Image Models

This page contains the list of external image models that can be used with EIR, coming from the great timm library.

There are 3 ways to use these models:

  • Configure and train specific architectures (e.g. ResNet with chosen number of layers) from scratch.

  • Train a specific architecture (e.g. resnet18) from scratch.

  • Use a pre-trained model (e.g. resnet18) and fine-tune it.

Please refer to this page for more detailed information about configurable architectures, and this page for a list of pre-defined architectures, with the option of using pre-trained weights.

Configurable Models

The following models can be configured and trained from scratch.

The model type is specified in the model_type field of the configuration, while the model specific configuration is specified in the model_init_config field.

For example, the ResNet architecture includes the layers and block parameters, and can be configured as follows:

input_configurable_image_model.yaml
input_info:
  input_source: eir_tutorials/a_using_eir/05_image_tutorial/data/hot_dog_not_hot_dog/food_images
  input_name: hot_dog
  input_type: image

input_type_info:
  mixing_subtype: "cutmix"
  size:
    - 64

model_config:
  model_type: "ResNet"
  model_init_config:
    layers: [1, 1, 1, 1]
    block: "BasicBlock"

interpretation_config:
    num_samples_to_interpret: 30
class timm.models.beit.Beit(
img_size: int | ~typing.Tuple[int,
int]=224,
patch_size: int | ~typing.Tuple[int,
int]=16,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'avg',
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
qkv_bias: bool = True,
mlp_ratio: float = 4.0,
swiglu_mlp: bool = False,
scale_mlp: bool = False,
drop_rate: float = 0.0,
pos_drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
norm_layer: Type[Module] = <class 'timm.layers.norm.LayerNorm'>,
init_values: float | None = None,
use_abs_pos_emb: bool = True,
use_rel_pos_bias: bool = False,
use_shared_rel_pos_bias: bool = False,
head_init_scale: float = 0.001,
device=None,
dtype=None,
)

BEiT: BERT Pre-Training of Image Transformers.

Vision Transformer model with support for relative position bias and shared relative position bias across layers. Implements both BEiT v1 and v2 architectures with flexible configuration options.

fix_init_weight() None

Fix initialization weights according to BEiT paper.

Rescales attention and MLP weights based on layer depth to improve training stability.

forward_features(x: Tensor) Tensor

Forward pass through feature extraction layers.

Parameters:

x – Input tensor of shape (batch_size, channels, height, width).

Returns:

Feature tensor of shape (batch_size, num_tokens, embed_dim).

forward_head(x: Tensor, pre_logits: bool = False) Tensor

Forward pass through classification head.

Parameters:
  • x – Feature tensor of shape (batch_size, num_tokens, embed_dim).

  • pre_logits – If True, return features before final linear layer.

Returns:

Logits tensor of shape (batch_size, num_classes) or pre-logits.

forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
return_prefix_tokens: bool = False,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward pass that returns intermediate feature maps.

Parameters:
  • x – Input image tensor of shape (batch_size, channels, height, width).

  • indices – Block indices to return features from. If int, returns last n blocks.

  • return_prefix_tokens – If True, return both prefix and spatial tokens.

  • norm – If True, apply normalization to intermediate features.

  • stop_early – If True, stop at last selected intermediate.

  • output_fmt – Output format (‘NCHW’ or ‘NLC’).

  • intermediates_only – If True, only return intermediate features.

Returns:

If intermediates_only is True, returns list of intermediate tensors. Otherwise, returns tuple of (final_features, intermediates).

get_classifier() Module

Get the classifier head.

Returns:

The classification head module.

group_matcher(coarse: bool = False) Dict[str, Any]

Create parameter group matcher for optimizer parameter groups.

Parameters:

coarse – If True, use coarse grouping.

Returns:

Dictionary mapping group names to regex patterns.

init_weights(needs_reset: bool = True) None

Initialize model weights.

Parameters:

needs_reset – If True, call reset_parameters() on modules that have it. Set to False when modules have already self-initialized in __init__.

no_weight_decay() Set[str]

Get parameter names that should not use weight decay.

Returns:

Set of parameter names to exclude from weight decay.

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
) List[int]

Prune layers not required for specified intermediate outputs.

Parameters:
  • indices – Indices of blocks to keep.

  • prune_norm – If True, remove final normalization.

  • prune_head – If True, remove classification head.

Returns:

List of indices that were kept.

reset_classifier(num_classes: int, global_pool: str | None = None)

Reset the classification head.

Parameters:
  • num_classes – Number of classes for new head.

  • global_pool – Global pooling type.

set_grad_checkpointing(enable: bool = True)

Enable or disable gradient checkpointing.

Parameters:

enable – If True, enable gradient checkpointing.

class timm.models.byobnet.ByobNet(
cfg: ByoModelCfg,
num_classes: int = 1000,
in_chans: int = 3,
global_pool: str | None = None,
output_stride: int = 32,
img_size: int | Tuple[int, int] | None = None,
drop_rate: float = 0.0,
drop_block_rate: float = 0.0,
drop_block_size: int = 3,
drop_path_rate: float = 0.0,
zero_init_last: bool = True,
device=None,
dtype=None,
**kwargs,
)

Bring-your-own-blocks Network.

A flexible network backbone that allows building model stem + blocks via dataclass cfg definition w/ factory functions for module instantiation.

Current assumption is that both stem and blocks are in conv-bn-act order (w/ block ending in act).

forward_features(x: Tensor) Tensor

Forward pass through feature extraction.

Parameters:

x – Input tensor.

Returns:

Feature tensor.

forward_head(
x: Tensor,
pre_logits: bool = False,
) Tensor

Forward pass through head.

Parameters:
  • x – Input features.

  • pre_logits – Return features before final linear layer.

Returns:

Classification logits or features.

forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
exclude_final_conv: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

  • exclude_final_conv – Exclude final_conv from last intermediate

Returns:

get_classifier() Module

Get classifier module.

Returns:

Classifier module.

group_matcher(coarse: bool = False) Dict[str, Any]

Group matcher for parameter groups.

Parameters:

coarse – Whether to use coarse grouping.

Returns:

Dictionary mapping group names to patterns.

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
) List[int]

Prune layers not required for specified intermediates.

Parameters:
  • indices – Indices of intermediate layers to keep.

  • prune_norm – Whether to prune normalization layer.

  • prune_head – Whether to prune the classifier head.

Returns:

List of indices that were kept.

reset_classifier(
num_classes: int,
global_pool: str | None = None,
) None

Reset classifier.

Parameters:
  • num_classes – Number of classes for new classifier.

  • global_pool – Global pooling type.

set_grad_checkpointing(enable: bool = True) None

Enable or disable gradient checkpointing.

Parameters:

enable – Whether to enable gradient checkpointing.

class timm.models.cait.Cait(
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'token',
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
drop_rate: float = 0.0,
pos_drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
block_layers: Type[Module] = <class 'timm.models.cait.LayerScaleBlock'>,
block_layers_token: Type[Module] = <class 'timm.models.cait.LayerScaleBlockClassAttn'>,
patch_layer: Type[Module] = <class 'timm.layers.patch_embed.PatchEmbed'>,
norm_layer: Type[Module] = functools.partial(<class 'torch.nn.modules.normalization.LayerNorm'>,
eps=1e-06),
act_layer: Type[Module] = <class 'torch.nn.modules.activation.GELU'>,
attn_block: Type[Module] = <class 'timm.models.cait.TalkingHeadAttn'>,
mlp_block: Type[Module] = <class 'timm.layers.mlp.Mlp'>,
init_values: float = 0.0001,
attn_block_token_only: Type[Module] = <class 'timm.models.cait.ClassAttn'>,
mlp_block_token_only: Type[Module] = <class 'timm.layers.mlp.Mlp'>,
depth_token_only: int = 2,
mlp_ratio_token_only: float = 4.0,
device=None,
dtype=None,
)
forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to all intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.coat.CoaT(
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
num_classes: int = 1000,
embed_dims: Tuple[int,
int,
int,
int]=(64,
128,
320,
512),
serial_depths: Tuple[int,
int,
int,
int]=(3,
4,
6,
3),
parallel_depth: int = 0,
num_heads: int = 8,
mlp_ratios: Tuple[float,
float,
float,
float]=(4,
4,
4,
4),
qkv_bias: bool = True,
drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
norm_layer: Type[Module] = <class 'timm.layers.norm.LayerNorm'>,
return_interm_layers: bool = False,
out_features: List[str] | None = None,
crpe_window: dict | None = None,
global_pool: str = 'token',
device=None,
dtype=None,
)

CoaT class.

class timm.models.convit.ConVit(
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'token',
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
qkv_bias: bool = False,
drop_rate: float = 0.0,
pos_drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
hybrid_backbone: Any | None = None,
norm_layer: Type[Module] = <class 'timm.layers.norm.LayerNorm'>,
local_up_to_layer: int = 3,
locality_strength: float = 1.0,
use_pos_embed: bool = True,
device=None,
dtype=None,
)

Vision Transformer with support for patch or hybrid CNN input stage

class timm.models.convmixer.ConvMixer(
dim: int,
depth: int,
kernel_size: int = 9,
patch_size: int = 7,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'avg',
drop_rate: float = 0.0,
act_layer: Type[Module] = <class 'torch.nn.modules.activation.GELU'>,
device=None,
dtype=None,
**kwargs,
)
class timm.models.convnext.ConvNeXt(
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'avg',
output_stride: int = 32,
depths: Tuple[int, ...] = (3, 3, 9, 3),
dims: Tuple[int, ...] = (96, 192, 384, 768),
kernel_sizes: int | Tuple[int, ...] = 7,
ls_init_value: float | None = 1e-06,
stem_type: str = 'patch',
patch_size: int = 4,
head_init_scale: float = 1.0,
head_norm_first: bool = False,
head_hidden_size: int | None = None,
conv_mlp: bool = False,
conv_bias: bool = True,
use_grn: bool = False,
act_layer: str | Callable = 'gelu',
norm_layer: str | Callable | None = None,
norm_eps: float | None = None,
drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
device=None,
dtype=None,
)

ConvNeXt model architecture.

A PyTorch impl of : A ConvNet for the 2020s - https://arxiv.org/pdf/2201.03545.pdf

forward_features(x: Tensor) Tensor

Forward pass through feature extraction layers.

forward_head(
x: Tensor,
pre_logits: bool = False,
) Tensor

Forward pass through classifier head.

Parameters:
  • x – Feature tensor.

  • pre_logits – Return features before final classifier.

Returns:

Output tensor.

forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor.

  • indices – Take last n blocks if int, all if None, select matching indices if sequence.

  • norm – Apply norm layer to compatible intermediates.

  • stop_early – Stop iterating over blocks when last desired intermediate hit.

  • output_fmt – Shape of intermediate feature outputs.

  • intermediates_only – Only return intermediate features.

Returns:

List of intermediate features or tuple of (final features, intermediates).

get_classifier() Module

Get the classifier module.

group_matcher(
coarse: bool = False,
) Dict[str, str | List]

Create regex patterns for parameter grouping.

Parameters:

coarse – Use coarse grouping.

Returns:

Dictionary mapping group names to regex patterns.

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
) List[int]

Prune layers not required for specified intermediates.

Parameters:
  • indices – Indices of intermediate layers to keep.

  • prune_norm – Whether to prune normalization layer.

  • prune_head – Whether to prune the classifier head.

Returns:

List of indices that were kept.

reset_classifier(
num_classes: int,
global_pool: str | None = None,
) None

Reset the classifier head.

Parameters:
  • num_classes – Number of classes for new classifier.

  • global_pool – Global pooling type.

set_grad_checkpointing(enable: bool = True) None

Enable or disable gradient checkpointing.

Parameters:

enable – Whether to enable gradient checkpointing.

class timm.models.crossvit.CrossVit(img_size: int = 224, img_scale: ~typing.Tuple[float, ...] = (1.0, 1.0), patch_size: ~typing.Tuple[int, ...] = (8, 16), in_chans: int = 3, num_classes: int = 1000, embed_dim: ~typing.Tuple[int, ...] = (192, 384), depth: ~typing.Tuple[~typing.Tuple[int, ...], ...] = ((1, 3, 1), (1, 3, 1), (1, 3, 1)), num_heads: ~typing.Tuple[int, ...] = (6, 12), mlp_ratio: ~typing.Tuple[float, ...] = (2.0, 2.0, 4.0), multi_conv: bool = False, crop_scale: bool = False, qkv_bias: bool = True, drop_rate: float = 0.0, pos_drop_rate: float = 0.0, proj_drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, norm_layer: ~typing.Type[~torch.nn.modules.module.Module] = functools.partial(<class 'torch.nn.modules.normalization.LayerNorm'>, eps=1e-06), global_pool: str = 'token', device=None, dtype=None)

Vision Transformer with support for patch or hybrid CNN input stage

class timm.models.cspnet.CspNet(
cfg: CspModelCfg,
in_chans: int = 3,
num_classes: int = 1000,
output_stride: int = 32,
global_pool: str = 'avg',
drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
zero_init_last: bool = True,
device=None,
dtype=None,
**kwargs,
)

Cross Stage Partial base model.

Paper: CSPNet: A New Backbone that can Enhance Learning Capability of CNN - https://arxiv.org/abs/1911.11929 Ref Impl: https://github.com/WongKinYiu/CrossStagePartialNetworks

NOTE: There are differences in the way I handle the 1x1 ‘expansion’ conv in this impl vs the darknet impl. I did it this way for simplicity and less special cases.

class timm.models.davit.DaVit(
in_chans: int = 3,
depths: Tuple[int, ...] = (1, 1, 3, 1),
embed_dims: Tuple[int, ...] = (96, 192, 384, 768),
num_heads: Tuple[int, ...] = (3, 6, 12, 24),
window_size: int = 7,
mlp_ratio: float = 4,
qkv_bias: bool = True,
norm_layer: str = 'layernorm2d',
norm_layer_cl: str = 'layernorm',
norm_eps: float = 1e-05,
attn_types: Tuple[str, ...] = ('spatial', 'channel'),
ffn: bool = True,
cpe_act: bool = False,
down_kernel_size: int = 2,
channel_attn_v2: bool = False,
named_blocks: bool = False,
drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
num_classes: int = 1000,
global_pool: str = 'avg',
head_norm_first: bool = False,
device=None,
dtype=None,
)
DaViT

A PyTorch implementation of DaViT: Dual Attention Vision Transformers - https://arxiv.org/abs/2204.03645 Supports arbitrary input sizes and pyramid feature extraction

Parameters:
  • in_chans (int) – Number of input image channels. Default: 3

  • num_classes (int) – Number of classes for classification head. Default: 1000

  • depths (tuple(int)) – Number of blocks in each stage. Default: (1, 1, 3, 1)

  • embed_dims (tuple(int)) – Patch embedding dimension. Default: (96, 192, 384, 768)

  • num_heads (tuple(int)) – Number of attention heads in different layers. Default: (3, 6, 12, 24)

  • window_size (int) – Window size. Default: 7

  • mlp_ratio (float) – Ratio of mlp hidden dim to embedding dim. Default: 4

  • qkv_bias (bool) – If True, add a learnable bias to query, key, value. Default: True

  • drop_path_rate (float) – Stochastic depth rate. Default: 0.1

  • norm_layer (nn.Module) – Normalization layer. Default: nn.LayerNorm.

forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.deit.VisionTransformerDistilled(*args, **kwargs)

Vision Transformer w/ Distillation Token and Head

Distillation token & head support for DeiT: Data-efficient Image Transformers
forward_head(
x,
pre_logits: bool = False,
) Tensor

Forward pass through classifier head.

Parameters:
  • x – Feature tensor.

  • pre_logits – Return features before final classifier.

Returns:

Output tensor.

get_classifier() Module

Get the classifier head.

group_matcher(coarse=False)

Create regex patterns for parameter grouping.

Parameters:

coarse – Use coarse grouping.

Returns:

Dictionary mapping group names to regex patterns.

init_weights(mode='', needs_reset=True)

Initialize model weights.

Parameters:
  • mode – Weight initialization mode (‘jax’, ‘jax_nlhb’, ‘moco’, or ‘’).

  • needs_reset – If True, call reset_parameters() on modules that have it. Set to False when modules have already self-initialized in __init__.

reset_classifier(
num_classes: int,
global_pool: str | None = None,
)

Reset the classifier head.

Parameters:
  • num_classes – Number of classes for new classifier.

  • global_pool – Global pooling type.

class timm.models.densenet.DenseNet(
growth_rate: int = 32,
block_config: Tuple[int, ...] = (6, 12, 24, 16),
num_classes: int = 1000,
in_chans: int = 3,
global_pool: str = 'avg',
bn_size: int = 4,
stem_type: str = '',
act_layer: str = 'relu',
norm_layer: str = 'batchnorm2d',
aa_layer: Type[Module] | None = None,
drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
memory_efficient: bool = False,
aa_stem_only: bool = True,
device=None,
dtype=None,
)

Densenet-BC model class.

Based on “Densely Connected Convolutional Networks”

Parameters:
  • growth_rate – How many filters to add each layer (k in paper).

  • block_config – How many layers in each pooling block.

  • bn_size – Multiplicative factor for number of bottle neck layers (i.e. bn_size * k features in the bottleneck layer).

  • drop_rate – Dropout rate before classifier layer.

  • proj_drop_rate – Dropout rate after each dense layer.

  • num_classes – Number of classification classes.

  • memory_efficient – If True, uses checkpointing. Much more memory efficient, but slower. Default: False. See “paper”.

forward_features(x: Tensor) Tensor

Forward pass through feature extraction layers.

forward_head(
x: Tensor,
pre_logits: bool = False,
) Tensor

Forward pass through classifier head.

Parameters:
  • x – Feature tensor.

  • pre_logits – Return features before final classifier.

Returns:

Output tensor.

get_classifier() Module

Get the classifier head.

group_matcher(coarse: bool = False) Dict[str, Any]

Group parameters for optimization.

reset_classifier(num_classes: int, global_pool: str = 'avg') None

Reset the classifier head.

Parameters:
  • num_classes – Number of classes for new classifier.

  • global_pool – Global pooling type.

set_grad_checkpointing(enable: bool = True) None

Enable or disable gradient checkpointing.

class timm.models.dla.DLA(levels: ~typing.Tuple[int, ...], channels: ~typing.Tuple[int, ...], output_stride: int = 32, num_classes: int = 1000, in_chans: int = 3, global_pool: str = 'avg', cardinality: int = 1, base_width: int = 64, block: ~typing.Type[~torch.nn.modules.module.Module] = <class 'timm.models.dla.DlaBottle2neck'>, shortcut_root: bool = False, drop_rate: float = 0.0, device=None, dtype=None)
class timm.models.dpn.DPN(
k_sec: Tuple[int, ...] = (3, 4, 20, 3),
inc_sec: Tuple[int, ...] = (16, 32, 24, 128),
k_r: int = 96,
groups: int = 32,
num_classes: int = 1000,
in_chans: int = 3,
output_stride: int = 32,
global_pool: str = 'avg',
small: bool = False,
num_init_features: int = 64,
b: bool = False,
drop_rate: float = 0.0,
norm_layer: str = 'batchnorm2d',
act_layer: str = 'relu',
fc_act_layer: str = 'elu',
device=None,
dtype=None,
)
class timm.models.edgenext.EdgeNeXt(
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'avg',
dims: Tuple[int,
...]=(24,
48,
88,
168),
depths: Tuple[int,
...]=(3,
3,
9,
3),
global_block_counts: Tuple[int,
...]=(0,
1,
1,
1),
kernel_sizes: Tuple[int,
...]=(3,
5,
7,
9),
heads: Tuple[int,
...]=(8,
8,
8,
8),
d2_scales: Tuple[int,
...]=(2,
2,
3,
4),
use_pos_emb: Tuple[bool,
...]=(False,
True,
False,
False),
ls_init_value: float = 1e-06,
head_init_scale: float = 1.0,
expand_ratio: float = 4,
downsample_block: bool = False,
conv_bias: bool = True,
stem_type: str = 'patch',
head_norm_first: bool = False,
act_layer: Type[Module] = <class 'torch.nn.modules.activation.GELU'>,
drop_path_rate: float = 0.0,
drop_rate: float = 0.0,
device=None,
dtype=None,
)
forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.efficientformer.EfficientFormer(
depths: Tuple[int,
...]=(3,
2,
6,
4),
embed_dims: Tuple[int,
...]=(48,
96,
224,
448),
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'avg',
downsamples: Tuple[bool,
...] | None=None,
num_vit: int = 0,
mlp_ratios: float = 4,
pool_size: int = 3,
layer_scale_init_value: float = 1e-05,
act_layer: Type[Module] = <class 'torch.nn.modules.activation.GELU'>,
norm_layer: Type[Module] = <class 'torch.nn.modules.batchnorm.BatchNorm2d'>,
norm_layer_cl: Type[Module] = <class 'torch.nn.modules.normalization.LayerNorm'>,
drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
device=None,
dtype=None,
**kwargs,
)
forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.efficientformer_v2.EfficientFormerV2(
depths: Tuple[int, ...],
in_chans: int = 3,
img_size: int | Tuple[int, int] = 224,
global_pool: str = 'avg',
embed_dims: Tuple[int, ...] | None = None,
downsamples: Tuple[bool, ...] | None = None,
mlp_ratios: float | Tuple[float, ...] | Tuple[Tuple[float, ...], ...] = 4,
norm_layer: str = 'batchnorm2d',
norm_eps: float = 1e-05,
act_layer: str = 'gelu',
num_classes: int = 1000,
drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
layer_scale_init_value: float | None = 1e-05,
num_vit: int = 0,
distillation: bool = True,
device=None,
dtype=None,
)
forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.efficientnet.EfficientNet(block_args: ~typing.List[~typing.List[~typing.Dict[str, ~typing.Any]]], num_classes: int = 1000, num_features: int = 1280, in_chans: int = 3, stem_size: int = 32, stem_kernel_size: int = 3, fix_stem: bool = False, output_stride: int = 32, pad_type: str = '', act_layer: str | ~typing.Callable | ~typing.Type[~torch.nn.modules.module.Module] | None = None, norm_layer: str | ~typing.Callable | ~typing.Type[~torch.nn.modules.module.Module] | None = None, aa_layer: str | ~typing.Callable | ~typing.Type[~torch.nn.modules.module.Module] | None = None, se_layer: str | ~typing.Callable | ~typing.Type[~torch.nn.modules.module.Module] | None = None, round_chs_fn: ~typing.Callable = <function round_channels>, drop_rate: float = 0.0, drop_path_rate: float = 0.0, global_pool: str = 'avg', device=None, dtype=None)

EfficientNet model architecture.

A flexible and performant PyTorch implementation of efficient network architectures, including:
  • EfficientNet-V2 Small, Medium, Large, XL & B0-B3

  • EfficientNet B0-B8, L2

  • EfficientNet-EdgeTPU

  • EfficientNet-CondConv

  • MixNet S, M, L, XL

  • MnasNet A1, B1, and small

  • MobileNet-V2

  • FBNet C

  • Single-Path NAS Pixel1

  • TinyNet

References

as_sequential() Sequential

Convert model to sequential for feature extraction.

forward_features(x: Tensor) Tensor

Forward pass through feature extraction layers.

forward_head(
x: Tensor,
pre_logits: bool = False,
) Tensor

Forward pass through classifier head.

Parameters:
  • x – Feature tensor.

  • pre_logits – Return features before final classifier.

Returns:

Output tensor.

forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
extra_blocks: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor.

  • indices – Take last n blocks if int, all if None, select matching indices if sequence.

  • norm – Apply norm layer to compatible intermediates.

  • stop_early – Stop iterating over blocks when last desired intermediate hit.

  • output_fmt – Shape of intermediate feature outputs.

  • intermediates_only – Only return intermediate features.

  • extra_blocks – Include outputs of all blocks and head conv in output, does not align with feature_info.

Returns:

List of intermediate features or tuple of (final features, intermediates).

get_classifier() Module

Get the classifier module.

group_matcher(
coarse: bool = False,
) Dict[str, str | List]

Create regex patterns for parameter groups.

Parameters:

coarse – Use coarse (stage-level) grouping.

Returns:

Dictionary mapping group names to regex patterns.

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
extra_blocks: bool = False,
) List[int]

Prune layers not required for specified intermediates.

Parameters:
  • indices – Indices of intermediate layers to keep.

  • prune_norm – Whether to prune normalization layers.

  • prune_head – Whether to prune the classifier head.

  • extra_blocks – Include all blocks in indexing.

Returns:

List of indices that were kept.

reset_classifier(
num_classes: int,
global_pool: str = 'avg',
) None

Reset the classifier head.

Parameters:
  • num_classes – Number of classes for new classifier.

  • global_pool – Global pooling type.

set_grad_checkpointing(enable: bool = True) None

Enable or disable gradient checkpointing.

Parameters:

enable – Whether to enable gradient checkpointing.

class timm.models.efficientvit_mit.EfficientVit(
in_chans: int = 3,
widths: Tuple[int,
...]=(),
depths: Tuple[int,
...]=(),
head_dim: int = 32,
expand_ratio: float = 4,
norm_layer: Type[Module] = <class 'torch.nn.modules.batchnorm.BatchNorm2d'>,
act_layer: Type[Module] = <class 'torch.nn.modules.activation.Hardswish'>,
global_pool: str = 'avg',
head_widths: Tuple[int,
...]=(),
drop_rate: float = 0.0,
num_classes: int = 1000,
device=None,
dtype=None,
)
forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.efficientvit_msra.EfficientVitMsra(
img_size: int = 224,
in_chans: int = 3,
num_classes: int = 1000,
embed_dim: Tuple[int, ...] = (64, 128, 192),
key_dim: Tuple[int, ...] = (16, 16, 16),
depth: Tuple[int, ...] = (1, 2, 3),
num_heads: Tuple[int, ...] = (4, 4, 4),
window_size: Tuple[int, ...] = (7, 7, 7),
kernels: Tuple[int, ...] = (5, 5, 5, 5),
down_ops: Tuple[Tuple[str, int], ...] = (('', 1), ('subsample', 2), ('subsample', 2)),
global_pool: str = 'avg',
drop_rate: float = 0.0,
device=None,
dtype=None,
)
forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.eva.Eva(
img_size: int | ~typing.Tuple[int,
int]=224,
patch_size: int | ~typing.Tuple[int,
int]=16,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'avg',
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
qkv_bias: bool = True,
qkv_fused: bool = True,
mlp_ratio: float = 4.0,
swiglu_mlp: bool = False,
swiglu_align_to: int = 0,
scale_mlp: bool = False,
scale_attn_inner: bool = False,
attn_type: str = 'eva',
drop_rate: float = 0.0,
pos_drop_rate: float = 0.0,
patch_drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
norm_layer: Callable = <class 'timm.layers.norm.LayerNorm'>,
init_values: float | None = None,
class_token: bool = True,
num_reg_tokens: int = 0,
no_embed_class: bool = False,
use_abs_pos_emb: bool = True,
use_rot_pos_emb: bool = False,
rope_type: str | None = 'cat',
rope_grid_offset: float = 0.0,
rope_grid_indexing: str = 'ij',
rope_temperature: float = 10000.0,
rope_rotate_half: bool = False,
use_post_norm: bool = False,
use_pre_transformer_norm: bool = False,
use_post_transformer_norm: bool | None = None,
use_fc_norm: bool | None = None,
attn_pool_num_heads: int | None = None,
attn_pool_mlp_ratio: float | None = None,
dynamic_img_size: bool = False,
dynamic_img_pad: bool = False,
ref_feat_shape: int | ~typing.Tuple[int,
int] | None=None,
head_init_scale: float = 0.001,
device=None,
dtype=None,
)

Eva Vision Transformer w/ Abs & Rotary Pos Embed

This class implements the EVA and EVA02 models that were based on the BEiT ViT variant
  • EVA - abs pos embed, global avg pool

  • EVA02 - abs + rope pos embed, global avg pool, SwiGLU, scale Norm in MLP (ala normformer)

fix_init_weight() None

Fix initialization weights by rescaling based on layer depth.

forward_features(
x: Tensor,
attn_mask: Tensor | None = None,
is_causal: bool = False,
) Tensor

Forward pass through feature extraction layers.

Parameters:
  • x – Input tensor.

  • attn_mask – Optional attention mask for masked attention

  • is_causal – If True, use causal (autoregressive) masking in attention.

Returns:

Feature tensor.

forward_head(x: Tensor, pre_logits: bool = False) Tensor

Forward pass through classifier head.

Parameters:
  • x – Feature tensor.

  • pre_logits – Return pre-logits if True.

Returns:

Output tensor.

forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
return_prefix_tokens: bool = False,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
attn_mask: Tensor | None = None,
is_causal: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates. :param x: Input image tensor :param indices: Take last n blocks if an int, if is a sequence, select by matching indices :param return_prefix_tokens: Return both prefix and spatial intermediate tokens :param norm: Apply norm layer to all intermediates :param stop_early: Stop iterating over blocks when last desired intermediate hit :param output_fmt: Shape of intermediate feature outputs :param intermediates_only: Only return intermediate features :param attn_mask: Optional attention mask for masked attention :param is_causal: If True, use causal (autoregressive) masking in attention

group_matcher(coarse: bool = False) Dict[str, Any]

Create layer groupings for optimization.

no_weight_decay() Set[str]

Parameters to exclude from weight decay.

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

reset_classifier(num_classes: int, global_pool: str | None = None) None

Reset the classifier head.

Parameters:
  • num_classes – Number of output classes.

  • global_pool – Global pooling type.

set_grad_checkpointing(enable: bool = True) None

Enable or disable gradient checkpointing.

set_input_size(
img_size: Tuple[int, int] | None = None,
patch_size: Tuple[int, int] | None = None,
) None

Update the input image resolution and patch size.

Parameters:
  • img_size – New input resolution, if None current resolution is used.

  • patch_size – New patch size, if None existing patch size is used.

class timm.models.fasternet.FasterNet(
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'avg',
embed_dim: int = 96,
depths: int | ~typing.Tuple[int,
...]=(1,
2,
8,
2),
mlp_ratio: float = 2.0,
n_div: int = 4,
patch_size: int | ~typing.Tuple[int,
int]=4,
merge_size: int | ~typing.Tuple[int,
int]=2,
patch_norm: bool = True,
feature_dim: int = 1280,
drop_rate: float = 0.0,
drop_path_rate: float = 0.1,
layer_scale_init_value: float = 0.0,
act_layer: Type[Module] = functools.partial(<class 'torch.nn.modules.activation.ReLU'>,
inplace=True),
norm_layer: Type[Module] = <class 'torch.nn.modules.batchnorm.BatchNorm2d'>,
pconv_fw_type: str = 'split_cat',
device=None,
dtype=None,
)
forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.fastvit.FastVit(
in_chans: int = 3,
layers: Tuple[int,
...]=(2,
2,
6,
2),
token_mixers: Tuple[str,
...]=('repmixer',
'repmixer',
'repmixer',
'repmixer'),
embed_dims: Tuple[int,
...]=(64,
128,
256,
512),
mlp_ratios: Tuple[float,
...]=(4,
4,
4,
4),
downsamples: Tuple[bool,
...]=(False,
True,
True,
True),
se_downsamples: Tuple[bool,
...]=(False,
False,
False,
False),
repmixer_kernel_size: int = 3,
num_classes: int = 1000,
pos_embs: Module | None,
...]=(None,
None,
None,
None),
down_patch_size: int = 7,
down_stride: int = 2,
drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
layer_scale_init_value: float = 1e-05,
lkc_use_act: bool = False,
stem_use_scale_branch: bool = True,
fork_feat: bool = False,
cls_ratio: float = 2.0,
global_pool: str = 'avg',
norm_layer: Type[Module] = <class 'torch.nn.modules.batchnorm.BatchNorm2d'>,
act_layer: Type[Module] = <class 'torch.nn.modules.activation.GELU'>,
inference_mode: bool = False,
device=None,
dtype=None,
)
fork_feat: Final[bool]

This class implements FastViT architecture

forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.focalnet.FocalNet(
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'avg',
embed_dim: int = 96,
depths: Tuple[int,
...]=(2,
2,
6,
2),
mlp_ratio: float = 4.0,
focal_levels: Tuple[int,
...]=(2,
2,
2,
2),
focal_windows: Tuple[int,
...]=(3,
3,
3,
3),
use_overlap_down: bool = False,
use_post_norm: bool = False,
use_post_norm_in_modulation: bool = False,
normalize_modulator: bool = False,
head_hidden_size: int | None = None,
head_init_scale: float = 1.0,
layerscale_value: float | None = None,
drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
drop_path_rate: float = 0.1,
norm_layer: Type[Module] = functools.partial(<class 'timm.layers.norm.LayerNorm2d'>,
eps=1e-05),
device=None,
dtype=None,
)

“ Focal Modulation Networks (FocalNets)

forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.gcvit.GlobalContextVit(
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'avg',
img_size: int | Tuple[int, int] = 224,
window_ratio: Tuple[int, ...] = (32, 32, 16, 32),
window_size: int | Tuple[int, ...] | None = None,
embed_dim: int = 64,
depths: Tuple[int, ...] = (3, 4, 19, 5),
num_heads: Tuple[int, ...] = (2, 4, 8, 16),
mlp_ratio: float = 3.0,
qkv_bias: bool = True,
layer_scale: float | None = None,
drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
weight_init: str = '',
act_layer: str = 'gelu',
norm_layer: str = 'layernorm2d',
norm_layer_cl: str = 'layernorm',
norm_eps: float = 1e-05,
device=None,
dtype=None,
)
forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.ghostnet.GhostNet(
cfgs: List[List[List[int | float]]],
num_classes: int = 1000,
width: float = 1.0,
in_chans: int = 3,
output_stride: int = 32,
global_pool: str = 'avg',
drop_rate: float = 0.2,
version: str = 'v1',
device=None,
dtype=None,
)
forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.hgnet.HighPerfGpuNet(
cfg: Dict,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'avg',
head_hidden_size: int | None = 2048,
drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
use_lab: bool = False,
device=None,
dtype=None,
**kwargs,
)
forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.hiera.Hiera(
img_size: Tuple[int, ...] = (224, 224),
in_chans: int = 3,
embed_dim: int = 96,
num_heads: int = 1,
num_classes: int = 1000,
global_pool: str = 'avg',
stages: Tuple[int, ...] = (2, 3, 16, 3),
q_pool: int = 3,
q_stride: Tuple[int, ...] = (2, 2),
mask_unit_size: Tuple[int, ...] = (8, 8),
mask_unit_attn: Tuple[bool, ...] = (True, True, False, False),
use_expand_proj: bool = True,
dim_mul: float = 2.0,
head_mul: float = 2.0,
patch_kernel: Tuple[int, ...] = (7, 7),
patch_stride: Tuple[int, ...] = (4, 4),
patch_padding: Tuple[int, ...] = (3, 3),
mlp_ratio: float = 4.0,
drop_path_rate: float = 0.0,
init_values: float | None = None,
fix_init: bool = True,
weight_init: str = '',
norm_layer: str | Type[Module] = 'LayerNorm',
drop_rate: float = 0.0,
patch_drop_rate: float = 0.0,
head_init_scale: float = 0.001,
sep_pos_embed: bool = False,
abs_win_pos_embed: bool = False,
global_pos_size: Tuple[int, int] = (14, 14),
device=None,
dtype=None,
)
forward_features(
x: Tensor,
mask: Tensor | None = None,
return_intermediates: bool = False,
) Tensor

mask should be a boolean tensor of shape [B, #MUt*#MUy*#MUx] where #MU are the number of mask units in that dim. Note: 1 in mask is keep, 0 is remove; mask.sum(dim=-1) should be the same across the batch.

forward_intermediates(
x: Tensor,
mask: Tensor | None = None,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = True,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
coarse: bool = True,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to all intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

get_random_mask(x: Tensor, mask_ratio: float) Tensor

Generates a random mask, mask_ratio fraction are dropped. 1 is keep, 0 is remove. Useful for MAE, FLIP, etc.

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
coarse: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.hrnet.HighResolutionNet(
cfg: Dict,
in_chans: int = 3,
num_classes: int = 1000,
output_stride: int = 32,
global_pool: str = 'avg',
drop_rate: float = 0.0,
head: str = 'classification',
device=None,
dtype=None,
**kwargs,
)
class timm.models.inception_next.MetaNeXt

A PyTorch impl of : InceptionNeXt: When Inception Meets ConvNeXt - https://arxiv.org/abs/2303.16900

Parameters:
  • in_chans (int) – Number of input image channels. Default: 3

  • num_classes (int) – Number of classes for classification head. Default: 1000

  • depths (tuple(int)) – Number of blocks at each stage. Default: (3, 3, 9, 3)

  • dims (tuple(int)) – Feature dimension at each stage. Default: (96, 192, 384, 768)

  • token_mixers – Token mixer function. Default: nn.Identity

  • norm_layer – Normalization layer. Default: nn.BatchNorm2d

  • act_layer – Activation function for MLP. Default: nn.GELU

  • mlp_ratios (int or tuple(int)) – MLP ratios. Default: (4, 4, 4, 3)

  • drop_rate (float) – Head dropout rate

  • drop_path_rate (float) – Stochastic depth rate. Default: 0.

  • ls_init_value (float) – Init value for Layer Scale. Default: 1e-6.

forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.inception_resnet_v2.InceptionResnetV2(
num_classes: int = 1000,
in_chans: int = 3,
drop_rate: float = 0.0,
output_stride: int = 32,
global_pool: str = 'avg',
norm_layer: str = 'batchnorm2d',
norm_eps: float = 0.001,
act_layer: str = 'relu',
device=None,
dtype=None,
)
class timm.models.inception_v3.InceptionV3(
num_classes: int = 1000,
in_chans: int = 3,
drop_rate: float = 0.0,
global_pool: str = 'avg',
aux_logits: bool = False,
norm_layer: str = 'batchnorm2d',
norm_eps: float = 0.001,
act_layer: str = 'relu',
device=None,
dtype=None,
)

Inception-V3

class timm.models.inception_v4.InceptionV4(
num_classes: int = 1000,
in_chans: int = 3,
output_stride: int = 32,
drop_rate: float = 0.0,
global_pool: str = 'avg',
norm_layer: str = 'batchnorm2d',
norm_eps: float = 0.001,
act_layer: str = 'relu',
device=None,
dtype=None,
)
forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.levit.Levit(
img_size: int | Tuple[int, int] = 224,
in_chans: int = 3,
num_classes: int = 1000,
embed_dim: Tuple[int, ...] = (192,),
key_dim: int = 64,
depth: Tuple[int, ...] = (12,),
num_heads: int | Tuple[int, ...] = (3,),
attn_ratio: float | Tuple[float, ...] = 2.0,
mlp_ratio: float | Tuple[float, ...] = 2.0,
stem_backbone: Module | None = None,
stem_stride: int | None = None,
stem_type: str = 's16',
down_op: str = 'subsample',
act_layer: str = 'hard_swish',
attn_act_layer: str | None = None,
use_conv: bool = False,
global_pool: str = 'avg',
drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
device=None,
dtype=None,
)

Vision Transformer with support for patch or hybrid CNN input stage

NOTE: distillation is defaulted to True since pretrained weights use it, will cause problems w/ train scripts that don’t take tuple outputs,

forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.maxxvit.MaxxVitCfg(embed_dim: ~typing.Tuple[int, ...] = (96, 192, 384, 768), depths: ~typing.Tuple[int, ...] = (2, 3, 5, 2), block_type: ~typing.Tuple[str | ~typing.Tuple[str, ...], ...] = ('C', 'C', 'T', 'T'), stem_width: int | ~typing.Tuple[int, int] = 64, stem_bias: bool = False, conv_cfg: ~timm.models.maxxvit.MaxxVitConvCfg = <factory>, transformer_cfg: ~timm.models.maxxvit.MaxxVitTransformerCfg = <factory>, head_hidden_size: int | None = None, weight_init: str = 'vit_eff')

Configuration for MaxxVit models.

class timm.models.metaformer.MetaFormer
A PyTorch impl ofMetaFormer Baselines for Vision -

https://arxiv.org/abs/2210.13452

Parameters:
  • in_chans (int) – Number of input image channels.

  • num_classes (int) – Number of classes for classification head.

  • global_pool – Pooling for classifier head.

  • depths (list or tuple) – Number of blocks at each stage.

  • dims (list or tuple) – Feature dimension at each stage.

  • token_mixers (list, tuple or token_fcn) – Token mixer for each stage.

  • mlp_act – Activation layer for MLP.

  • mlp_bias (boolean) – Enable or disable mlp bias term.

  • drop_path_rate (float) – Stochastic depth rate.

  • drop_rate (float) – Dropout rate.

  • layer_scale_init_values (list, tuple, float or None) – Init value for Layer Scale. None means not use the layer scale. Form: https://arxiv.org/abs/2103.17239.

  • res_scale_init_values (list, tuple, float or None) – Init value for res Scale on residual connections. None means not use the res scale. From: https://arxiv.org/abs/2110.09456.

  • downsample_norm (nn.Module) – Norm layer used in stem and downsampling layers.

  • norm_layers (list, tuple or norm_fcn) – Norm layers for each stage.

  • output_norm – Norm layer before classifier head.

  • use_mlp_head – Use MLP classification head.

forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.mobilenetv3.MobileNetV3(block_args: ~typing.List[~typing.List[~typing.Dict[str, ~typing.Any]]], num_classes: int = 1000, in_chans: int = 3, stem_size: int = 16, fix_stem: bool = False, num_features: int = 1280, head_bias: bool = True, head_norm: bool = False, pad_type: str = '', act_layer: str | ~typing.Callable | ~typing.Type[~torch.nn.modules.module.Module] | None = None, norm_layer: str | ~typing.Callable | ~typing.Type[~torch.nn.modules.module.Module] | None = None, aa_layer: str | ~typing.Callable | ~typing.Type[~torch.nn.modules.module.Module] | None = None, se_layer: str | ~typing.Callable | ~typing.Type[~torch.nn.modules.module.Module] | None = None, se_from_exp: bool = True, round_chs_fn: ~typing.Callable = <function round_channels>, drop_rate: float = 0.0, drop_path_rate: float = 0.0, layer_scale_init_value: float | None = None, global_pool: str = 'avg', device=None, dtype=None)

MobileNetV3.

Based on my EfficientNet implementation and building blocks, this model utilizes the MobileNet-v3 specific ‘efficient head’, where global pooling is done before the head convolution without a final batch-norm layer before the classifier.

Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244

Other architectures utilizing MobileNet-V3 efficient head that are supported by this impl include:
as_sequential() Sequential

Convert model to sequential form.

Returns:

Sequential module containing all layers.

forward_features(x: Tensor) Tensor

Forward pass through feature extraction layers.

Parameters:

x – Input tensor.

Returns:

Feature tensor.

forward_head(
x: Tensor,
pre_logits: bool = False,
) Tensor

Forward pass through classifier head.

Parameters:
  • x – Input features.

  • pre_logits – Return features before final linear layer.

Returns:

Classification logits or features.

forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
extra_blocks: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

  • extra_blocks – Include outputs of all blocks and head conv in output, does not align with feature_info

Returns:

get_classifier() Module

Get the classifier head.

group_matcher(
coarse: bool = False,
) Dict[str, Any]

Group parameters for optimization.

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
extra_blocks: bool = False,
) List[int]

Prune layers not required for specified intermediates.

Parameters:
  • indices – Indices of intermediate layers to keep.

  • prune_norm – Whether to prune normalization layer.

  • prune_head – Whether to prune the classifier head.

  • extra_blocks – Include outputs of all blocks.

Returns:

List of indices that were kept.

reset_classifier(num_classes: int, global_pool: str = 'avg') None

Reset the classifier head.

Parameters:
  • num_classes – Number of classes for new classifier.

  • global_pool – Global pooling type.

set_grad_checkpointing(enable: bool = True) None

Enable or disable gradient checkpointing.

class timm.models.mobilenetv5.MobileNetV5(block_args: ~typing.List[~typing.List[~typing.Dict[str, ~typing.Any]]], num_classes: int = 1000, in_chans: int = 3, stem_size: int = 16, stem_bias: bool = True, fix_stem: bool = False, num_features: int = 2048, pad_type: str = '', use_msfa: bool = True, msfa_indices: ~typing.List[int] = (-2, -1), msfa_output_resolution: int = 16, act_layer: str | ~typing.Callable | ~typing.Type[~torch.nn.modules.module.Module] | None = None, norm_layer: str | ~typing.Callable | ~typing.Type[~torch.nn.modules.module.Module] | None = None, aa_layer: str | ~typing.Callable | ~typing.Type[~torch.nn.modules.module.Module] | None = None, se_layer: str | ~typing.Callable | ~typing.Type[~torch.nn.modules.module.Module] | None = None, se_from_exp: bool = True, round_chs_fn: ~typing.Callable = <function round_channels>, drop_rate: float = 0.0, drop_path_rate: float = 0.0, layer_scale_init_value: float | None = None, global_pool: str = 'avg', device=None, dtype=None)

MobiletNet-V5

forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
extra_blocks: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

  • extra_blocks – Include outputs of all blocks and head conv in output, does not align with feature_info

Returns:

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
extra_blocks: bool = False,
)

Prune layers not required for specified intermediates.

class timm.models.mvitv2.MultiScaleVit(
cfg: MultiScaleVitCfg,
img_size: Tuple[int, int] = (224, 224),
in_chans: int = 3,
global_pool: str | None = None,
num_classes: int = 1000,
drop_path_rate: float = 0.0,
drop_rate: float = 0.0,
device=None,
dtype=None,
)

Improved Multiscale Vision Transformers for Classification and Detection Yanghao Li*, Chao-Yuan Wu*, Haoqi Fan, Karttikeya Mangalam, Bo Xiong, Jitendra Malik,

Christoph Feichtenhofer*

https://arxiv.org/abs/2112.01526

Multiscale Vision Transformers Haoqi Fan*, Bo Xiong*, Karttikeya Mangalam*, Yanghao Li*, Zhicheng Yan, Jitendra Malik,

Christoph Feichtenhofer*

https://arxiv.org/abs/2104.11227

forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to all intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.naflexvit.NaFlexVitCfg(
patch_size: int | Tuple[int, int] = 16,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
scale_mlp_norm: bool = False,
qkv_bias: bool = True,
qk_norm: bool = False,
proj_bias: bool = True,
attn_drop_rate: float = 0.0,
scale_attn_inner_norm: bool = False,
init_values: float | None = None,
drop_rate: float = 0.0,
pos_drop_rate: float = 0.0,
patch_drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
class_token: bool = False,
reg_tokens: int = 0,
pos_embed: str = 'learned',
pos_embed_grid_size: Tuple[int, int] | None = (16, 16),
pos_embed_interp_mode: str = 'bicubic',
pos_embed_ar_preserving: bool = False,
pos_embed_use_grid_sample: bool = False,
rope_type: str = '',
rope_temperature: float = 10000.0,
rope_ref_feat_shape: Tuple[int, int] | None = None,
rope_grid_offset: float = 0.0,
rope_grid_indexing: str = 'ij',
rope_rotate_half: bool = False,
dynamic_img_pad: bool = False,
pre_norm: bool = False,
final_norm: bool = True,
fc_norm: bool | None = None,
global_pool: str = 'map',
pool_include_prefix: bool = False,
attn_pool_num_heads: int | None = None,
attn_pool_mlp_ratio: float | None = None,
weight_init: str = '',
fix_init: bool = True,
embed_proj_type: str = 'linear',
input_norm_layer: str | None = None,
embed_norm_layer: str | None = None,
norm_layer: str | None = None,
act_layer: str | None = None,
block_fn: str | None = None,
mlp_layer: str | None = None,
attn_layer: str | None = None,
attn_type: str = 'standard',
swiglu_mlp: bool = False,
qkv_fused: bool = True,
enable_patch_interpolator: bool = False,
)

Configuration for FlexVit model.

This dataclass contains the bulk of model configuration parameters, with core parameters (img_size, in_chans, num_classes, etc.) remaining as direct constructor arguments for API compatibility.

class timm.models.nasnet.NASNetALarge(6 @ 4032)
class timm.models.nest.Nest(
img_size: int = 224,
in_chans: int = 3,
patch_size: int = 4,
num_levels: int = 3,
embed_dims: Tuple[int, ...] = (128, 256, 512),
num_heads: Tuple[int, ...] = (4, 8, 16),
depths: Tuple[int, ...] = (2, 2, 20),
num_classes: int = 1000,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.5,
norm_layer: Type[Module] | None = None,
act_layer: Type[Module] | None = None,
pad_type: str = '',
weight_init: str = '',
global_pool: str = 'avg',
device=None,
dtype=None,
)

Nested Transformer (NesT)

A PyTorch impl ofAggregating Nested Transformers
forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.nextvit.NextViT(
in_chans: int,
num_classes: int = 1000,
global_pool: str = 'avg',
stem_chs: Tuple[int,
...]=(64,
32,
64),
depths: Tuple[int,
...]=(3,
4,
10,
3),
strides: Tuple[int,
...]=(1,
2,
2,
2),
sr_ratios: Tuple[int,
...]=(8,
4,
2,
1),
drop_path_rate: float = 0.1,
attn_drop_rate: float = 0.0,
drop_rate: float = 0.0,
head_dim: int = 32,
mix_block_ratio: float = 0.75,
norm_layer: Type[Module] = <class 'torch.nn.modules.batchnorm.BatchNorm2d'>,
act_layer: Type[Module] | None = None,
device=None,
dtype=None,
)
forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.nfnet.NormFreeNet(
cfg: NfCfg,
num_classes: int = 1000,
in_chans: int = 3,
global_pool: str = 'avg',
output_stride: int = 32,
drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
device=None,
dtype=None,
**kwargs: Any,
)

Normalization-Free Network

As described in : Characterizing signal propagation to close the performance gap in unnormalized ResNets

and High-Performance Large-Scale Image Recognition Without Normalization - https://arxiv.org/abs/2102.06171

This model aims to cover both the NFRegNet-Bx models as detailed in the paper’s code snippets and the (preact) ResNet models described earlier in the paper.

There are a few differences:
  • channels are rounded to be divisible by 8 by default (keep tensor core kernels happy),

    this changes channel dim and param counts slightly from the paper models

  • activation correcting gamma constants are moved into the ScaledStdConv as it has less performance

    impact in PyTorch when done with the weight scaling there. This likely wasn’t a concern in the JAX impl.

  • a config option gamma_in_act can be enabled to not apply gamma in StdConv as described above, but

    apply it in each activation. This is slightly slower, numerically different, but matches official impl.

  • skipinit is disabled by default, it seems to have a rather drastic impact on GPU memory use and throughput

    for what it is/does. Approx 8-10% throughput loss.

forward_features(x: Tensor) Tensor

Forward pass through feature extraction layers.

Parameters:

x – Input tensor.

Returns:

Feature tensor.

forward_head(
x: Tensor,
pre_logits: bool = False,
) Tensor

Forward pass through classifier head.

Parameters:
  • x – Input features.

  • pre_logits – Return features before final linear layer.

Returns:

Classification logits or features.

get_classifier() Module

Get the classifier head.

group_matcher(
coarse: bool = False,
) Dict[str, Any]

Group parameters for optimization.

reset_classifier(
num_classes: int,
global_pool: str | None = None,
) None

Reset the classifier head.

Parameters:
  • num_classes – Number of classes for new classifier.

  • global_pool – Global pooling type.

set_grad_checkpointing(enable: bool = True) None

Enable or disable gradient checkpointing.

class timm.models.pit.PoolingVisionTransformer(
img_size: int = 224,
patch_size: int = 16,
stride: int = 8,
stem_type: str = 'overlap',
base_dims: Sequence[int] = (48, 48, 48),
depth: Sequence[int] = (2, 6, 4),
heads: Sequence[int] = (2, 4, 8),
mlp_ratio: float = 4,
num_classes: int = 1000,
in_chans: int = 3,
global_pool: str = 'token',
distilled: bool = False,
drop_rate: float = 0.0,
pos_drop_drate: float = 0.0,
proj_drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
device=None,
dtype=None,
)

Pooling-based Vision Transformer

A PyTorch implement of ‘Rethinking Spatial Dimensions of Vision Transformers’
forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.pnasnet.PNASNet5Large(
num_classes: int = 1000,
in_chans: int = 3,
output_stride: int = 32,
drop_rate: float = 0.0,
global_pool: str = 'avg',
pad_type: str = '',
device=None,
dtype=None,
)
class timm.models.pvt_v2.PyramidVisionTransformerV2(
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'avg',
depths: Tuple[int,
...]=(3,
4,
6,
3),
embed_dims: Tuple[int,
...]=(64,
128,
256,
512),
num_heads: Tuple[int,
...]=(1,
2,
4,
8),
sr_ratios: Tuple[int,
...]=(8,
4,
2,
1),
mlp_ratios: Tuple[float,
...]=(8.0,
8.0,
4.0,
4.0),
qkv_bias: bool = True,
linear: bool = False,
drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
norm_layer: Type[Module] = <class 'timm.layers.norm.LayerNorm'>,
device=None,
dtype=None,
)
forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.rdnet.RDNet(
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'avg',
growth_rates: List[int] | Tuple[int] = (64, 104, 128, 128, 128, 128, 224),
num_blocks_list: List[int] | Tuple[int] = (3, 3, 3, 3, 3, 3, 3),
block_type: List[int] | Tuple[int] = ('Block', 'Block', 'BlockESE', 'BlockESE', 'BlockESE', 'BlockESE', 'BlockESE'),
is_downsample_block: List[bool] | Tuple[bool] = (None, True, True, False, False, False, True),
bottleneck_width_ratio: float = 4.0,
transition_compression_ratio: float = 0.5,
ls_init_value: float = 1e-06,
stem_type: str = 'patch',
patch_size: int = 4,
num_init_features: int = 64,
head_init_scale: float = 1.0,
head_norm_first: bool = False,
conv_bias: bool = True,
act_layer: str | Callable = 'gelu',
norm_layer: str = 'layernorm2d',
norm_eps: float | None = None,
drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
device=None,
dtype=None,
)
forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.regnet.RegNet(
cfg: RegNetCfg,
in_chans: int = 3,
num_classes: int = 1000,
output_stride: int = 32,
global_pool: str = 'avg',
drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
zero_init_last: bool = True,
device=None,
dtype=None,
**kwargs,
)

RegNet-X, Y, and Z Models.

Paper: https://arxiv.org/abs/2003.13678 Original Impl: https://github.com/facebookresearch/pycls/blob/master/pycls/models/regnet.py

forward_features(x: Tensor) Tensor

Forward pass through feature extraction layers.

Parameters:

x – Input tensor.

Returns:

Feature tensor.

forward_head(x: Tensor, pre_logits: bool = False) Tensor

Forward pass through classifier head.

Parameters:
  • x – Input features.

  • pre_logits – Return features before final linear layer.

Returns:

Classification logits or features.

forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

get_classifier() Module

Get the classifier head.

group_matcher(coarse: bool = False) Dict[str, Any]

Group parameters for optimization.

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
) List[int]

Prune layers not required for specified intermediates.

Parameters:
  • indices – Indices of intermediate layers to keep.

  • prune_norm – Whether to prune normalization layer.

  • prune_head – Whether to prune the classifier head.

Returns:

List of indices that were kept.

reset_classifier(
num_classes: int,
global_pool: str | None = None,
) None

Reset the classifier head.

Parameters:
  • num_classes – Number of classes for new classifier.

  • global_pool – Global pooling type.

set_grad_checkpointing(enable: bool = True) None

Enable or disable gradient checkpointing.

class timm.models.repghost.RepGhostNet(
cfgs: List[List[List]],
num_classes: int = 1000,
width: float = 1.0,
in_chans: int = 3,
output_stride: int = 32,
global_pool: str = 'avg',
drop_rate: float = 0.2,
reparam: bool = True,
device=None,
dtype=None,
)
forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.repvit.RepVit(
in_chans: int = 3,
img_size: int = 224,
embed_dim: Tuple[int,
...]=(48,
),
depth: Tuple[int,
...]=(2,
),
mlp_ratio: float = 2,
global_pool: str = 'avg',
kernel_size: int = 3,
num_classes: int = 1000,
act_layer: Type[Module] = <class 'torch.nn.modules.activation.GELU'>,
distillation: bool = True,
drop_rate: float = 0.0,
legacy: bool = False,
device=None,
dtype=None,
)
forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.resnet.ResNet(block: ~timm.models.resnet.BasicBlock | ~timm.models.resnet.Bottleneck, layers: ~typing.Tuple[int, ...], num_classes: int = 1000, in_chans: int = 3, output_stride: int = 32, global_pool: str = 'avg', cardinality: int = 1, base_width: int = 64, stem_width: int = 64, stem_type: str = '', replace_stem_pool: bool = False, block_reduce_first: int = 1, down_kernel_size: int = 1, avg_down: bool = False, channels: ~typing.Tuple[int, ...] | None = (64, 128, 256, 512), act_layer: str | ~typing.Callable | ~typing.Type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.ReLU'>, norm_layer: str | ~typing.Callable | ~typing.Type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.batchnorm.BatchNorm2d'>, aa_layer: ~typing.Type[~torch.nn.modules.module.Module] | None = None, drop_rate: float = 0.0, drop_path_rate: float = 0.0, drop_block_rate: float = 0.0, zero_init_last: bool = True, block_args: ~typing.Dict[str, ~typing.Any] | None = None, device=None, dtype=None)

ResNet / ResNeXt / SE-ResNeXt / SE-Net

This class implements all variants of ResNet, ResNeXt, SE-ResNeXt, and SENet that
  • have > 1 stride in the 3x3 conv layer of bottleneck

  • have conv-bn-act ordering

This ResNet impl supports a number of stem and downsample options based on the v1c, v1d, v1e, and v1s variants included in the MXNet Gluon ResNetV1b model. The C and D variants are also discussed in the ‘Bag of Tricks’ paper: https://arxiv.org/pdf/1812.01187. The B variant is equivalent to torchvision default.

ResNet variants (the same modifications can be used in SE/ResNeXt models as well):
  • normal, b - 7x7 stem, stem_width = 64, same as torchvision ResNet, NVIDIA ResNet ‘v1.5’, Gluon v1b

  • c - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64)

  • d - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64), average pool in downsample

  • e - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128), average pool in downsample

  • s - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128)

  • t - 3 layer deep 3x3 stem, stem width = 32 (24, 48, 64), average pool in downsample

  • tn - 3 layer deep 3x3 stem, stem width = 32 (24, 32, 64), average pool in downsample

ResNeXt
  • normal - 7x7 stem, stem_width = 64, standard cardinality and base widths

  • same c,d, e, s variants as ResNet can be enabled

SE-ResNeXt
  • normal - 7x7 stem, stem_width = 64

  • same c, d, e, s variants as ResNet can be enabled

SENet-154 - 3 layer deep 3x3 stem (same as v1c-v1s), stem_width = 64, cardinality=64,

reduction by 2 on width of first bottleneck convolution, 3x3 downsample convs after first block

forward_features(x: Tensor) Tensor

Forward pass through feature extraction layers.

forward_head(x: Tensor, pre_logits: bool = False) Tensor

Forward pass through classifier head.

Parameters:
  • x – Feature tensor.

  • pre_logits – Return features before final classifier layer.

Returns:

Output tensor.

forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor.

  • indices – Take last n blocks if int, all if None, select matching indices if sequence.

  • norm – Apply norm layer to compatible intermediates.

  • stop_early – Stop iterating over blocks when last desired intermediate hit.

  • output_fmt – Shape of intermediate feature outputs.

  • intermediates_only – Only return intermediate features.

Returns:

Features and list of intermediate features or just intermediate features.

get_classifier(
name_only: bool = False,
) str | Module

Get the classifier module.

Parameters:

name_only – Return classifier module name instead of module.

Returns:

Classifier module or name.

group_matcher(coarse: bool = False) Dict[str, str]

Create regex patterns for parameter grouping.

Parameters:

coarse – Use coarse (stage-level) or fine (block-level) grouping.

Returns:

Dictionary mapping group names to regex patterns.

init_weights(zero_init_last: bool = True) None

Initialize model weights.

Parameters:

zero_init_last – Zero-initialize the last BN in each residual branch.

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
) List[int]

Prune layers not required for specified intermediates.

Parameters:
  • indices – Indices of intermediate layers to keep.

  • prune_norm – Whether to prune normalization layers.

  • prune_head – Whether to prune the classifier head.

Returns:

List of indices that were kept.

reset_classifier(num_classes: int, global_pool: str = 'avg') None

Reset the classifier head.

Parameters:
  • num_classes – Number of classes for new classifier.

  • global_pool – Global pooling type.

set_grad_checkpointing(enable: bool = True) None

Enable or disable gradient checkpointing.

Parameters:

enable – Whether to enable gradient checkpointing.

class timm.models.resnetv2.ResNetV2(layers: ~typing.List[int], channels: ~typing.Tuple[int, ...] = (256, 512, 1024, 2048), num_classes: int = 1000, in_chans: int = 3, global_pool: str = 'avg', output_stride: int = 32, width_factor: int = 1, stem_chs: int = 64, stem_type: str = '', avg_down: bool = False, preact: bool = True, basic: bool = False, bottle_ratio: float = 0.25, act_layer: ~typing.Callable = <class 'torch.nn.modules.activation.ReLU'>, norm_layer: ~typing.Callable = functools.partial(<class 'timm.layers.norm_act.GroupNormAct'>, num_groups=32), conv_layer: ~typing.Callable = <class 'timm.layers.std_conv.StdConv2d'>, drop_rate: float = 0.0, drop_path_rate: float = 0.0, zero_init_last: bool = False, device=None, dtype=None)

Implementation of Pre-activation (v2) ResNet mode.

forward_features(x: Tensor) Tensor

Forward pass through feature extraction layers.

Parameters:

x – Input tensor.

Returns:

Feature tensor.

forward_head(
x: Tensor,
pre_logits: bool = False,
) Tensor

Forward pass through classifier head.

Parameters:
  • x – Input features.

  • pre_logits – Return features before final linear layer.

Returns:

Classification logits or features.

forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

get_classifier() Module

Get the classifier head.

group_matcher(coarse: bool = False) Dict[str, Any]

Group parameters for optimization.

init_weights(zero_init_last: bool = True) None

Initialize model weights.

load_pretrained(checkpoint_path: str, prefix: str = 'resnet/') None

Load pretrained weights.

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

reset_classifier(
num_classes: int,
global_pool: str | None = None,
) None

Reset the classifier head.

Parameters:
  • num_classes – Number of classes for new classifier.

  • global_pool – Global pooling type.

set_grad_checkpointing(enable: bool = True) None

Enable or disable gradient checkpointing.

class timm.models.rexnet.RexNet(
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'avg',
output_stride: int = 32,
initial_chs: int = 16,
final_chs: int = 180,
width_mult: float = 1.0,
depth_mult: float = 1.0,
se_ratio: float = 0.08333333333333333,
ch_div: int = 1,
act_layer: str = 'swish',
dw_act_layer: str = 'relu6',
drop_rate: float = 0.2,
drop_path_rate: float = 0.0,
device=None,
dtype=None,
)

ReXNet model architecture.

Based on ReXNet: Diminishing Representational Bottleneck on Convolutional Neural Network - https://arxiv.org/abs/2007.00992

forward_features(x: Tensor) Tensor

Forward pass through feature extraction layers.

Parameters:

x – Input tensor.

Returns:

Feature tensor.

forward_head(x: Tensor, pre_logits: bool = False) Tensor

Forward pass through head.

Parameters:
  • x – Input features.

  • pre_logits – Return features before final linear layer.

Returns:

Classification logits or features.

forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

get_classifier() Module

Get the classifier module.

Returns:

Classifier module.

group_matcher(coarse: bool = False) Dict[str, Any]

Group matcher for parameter groups.

Parameters:

coarse – Whether to use coarse grouping.

Returns:

Dictionary of grouped parameters.

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
) List[int]

Prune layers not required for specified intermediates.

Parameters:
  • indices – Indices of intermediate layers to keep.

  • prune_norm – Whether to prune normalization layer.

  • prune_head – Whether to prune the classifier head.

Returns:

List of indices that were kept.

reset_classifier(
num_classes: int,
global_pool: str | None = None,
device=None,
dtype=None,
) None

Reset the classifier.

Parameters:
  • num_classes – Number of classes for new classifier.

  • global_pool – Global pooling type.

set_grad_checkpointing(enable: bool = True) None

Enable or disable gradient checkpointing.

Parameters:

enable – Whether to enable gradient checkpointing.

class timm.models.selecsls.SelecSls(
cfg,
num_classes: int = 1000,
in_chans: int = 3,
drop_rate: float = 0.0,
global_pool: str = 'avg',
device=None,
dtype=None,
)

SelecSls42 / SelecSls60 / SelecSls84

Parameters:
  • cfg (network config dictionary specifying block type, feature, and head args)

  • num_classes (int, default 1000) – Number of classification classes.

  • in_chans (int, default 3) – Number of input (color) channels.

  • drop_rate (float, default 0.) – Dropout probability before classifier, for training

  • global_pool (str, default 'avg') – Global pooling type. One of ‘avg’, ‘max’, ‘avgmax’, ‘catavgmax’

class timm.models.senet.SENet(
block: Type[Module],
layers: Tuple[int, ...],
groups: int,
reduction: int,
drop_rate: float = 0.2,
in_chans: int = 3,
inplanes: int = 64,
input_3x3: bool = False,
downsample_kernel_size: int = 1,
downsample_padding: int = 0,
num_classes: int = 1000,
global_pool: str = 'avg',
device=None,
dtype=None,
)
class timm.models.sequencer.Sequencer2d(
num_classes: int = 1000,
img_size: int = 224,
in_chans: int = 3,
global_pool: str = 'avg',
layers: Tuple[int,
...]=(4,
3,
8,
3),
patch_sizes: Tuple[int,
...]=(7,
2,
2,
1),
embed_dims: Tuple[int,
...]=(192,
384,
384,
384),
hidden_sizes: Tuple[int,
...]=(48,
96,
96,
96),
mlp_ratios: Tuple[float,
...]=(3.0,
3.0,
3.0,
3.0),
block_layer: Type[Module] = <class 'timm.models.sequencer.Sequencer2dBlock'>,
rnn_layer: Type[Module] = <class 'timm.models.sequencer.LSTM2d'>,
mlp_layer: Type[Module] = <class 'timm.layers.mlp.Mlp'>,
norm_layer: Type[Module] = functools.partial(<class 'torch.nn.modules.normalization.LayerNorm'>,
eps=1e-06),
act_layer: Type[Module] = <class 'torch.nn.modules.activation.GELU'>,
num_rnn_layers: int = 1,
bidirectional: bool = True,
union: str = 'cat',
with_fc: bool = True,
drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
nlhb: bool = False,
stem_norm: bool = False,
device=None,
dtype=None,
)
class timm.models.shvit.SHViT(
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'avg',
embed_dim: Tuple[int,
int,
int]=(128,
256,
384),
partial_dim: Tuple[int,
int,
int]=(32,
64,
96),
qk_dim: Tuple[int,
int,
int]=(16,
16,
16),
depth: Tuple[int,
int,
int]=(1,
2,
3),
types: Tuple[str,
str,
str]=('s',
's',
's'),
drop_rate: float = 0.0,
norm_layer: Type[Module] = <class 'timm.layers.norm.GroupNorm1'>,
act_layer: Type[Module] = <class 'torch.nn.modules.activation.ReLU'>,
device=None,
dtype=None,
)
forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.starnet.StarNet(base_dim: int = 32, depths: ~typing.List[int] = [3, 3, 12, 5], mlp_ratio: int = 4, drop_rate: float = 0.0, drop_path_rate: float = 0.0, act_layer: ~typing.Type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.ReLU6'>, num_classes: int = 1000, in_chans: int = 3, global_pool: str = 'avg', output_stride: int = 32, device=None, dtype=None, **kwargs)
forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.swiftformer.SwiftFormer(layers: ~typing.List[int] = [3, 3, 6, 4], embed_dims: ~typing.List[int] = [48, 56, 112, 220], mlp_ratios: int = 4, downsamples: ~typing.List[bool] = [False, True, True, True], act_layer: ~typing.Type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.GELU'>, down_patch_size: int = 3, down_stride: int = 2, down_pad: int = 1, num_classes: int = 1000, drop_rate: float = 0.0, drop_path_rate: float = 0.0, use_layer_scale: bool = True, layer_scale_init_value: float = 1e-05, global_pool: str = 'avg', output_stride: int = 32, in_chans: int = 3, device=None, dtype=None, **kwargs)
forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.swin_transformer.SwinTransformer(
img_size: int | ~typing.Tuple[int,
int]=224,
patch_size: int = 4,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'avg',
embed_dim: int = 96,
depths: Tuple[int,
...]=(2,
2,
6,
2),
num_heads: Tuple[int,
...]=(3,
6,
12,
24),
head_dim: int | None = None,
window_size: int | ~typing.Tuple[int,
int]=7,
always_partition: bool = False,
strict_img_size: bool = True,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.1,
embed_layer: Type[Module] = <class 'timm.layers.patch_embed.PatchEmbed'>,
norm_layer: str | Type[Module] = <class 'torch.nn.modules.normalization.LayerNorm'>,
weight_init: str = '',
device=None,
dtype=None,
**kwargs,
)

Swin Transformer.

A PyTorch impl ofSwin Transformer: Hierarchical Vision Transformer using Shifted Windows -

https://arxiv.org/pdf/2103.14030

forward_features(x: Tensor) Tensor

Forward pass through feature extraction layers.

forward_head(
x: Tensor,
pre_logits: bool = False,
) Tensor

Forward pass through classifier head.

Parameters:
  • x – Feature tensor.

  • pre_logits – Return features before final classifier.

Returns:

Output tensor.

forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor.

  • indices – Take last n blocks if int, all if None, select matching indices if sequence.

  • norm – Apply norm layer to compatible intermediates.

  • stop_early – Stop iterating over blocks when last desired intermediate hit.

  • output_fmt – Shape of intermediate feature outputs.

  • intermediates_only – Only return intermediate features.

Returns:

List of intermediate features or tuple of (final features, intermediates).

get_classifier() Module

Get the classifier head.

group_matcher(
coarse: bool = False,
) Dict[str, Any]

Group parameters for optimization.

init_weights(mode: str = '', needs_reset: bool = True) None

Initialize model weights.

Parameters:
  • mode – Weight initialization mode (‘jax’, ‘jax_nlhb’, ‘moco’, or ‘’).

  • needs_reset – If True, call reset_parameters() on modules that have it. Set to False when modules have already self-initialized in __init__.

no_weight_decay() Set[str]

Parameters that should not use weight decay.

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
) List[int]

Prune layers not required for specified intermediates.

Parameters:
  • indices – Indices of intermediate layers to keep.

  • prune_norm – Whether to prune normalization layer.

  • prune_head – Whether to prune the classifier head.

Returns:

List of indices that were kept.

reset_classifier(
num_classes: int,
global_pool: str | None = None,
) None

Reset the classifier head.

Parameters:
  • num_classes – Number of classes for new classifier.

  • global_pool – Global pooling type.

set_grad_checkpointing(enable: bool = True) None

Enable or disable gradient checkpointing.

set_input_size(
img_size: Tuple[int, int] | None = None,
patch_size: Tuple[int, int] | None = None,
window_size: Tuple[int, int] | None = None,
window_ratio: int = 8,
always_partition: bool | None = None,
) None

Update the image resolution and window size.

Parameters:
  • img_size – New input resolution, if None current resolution is used.

  • patch_size – New patch size, if None use current patch size.

  • window_size – New window size, if None based on new_img_size // window_div.

  • window_ratio – Divisor for calculating window size from grid size.

  • always_partition – Always partition into windows and shift (even if window size < feat size).

class timm.models.swin_transformer_v2.SwinTransformerV2(
img_size: int | ~typing.Tuple[int,
int]=224,
patch_size: int = 4,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'avg',
embed_dim: int = 96,
depths: Tuple[int,
...]=(2,
2,
6,
2),
num_heads: Tuple[int,
...]=(3,
6,
12,
24),
window_size: int | ~typing.Tuple[int,
int]=7,
always_partition: bool = False,
strict_img_size: bool = True,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.1,
act_layer: str | Callable = 'gelu',
norm_layer: Type[Module] = <class 'torch.nn.modules.normalization.LayerNorm'>,
pretrained_window_sizes: Tuple[int,
...]=(0,
0,
0,
0),
device=None,
dtype=None,
**kwargs,
)

Swin Transformer V2.

A hierarchical vision transformer using shifted windows for efficient self-attention computation with continuous position bias.

A PyTorch impl ofSwin Transformer V2: Scaling Up Capacity and Resolution
forward_features(x: Tensor) Tensor

Forward pass through feature extraction layers.

Parameters:

x – Input tensor of shape (B, C, H, W).

Returns:

Feature tensor of shape (B, H’, W’, C).

forward_head(
x: Tensor,
pre_logits: bool = False,
) Tensor

Forward pass through classification head.

Parameters:
  • x – Feature tensor of shape (B, H, W, C).

  • pre_logits – If True, return features before final linear layer.

Returns:

Logits tensor of shape (B, num_classes) or pre-logits.

forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

get_classifier() Module

Get the classifier head.

Returns:

The classification head module.

group_matcher(
coarse: bool = False,
) Dict[str, Any]

Create parameter group matcher for optimizer parameter groups.

Parameters:

coarse – If True, use coarse grouping.

Returns:

Dictionary mapping group names to regex patterns.

init_weights(needs_reset: bool = True) None

Initialize model weights.

Parameters:

needs_reset – If True, call reset_parameters() on modules (default for after to_empty()). If False, skip reset_parameters() (for __init__ where modules already self-initialized).

no_weight_decay() Set[str]

Get parameter names that should not use weight decay.

Returns:

Set of parameter names to exclude from weight decay.

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

reset_classifier(
num_classes: int,
global_pool: str | None = None,
) None

Reset the classification head.

Parameters:
  • num_classes – Number of classes for new head.

  • global_pool – Global pooling type.

set_grad_checkpointing(enable: bool = True) None

Enable or disable gradient checkpointing.

Parameters:

enable – If True, enable gradient checkpointing.

set_input_size(
img_size: Tuple[int, int] | None = None,
patch_size: Tuple[int, int] | None = None,
window_size: Tuple[int, int] | None = None,
window_ratio: int | None = 8,
always_partition: bool | None = None,
)

Updates the image resolution, window size, and so the pair-wise relative positions.

Parameters:
  • img_size (Optional[Tuple[int, int]]) – New input resolution, if None current resolution is used

  • patch_size (Optional[Tuple[int, int]) – New patch size, if None use current patch size

  • window_size (Optional[int]) – New window size, if None based on new_img_size // window_div

  • window_ratio (int) – divisor for calculating window size from patch grid size

  • always_partition – always partition / shift windows even if feat size is < window

class timm.models.swin_transformer_v2_cr.SwinTransformerV2Cr(
img_size: Tuple[int,
int]=(224,
224),
patch_size: int = 4,
window_size: int | None = None,
window_ratio: int = 8,
always_partition: bool = False,
strict_img_size: bool = True,
in_chans: int = 3,
num_classes: int = 1000,
embed_dim: int = 96,
depths: Tuple[int,
...]=(2,
2,
6,
2),
num_heads: Tuple[int,
...]=(3,
6,
12,
24),
mlp_ratio: float = 4.0,
init_values: float | None = 0.0,
drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
norm_layer: Type[Module] = <class 'torch.nn.modules.normalization.LayerNorm'>,
extra_norm_period: int = 0,
extra_norm_stage: bool = False,
sequential_attn: bool = False,
global_pool: str = 'avg',
weight_init: str = 'reset',
device=None,
dtype=None,
**kwargs: Any,
)
Swin Transformer V2
A PyTorch impl ofSwin Transformer V2: Scaling Up Capacity and Resolution -

https://arxiv.org/pdf/2111.09883

Parameters:
  • img_size – Input resolution.

  • window_size – Window size. If None, grid_size // window_div

  • window_ratio – Window size to patch grid ratio.

  • patch_size – Patch size.

  • in_chans – Number of input channels.

  • depths – Depth of the stage (number of layers).

  • num_heads – Number of attention heads to be utilized.

  • embed_dim – Patch embedding dimension.

  • num_classes – Number of output classes.

  • mlp_ratio – Ratio of the hidden dimension in the FFN to the input channels.

  • drop_rate – Dropout rate.

  • proj_drop_rate – Projection dropout rate.

  • attn_drop_rate – Dropout rate of attention map.

  • drop_path_rate – Stochastic depth rate.

  • norm_layer – Type of normalization layer to be utilized.

  • extra_norm_period – Insert extra norm layer on main branch every N (period) blocks in stage

  • extra_norm_stage – End each stage with an extra norm layer in main branch

  • sequential_attn – If true sequential self-attention is performed.

forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

get_classifier() Module

Method returns the classification head of the model. :returns: Current classification head :rtype: head (nn.Module)

init_weights(needs_reset: bool = True) None

Initialize model weights.

Parameters:

needs_reset – If True, call reset_parameters() on modules (default for after to_empty()). If False, skip reset_parameters() (for __init__ where modules already self-initialized).

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

reset_classifier(
num_classes: int,
global_pool: str | None = None,
) None

Method results the classification head

Parameters:
  • num_classes (int) – Number of classes to be predicted

  • global_pool (str) – Unused

set_input_size(
img_size: Tuple[int, int] | None = None,
window_size: Tuple[int, int] | None = None,
window_ratio: int = 8,
always_partition: bool | None = None,
) None

Updates the image resolution, window size and so the pair-wise relative positions.

Parameters:
  • img_size (Optional[Tuple[int, int]]) – New input resolution, if None current resolution is used

  • window_size (Optional[int]) – New window size, if None based on new_img_size // window_div

  • window_ratio (int) – divisor for calculating window size from patch grid size

  • always_partition – always partition / shift windows even if feat size is < window

class timm.models.tiny_vit.TinyVit(
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'avg',
embed_dims: Tuple[int,
...]=(96,
192,
384,
768),
depths: Tuple[int,
...]=(2,
2,
6,
2),
num_heads: Tuple[int,
...]=(3,
6,
12,
24),
window_sizes: Tuple[int,
...]=(7,
7,
14,
7),
mlp_ratio: float = 4.0,
drop_rate: float = 0.0,
drop_path_rate: float = 0.1,
use_checkpoint: bool = False,
mbconv_expand_ratio: float = 4.0,
local_conv_size: int = 3,
act_layer: Type[Module] = <class 'torch.nn.modules.activation.GELU'>,
device=None,
dtype=None,
)
forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.tnt.TNT(
img_size: int | ~typing.Tuple[int,
int]=224,
patch_size: int | ~typing.Tuple[int,
int]=16,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'token',
embed_dim: int = 768,
inner_dim: int = 48,
depth: int = 12,
num_heads_inner: int = 4,
num_heads_outer: int = 12,
mlp_ratio: float = 4.0,
qkv_bias: bool = False,
drop_rate: float = 0.0,
pos_drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
norm_layer: Type[Module] = <class 'torch.nn.modules.normalization.LayerNorm'>,
first_stride: int = 4,
legacy: bool = False,
device=None,
dtype=None,
)

Transformer in Transformer - https://arxiv.org/abs/2103.00112

forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
return_prefix_tokens: bool = False,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if an int, if is a sequence, select by matching indices

  • return_prefix_tokens – Return both prefix and spatial intermediate tokens

  • norm – Apply norm layer to all intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.tresnet.TResNet(
layers: List[int],
in_chans: int = 3,
num_classes: int = 1000,
width_factor: float = 1.0,
v2: bool = False,
global_pool: str = 'fast',
drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
device=None,
dtype=None,
)
forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.twins.Twins(
img_size: int | ~typing.Tuple[int,
int]=224,
patch_size: int = 4,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'avg',
embed_dims: Tuple[int,
...]=(64,
128,
256,
512),
num_heads: Tuple[int,
...]=(1,
2,
4,
8),
mlp_ratios: Tuple[float,
...]=(4,
4,
4,
4),
depths: Tuple[int,
...]=(3,
4,
6,
3),
sr_ratios: Tuple[int,
...]=(8,
4,
2,
1),
wss: Tuple[int,
...] | None=None,
drop_rate: float = 0.0,
pos_drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
norm_layer: Type[Module] = functools.partial(<class 'torch.nn.modules.normalization.LayerNorm'>,
eps=1e-06),
block_cls: Any = <class 'timm.models.twins.Block'>,
device=None,
dtype=None,
)

Twins Vision Transformer (Revisiting Spatial Attention)

Adapted from PVT (PyramidVisionTransformer) class at https://github.com/whai362/PVT.git

forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates. :param x: Input image tensor :param indices: Take last n blocks if int, all if None, select matching indices if sequence :param norm: Apply norm layer to all intermediates :param stop_early: Stop iterating over blocks when last desired intermediate hit :param output_fmt: Shape of intermediate feature outputs :param intermediates_only: Only return intermediate features

Returns:

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.vgg.VGG(cfg: ~typing.List[~typing.Any], num_classes: int = 1000, in_chans: int = 3, output_stride: int = 32, mlp_ratio: float = 1.0, act_layer: ~typing.Type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.ReLU'>, conv_layer: ~typing.Type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.conv.Conv2d'>, norm_layer: ~typing.Type[~torch.nn.modules.module.Module] | None = None, global_pool: str = 'avg', drop_rate: float = 0.0, device=None, dtype=None)

VGG model architecture.

Based on Very Deep Convolutional Networks for Large-Scale Image Recognition - https://arxiv.org/abs/1409.1556

forward_features(x: Tensor) Tensor

Forward pass through feature extraction layers.

Parameters:

x – Input tensor.

Returns:

Feature tensor.

forward_head(x: Tensor, pre_logits: bool = False) Tensor

Forward pass through head.

Parameters:
  • x – Input features.

  • pre_logits – Return features before final linear layer.

Returns:

Classification logits or features.

get_classifier() Module

Get the classifier module.

Returns:

Classifier module.

group_matcher(coarse: bool = False) Dict[str, Any]

Group matcher for parameter groups.

Parameters:

coarse – Whether to use coarse grouping.

Returns:

Dictionary of grouped parameters.

reset_classifier(num_classes: int, global_pool: str | None = None) None

Reset the classifier.

Parameters:
  • num_classes – Number of classes for new classifier.

  • global_pool – Global pooling type.

set_grad_checkpointing(enable: bool = True) None

Enable or disable gradient checkpointing.

Parameters:

enable – Whether to enable gradient checkpointing.

class timm.models.visformer.Visformer(
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
num_classes: int = 1000,
init_channels: int | None = 32,
embed_dim: int = 384,
depth: int | tuple = 12,
num_heads: int = 6,
mlp_ratio: float = 4.0,
drop_rate: float = 0.0,
pos_drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
norm_layer: Type[Module] = <class 'timm.layers.norm.LayerNorm2d'>,
attn_stage: str = '111',
use_pos_embed: bool = True,
spatial_conv: str = '111',
vit_stem: bool = False,
group: int = 8,
global_pool: str = 'avg',
conv_init: bool = False,
embed_norm: Type[Module] | None = None,
device=None,
dtype=None,
)
class timm.models.vision_transformer.VisionTransformer(
img_size: int | ~typing.Tuple[int,
int]=224,
patch_size: int | ~typing.Tuple[int,
int]=16,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: Literal['',
'avg',
'avgmax',
'max',
'token',
'map',
'prr']='token',
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
qk_norm: bool = False,
scale_attn_norm: bool = False,
scale_mlp_norm: bool = False,
proj_bias: bool = True,
init_values: float | None = None,
class_token: bool = True,
pos_embed: str = 'learn',
no_embed_class: bool = False,
reg_tokens: int = 0,
pre_norm: bool = False,
final_norm: bool = True,
fc_norm: bool | None = None,
pool_include_prefix: bool = False,
dynamic_img_size: bool = False,
dynamic_img_pad: bool = False,
drop_rate: float = 0.0,
pos_drop_rate: float = 0.0,
patch_drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
weight_init: Literal['skip',
'reset',
'jax',
'jax_nlhb',
'moco',
'']='',
fix_init: bool = False,
embed_layer: Callable = <class 'timm.layers.patch_embed.PatchEmbed'>,
embed_norm_layer: str | Callable | Type[Module] | None = None,
norm_layer: str | Callable | Type[Module] | None = None,
act_layer: str | Callable | Type[Module] | None = None,
block_fn: Type[Module] = <class 'timm.models.vision_transformer.Block'>,
mlp_layer: Type[Module] = <class 'timm.layers.mlp.Mlp'>,
attn_layer: str | Callable | Type[Module] = <class 'timm.layers.attention.Attention'>,
device=None,
dtype=None,
)

Vision Transformer

A PyTorch impl ofAn Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
fix_init_weight() None

Apply weight initialization fix (scaling w/ layer index).

forward_features(
x: Tensor,
attn_mask: Tensor | None = None,
is_causal: bool = False,
) Tensor

Forward pass through feature layers (embeddings, transformer blocks, post-transformer norm).

forward_head(
x: Tensor,
pre_logits: bool = False,
) Tensor

Forward pass through classifier head.

Parameters:
  • x – Feature tensor.

  • pre_logits – Return features before final classifier.

Returns:

Output tensor.

forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
return_prefix_tokens: bool = False,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
output_dict: bool = False,
attn_mask: Tensor | None = None,
is_causal: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]] | Dict[str, Any]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • return_prefix_tokens – Return both prefix and spatial intermediate tokens

  • norm – Apply norm layer to all intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

  • output_dict – Return outputs as a dictionary with ‘image_features’ and ‘image_intermediates’ keys

  • attn_mask – Optional attention mask for masked attention (e.g., for NaFlex)

  • is_causal – If True, use causal (autoregressive) masking in attention

Returns:

A tuple with (final_features, intermediates), a list of intermediate features, or a dictionary containing ‘image_features’ and ‘image_intermediates’ (and optionally ‘image_intermediates_prefix’)

get_classifier() Module

Get the classifier head.

get_intermediate_layers(
x: Tensor,
n: int | List[int] | Tuple[int] = 1,
reshape: bool = False,
return_prefix_tokens: bool = False,
norm: bool = False,
attn_mask: Tensor | None = None,
) List[Tensor]

Get intermediate layer outputs (DINO interface compatibility).

NOTE: This API is for backwards compat, favour using forward_intermediates() directly.

Parameters:
  • x – Input tensor.

  • n – Number or indices of layers.

  • reshape – Reshape to NCHW format.

  • return_prefix_tokens – Return prefix tokens.

  • norm – Apply normalization.

Returns:

List of intermediate features.

group_matcher(
coarse: bool = False,
) Dict[str, str | List]

Create regex patterns for parameter grouping.

Parameters:

coarse – Use coarse grouping.

Returns:

Dictionary mapping group names to regex patterns.

init_weights(mode: str = '', needs_reset: bool = True) None

Initialize model weights.

Parameters:
  • mode – Weight initialization mode (‘jax’, ‘jax_nlhb’, ‘moco’, or ‘’).

  • needs_reset – If True, call reset_parameters() on modules that have it. Set to False when modules have already self-initialized in __init__.

load_pretrained(
checkpoint_path: str,
prefix: str = '',
) None

Load pretrained weights.

Parameters:
  • checkpoint_path – Path to checkpoint.

  • prefix – Prefix for state dict keys.

no_weight_decay() Set[str]

Set of parameters that should not use weight decay.

pool(
x: Tensor,
pool_type: str | None = None,
) Tensor

Apply pooling to feature tokens.

Parameters:
  • x – Feature tensor.

  • pool_type – Pooling type override.

Returns:

Pooled features.

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
) List[int]

Prune layers not required for specified intermediates.

Parameters:
  • indices – Indices of intermediate layers to keep.

  • prune_norm – Whether to prune normalization layer.

  • prune_head – Whether to prune the classifier head.

Returns:

List of indices that were kept.

reset_classifier(
num_classes: int,
global_pool: str | None = None,
) None

Reset the classifier head.

Parameters:
  • num_classes – Number of classes for new classifier.

  • global_pool – Global pooling type.

set_grad_checkpointing(enable: bool = True) None

Enable or disable gradient checkpointing.

Parameters:

enable – Whether to enable gradient checkpointing.

set_input_size(
img_size: Tuple[int, int] | None = None,
patch_size: Tuple[int, int] | None = None,
) None

Update the input image resolution and patch size.

Parameters:
  • img_size – New input resolution, if None current resolution is used.

  • patch_size – New patch size, if None existing patch size is used.

class timm.models.vision_transformer_relpos.VisionTransformerRelPos(
img_size: int | ~typing.Tuple[int,
int]=224,
patch_size: int | ~typing.Tuple[int,
int]=16,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: Literal['',
'avg',
'token',
'map']='avg',
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
qk_norm: bool = False,
init_values: float | None = 1e-06,
class_token: bool = False,
fc_norm: bool = False,
rel_pos_type: str = 'mlp',
rel_pos_dim: int | None = None,
shared_rel_pos: bool = False,
drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
weight_init: Literal['skip',
'reset',
'jax',
'moco',
'']='reset',
fix_init: bool = False,
embed_layer: Type[Module] = <class 'timm.layers.patch_embed.PatchEmbed'>,
norm_layer: str | Callable | Type[Module] | None = None,
act_layer: str | Callable | Type[Module] | None = None,
block_fn: Type[Module] = <class 'timm.models.vision_transformer_relpos.RelPosBlock'>,
device=None,
dtype=None,
)

Vision Transformer w/ Relative Position Bias

Differing from classic vit, this impl
  • uses relative position index (swin v1 / beit) or relative log coord + mlp (swin v2) pos embed

  • defaults to no class token (can be enabled)

  • defaults to global avg pool for head (can be changed)

  • layer-scale (residual branch gain) enabled

fix_init_weight() None

Apply weight initialization fix (scaling w/ layer index).

forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
return_prefix_tokens: bool = False,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • return_prefix_tokens – Return both prefix and spatial intermediate tokens

  • norm – Apply norm layer to all intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

init_weights(
mode: str = '',
needs_reset: bool = True,
) None

Initialize model weights.

Parameters:
  • mode – Weight initialization mode (‘jax’, ‘jax_nlhb’, ‘moco’, or ‘’).

  • needs_reset – If True, call reset_parameters() on modules (default for after to_empty()). If False, skip reset_parameters() (for __init__ where modules already self-initialized).

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.vision_transformer_sam.VisionTransformerSAM(img_size: int = 1024, patch_size: int = 16, in_chans: int = 3, num_classes: int = 768, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, qkv_bias: bool = True, qk_norm: bool = False, init_values: float | None = None, pre_norm: bool = False, drop_rate: float = 0.0, pos_drop_rate: float = 0.0, patch_drop_rate: float = 0.0, proj_drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, weight_init: str = '', embed_layer: ~typing.Type[~torch.nn.modules.module.Module] = functools.partial(<class 'timm.layers.patch_embed.PatchEmbed'>, output_fmt=<Format.NHWC: 'NHWC'>, strict_img_size=False), norm_layer: ~typing.Type[~torch.nn.modules.module.Module] | None = <class 'torch.nn.modules.normalization.LayerNorm'>, act_layer: ~typing.Type[~torch.nn.modules.module.Module] | None = <class 'torch.nn.modules.activation.GELU'>, block_fn: ~typing.Type[~torch.nn.modules.module.Module] = <class 'timm.models.vision_transformer_sam.Block'>, mlp_layer: ~typing.Type[~torch.nn.modules.module.Module] = <class 'timm.layers.mlp.Mlp'>, use_abs_pos: bool = True, use_rel_pos: bool = False, use_rope: bool = False, window_size: int = 14, global_attn_indexes: ~typing.Tuple[int, ...] = (), neck_chans: int = 256, global_pool: str = 'avg', head_hidden_size: int | None = None, ref_feat_shape: ~typing.Tuple[~typing.Tuple[int, int], ~typing.Tuple[int, int]] | None = None, device=None, dtype=None)

Vision Transformer for Segment-Anything Model(SAM)

A PyTorch impl ofExploring Plain Vision Transformer Backbones for Object Detection or Segment Anything Model (SAM)
forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to all intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

prune_intermediate_layers(
indices: int | List[int] | None = None,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.volo.VOLO(layers: ~typing.List[int], img_size: int = 224, in_chans: int = 3, num_classes: int = 1000, global_pool: str = 'token', patch_size: int = 8, stem_hidden_dim: int = 64, embed_dims: ~typing.List[int] | None = None, num_heads: ~typing.List[int] | None = None, downsamples: ~typing.Tuple[bool, ...] = (True, False, False, False), outlook_attention: ~typing.Tuple[bool, ...] = (True, False, False, False), mlp_ratio: float = 3.0, qkv_bias: bool = False, drop_rate: float = 0.0, pos_drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, norm_layer: ~typing.Type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.normalization.LayerNorm'>, post_layers: ~typing.Tuple[str, ...] | None = ('ca', 'ca'), use_aux_head: bool = True, use_mix_token: bool = False, pooling_scale: int = 2, device=None, dtype=None)

Vision Outlooker (VOLO) model.

forward_cls(x: Tensor) Tensor

Forward pass through class attention blocks.

Parameters:

x – Input token tensor of shape (B, N, C).

Returns:

Output tensor with class token of shape (B, N+1, C).

forward_features(x: Tensor) Tensor

Forward pass through feature extraction.

Parameters:

x – Input tensor of shape (B, C, H, W).

Returns:

Feature tensor.

forward_head(x: Tensor, pre_logits: bool = False) Tensor

Forward pass through classification head.

Parameters:
  • x – Input feature tensor.

  • pre_logits – Whether to return pre-logits features.

Returns:

Classification logits or pre-logits features.

forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to all intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

forward_tokens(x: Tensor) Tensor

Forward pass through token processing stages.

Parameters:

x – Input tensor of shape (B, H, W, C).

Returns:

Token tensor of shape (B, N, C).

forward_train(
x: Tensor,
) Tensor | Tuple[Tensor, Tensor, Tuple[int, int, int, int]]

Forward pass for training with mix token support.

Parameters:

x – Input tensor of shape (B, C, H, W).

Returns:

tuple of (class_token, aux_tokens, bbox). Otherwise: class_token tensor.

Return type:

If training with mix_token

get_classifier() Module

Get classifier module.

Returns:

The classifier head module.

group_matcher(coarse: bool = False) Dict[str, Any]

Get parameter grouping for optimizer.

Parameters:

coarse – Whether to use coarse grouping.

Returns:

Parameter grouping dictionary.

no_weight_decay() set

Get set of parameters that should not have weight decay.

Returns:

Set of parameter names.

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
) List[int]

Prune layers not required for specified intermediates.

Parameters:
  • indices – Indices of intermediate layers to keep.

  • prune_norm – Whether to prune normalization layer.

  • prune_head – Whether to prune classification head.

Returns:

List of kept intermediate indices.

reset_classifier(num_classes: int, global_pool: str | None = None) None

Reset classifier head.

Parameters:
  • num_classes – Number of classes for new classifier.

  • global_pool – Global pooling type.

set_grad_checkpointing(enable: bool = True) None

Set gradient checkpointing.

Parameters:

enable – Whether to enable gradient checkpointing.

class timm.models.vovnet.VovNet(
cfg: dict,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'avg',
output_stride: int = 32,
norm_layer: Type[Module] = <class 'timm.layers.norm_act.BatchNormAct2d'>,
act_layer: Type[Module] = <class 'torch.nn.modules.activation.ReLU'>,
drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
device=None,
dtype=None,
**kwargs,
)
forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to compatible intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.

class timm.models.xception.Xception(
num_classes: int = 1000,
in_chans: int = 3,
drop_rate: float = 0.0,
global_pool: str = 'avg',
device=None,
dtype=None,
)

Xception optimized for the ImageNet dataset, as specified in https://arxiv.org/pdf/1610.02357.pdf

class timm.models.xception_aligned.XceptionAligned(block_cfg: ~typing.List[~typing.Dict], num_classes: int = 1000, in_chans: int = 3, output_stride: int = 32, preact: bool = False, act_layer: ~typing.Type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.ReLU'>, norm_layer: ~typing.Type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.batchnorm.BatchNorm2d'>, drop_rate: float = 0.0, drop_path_rate: float = 0.0, global_pool: str = 'avg', device=None, dtype=None)

Modified Aligned Xception

class timm.models.xcit.Xcit(
img_size: int | Tuple[int, int] = 224,
patch_size: int = 16,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'token',
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
drop_rate: float = 0.0,
pos_drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
act_layer: Type[Module] | None = None,
norm_layer: Type[Module] | None = None,
cls_attn_layers: int = 2,
use_pos_embed: bool = True,
eta: float = 1.0,
tokens_norm: bool = False,
device=None,
dtype=None,
)

Based on timm and DeiT code bases https://github.com/rwightman/pytorch-image-models/tree/master/timm https://github.com/facebookresearch/deit/

forward_intermediates(
x: Tensor,
indices: int | List[int] | None = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) List[Tensor] | Tuple[Tensor, List[Tensor]]

Forward features that returns intermediates.

Parameters:
  • x – Input image tensor

  • indices – Take last n blocks if int, all if None, select matching indices if sequence

  • norm – Apply norm layer to all intermediates

  • stop_early – Stop iterating over blocks when last desired intermediate hit

  • output_fmt – Shape of intermediate feature outputs

  • intermediates_only – Only return intermediate features

Returns:

prune_intermediate_layers(
indices: int | List[int] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)

Prune layers not required for specified intermediates.