Image Models
This page contains the list of external image models that can be used with EIR, coming from the great timm library.
There are 3 ways to use these models:
Configure and train specific architectures (e.g. ResNet with chosen number of layers) from scratch.
Train a specific architecture (e.g.
resnet18) from scratch.Use a pre-trained model (e.g.
resnet18) and fine-tune it.
Please refer to this page for more detailed information about configurable architectures, and this page for a list of pre-defined architectures, with the option of using pre-trained weights.
Configurable Models
The following models can be configured and trained from scratch.
The model type is specified in the model_type field of the configuration, while the model specific configuration is specified in the model_init_config field.
For example, the ResNet architecture includes the layers and block parameters, and can be configured as follows:
input_info:
input_source: eir_tutorials/a_using_eir/05_image_tutorial/data/hot_dog_not_hot_dog/food_images
input_name: hot_dog
input_type: image
input_type_info:
mixing_subtype: "cutmix"
size:
- 64
model_config:
model_type: "ResNet"
model_init_config:
layers: [1, 1, 1, 1]
block: "BasicBlock"
interpretation_config:
num_samples_to_interpret: 30
- class timm.models.beit.Beit(
- img_size: int | ~typing.Tuple[int,
- int]=224,
- patch_size: int | ~typing.Tuple[int,
- int]=16,
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: str = 'avg',
- embed_dim: int = 768,
- depth: int = 12,
- num_heads: int = 12,
- qkv_bias: bool = True,
- mlp_ratio: float = 4.0,
- swiglu_mlp: bool = False,
- scale_mlp: bool = False,
- drop_rate: float = 0.0,
- pos_drop_rate: float = 0.0,
- proj_drop_rate: float = 0.0,
- attn_drop_rate: float = 0.0,
- drop_path_rate: float = 0.0,
- norm_layer: Type[Module] = <class 'timm.layers.norm.LayerNorm'>,
- init_values: float | None = None,
- use_abs_pos_emb: bool = True,
- use_rel_pos_bias: bool = False,
- use_shared_rel_pos_bias: bool = False,
- head_init_scale: float = 0.001,
- device=None,
- dtype=None,
BEiT: BERT Pre-Training of Image Transformers.
Vision Transformer model with support for relative position bias and shared relative position bias across layers. Implements both BEiT v1 and v2 architectures with flexible configuration options.
- fix_init_weight() None
Fix initialization weights according to BEiT paper.
Rescales attention and MLP weights based on layer depth to improve training stability.
- forward_features(x: Tensor) Tensor
Forward pass through feature extraction layers.
- Parameters:
x – Input tensor of shape (batch_size, channels, height, width).
- Returns:
Feature tensor of shape (batch_size, num_tokens, embed_dim).
- forward_head(x: Tensor, pre_logits: bool = False) Tensor
Forward pass through classification head.
- Parameters:
x – Feature tensor of shape (batch_size, num_tokens, embed_dim).
pre_logits – If True, return features before final linear layer.
- Returns:
Logits tensor of shape (batch_size, num_classes) or pre-logits.
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- return_prefix_tokens: bool = False,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward pass that returns intermediate feature maps.
- Parameters:
x – Input image tensor of shape (batch_size, channels, height, width).
indices – Block indices to return features from. If int, returns last n blocks.
return_prefix_tokens – If True, return both prefix and spatial tokens.
norm – If True, apply normalization to intermediate features.
stop_early – If True, stop at last selected intermediate.
output_fmt – Output format (‘NCHW’ or ‘NLC’).
intermediates_only – If True, only return intermediate features.
- Returns:
If intermediates_only is True, returns list of intermediate tensors. Otherwise, returns tuple of (final_features, intermediates).
- get_classifier() Module
Get the classifier head.
- Returns:
The classification head module.
- group_matcher(coarse: bool = False) Dict[str, Any]
Create parameter group matcher for optimizer parameter groups.
- Parameters:
coarse – If True, use coarse grouping.
- Returns:
Dictionary mapping group names to regex patterns.
- init_weights(needs_reset: bool = True) None
Initialize model weights.
- Parameters:
needs_reset – If True, call reset_parameters() on modules that have it. Set to False when modules have already self-initialized in __init__.
- no_weight_decay() Set[str]
Get parameter names that should not use weight decay.
- Returns:
Set of parameter names to exclude from weight decay.
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediate outputs.
- Parameters:
indices – Indices of blocks to keep.
prune_norm – If True, remove final normalization.
prune_head – If True, remove classification head.
- Returns:
List of indices that were kept.
- reset_classifier(num_classes: int, global_pool: str | None = None)
Reset the classification head.
- Parameters:
num_classes – Number of classes for new head.
global_pool – Global pooling type.
- set_grad_checkpointing(enable: bool = True)
Enable or disable gradient checkpointing.
- Parameters:
enable – If True, enable gradient checkpointing.
- class timm.models.byobnet.ByobNet(
- cfg: ByoModelCfg,
- num_classes: int = 1000,
- in_chans: int = 3,
- global_pool: str | None = None,
- output_stride: int = 32,
- img_size: int | Tuple[int, int] | None = None,
- drop_rate: float = 0.0,
- drop_block_rate: float = 0.0,
- drop_block_size: int = 3,
- drop_path_rate: float = 0.0,
- zero_init_last: bool = True,
- device=None,
- dtype=None,
- **kwargs,
Bring-your-own-blocks Network.
A flexible network backbone that allows building model stem + blocks via dataclass cfg definition w/ factory functions for module instantiation.
Current assumption is that both stem and blocks are in conv-bn-act order (w/ block ending in act).
- forward_features(x: Tensor) Tensor
Forward pass through feature extraction.
- Parameters:
x – Input tensor.
- Returns:
Feature tensor.
- forward_head(
- x: Tensor,
- pre_logits: bool = False,
Forward pass through head.
- Parameters:
x – Input features.
pre_logits – Return features before final linear layer.
- Returns:
Classification logits or features.
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
- exclude_final_conv: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
exclude_final_conv – Exclude final_conv from last intermediate
Returns:
- get_classifier() Module
Get classifier module.
- Returns:
Classifier module.
- group_matcher(coarse: bool = False) Dict[str, Any]
Group matcher for parameter groups.
- Parameters:
coarse – Whether to use coarse grouping.
- Returns:
Dictionary mapping group names to patterns.
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- Parameters:
indices – Indices of intermediate layers to keep.
prune_norm – Whether to prune normalization layer.
prune_head – Whether to prune the classifier head.
- Returns:
List of indices that were kept.
- reset_classifier(
- num_classes: int,
- global_pool: str | None = None,
Reset classifier.
- Parameters:
num_classes – Number of classes for new classifier.
global_pool – Global pooling type.
- set_grad_checkpointing(enable: bool = True) None
Enable or disable gradient checkpointing.
- Parameters:
enable – Whether to enable gradient checkpointing.
- class timm.models.cait.Cait(
- img_size: int = 224,
- patch_size: int = 16,
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: str = 'token',
- embed_dim: int = 768,
- depth: int = 12,
- num_heads: int = 12,
- mlp_ratio: float = 4.0,
- qkv_bias: bool = True,
- drop_rate: float = 0.0,
- pos_drop_rate: float = 0.0,
- proj_drop_rate: float = 0.0,
- attn_drop_rate: float = 0.0,
- drop_path_rate: float = 0.0,
- block_layers: Type[Module] = <class 'timm.models.cait.LayerScaleBlock'>,
- block_layers_token: Type[Module] = <class 'timm.models.cait.LayerScaleBlockClassAttn'>,
- patch_layer: Type[Module] = <class 'timm.layers.patch_embed.PatchEmbed'>,
- norm_layer: Type[Module] = functools.partial(<class 'torch.nn.modules.normalization.LayerNorm'>,
- eps=1e-06),
- act_layer: Type[Module] = <class 'torch.nn.modules.activation.GELU'>,
- attn_block: Type[Module] = <class 'timm.models.cait.TalkingHeadAttn'>,
- mlp_block: Type[Module] = <class 'timm.layers.mlp.Mlp'>,
- init_values: float = 0.0001,
- attn_block_token_only: Type[Module] = <class 'timm.models.cait.ClassAttn'>,
- mlp_block_token_only: Type[Module] = <class 'timm.layers.mlp.Mlp'>,
- depth_token_only: int = 2,
- mlp_ratio_token_only: float = 4.0,
- device=None,
- dtype=None,
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to all intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.coat.CoaT(
- img_size: int = 224,
- patch_size: int = 16,
- in_chans: int = 3,
- num_classes: int = 1000,
- embed_dims: Tuple[int,
- int,
- int,
- int]=(64,
- 128,
- 320,
- 512),
- serial_depths: Tuple[int,
- int,
- int,
- int]=(3,
- 4,
- 6,
- 3),
- parallel_depth: int = 0,
- num_heads: int = 8,
- mlp_ratios: Tuple[float,
- float,
- float,
- float]=(4,
- 4,
- 4,
- 4),
- qkv_bias: bool = True,
- drop_rate: float = 0.0,
- proj_drop_rate: float = 0.0,
- attn_drop_rate: float = 0.0,
- drop_path_rate: float = 0.0,
- norm_layer: Type[Module] = <class 'timm.layers.norm.LayerNorm'>,
- return_interm_layers: bool = False,
- out_features: List[str] | None = None,
- crpe_window: dict | None = None,
- global_pool: str = 'token',
- device=None,
- dtype=None,
CoaT class.
- class timm.models.convit.ConVit(
- img_size: int = 224,
- patch_size: int = 16,
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: str = 'token',
- embed_dim: int = 768,
- depth: int = 12,
- num_heads: int = 12,
- mlp_ratio: float = 4.0,
- qkv_bias: bool = False,
- drop_rate: float = 0.0,
- pos_drop_rate: float = 0.0,
- proj_drop_rate: float = 0.0,
- attn_drop_rate: float = 0.0,
- drop_path_rate: float = 0.0,
- hybrid_backbone: Any | None = None,
- norm_layer: Type[Module] = <class 'timm.layers.norm.LayerNorm'>,
- local_up_to_layer: int = 3,
- locality_strength: float = 1.0,
- use_pos_embed: bool = True,
- device=None,
- dtype=None,
Vision Transformer with support for patch or hybrid CNN input stage
- class timm.models.convmixer.ConvMixer(
- dim: int,
- depth: int,
- kernel_size: int = 9,
- patch_size: int = 7,
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: str = 'avg',
- drop_rate: float = 0.0,
- act_layer: Type[Module] = <class 'torch.nn.modules.activation.GELU'>,
- device=None,
- dtype=None,
- **kwargs,
- class timm.models.convnext.ConvNeXt(
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: str = 'avg',
- output_stride: int = 32,
- depths: Tuple[int, ...] = (3, 3, 9, 3),
- dims: Tuple[int, ...] = (96, 192, 384, 768),
- kernel_sizes: int | Tuple[int, ...] = 7,
- ls_init_value: float | None = 1e-06,
- stem_type: str = 'patch',
- patch_size: int = 4,
- head_init_scale: float = 1.0,
- head_norm_first: bool = False,
- head_hidden_size: int | None = None,
- conv_mlp: bool = False,
- conv_bias: bool = True,
- use_grn: bool = False,
- act_layer: str | Callable = 'gelu',
- norm_layer: str | Callable | None = None,
- norm_eps: float | None = None,
- drop_rate: float = 0.0,
- drop_path_rate: float = 0.0,
- device=None,
- dtype=None,
ConvNeXt model architecture.
A PyTorch impl of : A ConvNet for the 2020s - https://arxiv.org/pdf/2201.03545.pdf
- forward_features(x: Tensor) Tensor
Forward pass through feature extraction layers.
- forward_head(
- x: Tensor,
- pre_logits: bool = False,
Forward pass through classifier head.
- Parameters:
x – Feature tensor.
pre_logits – Return features before final classifier.
- Returns:
Output tensor.
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor.
indices – Take last n blocks if int, all if None, select matching indices if sequence.
norm – Apply norm layer to compatible intermediates.
stop_early – Stop iterating over blocks when last desired intermediate hit.
output_fmt – Shape of intermediate feature outputs.
intermediates_only – Only return intermediate features.
- Returns:
List of intermediate features or tuple of (final features, intermediates).
- get_classifier() Module
Get the classifier module.
- group_matcher(
- coarse: bool = False,
Create regex patterns for parameter grouping.
- Parameters:
coarse – Use coarse grouping.
- Returns:
Dictionary mapping group names to regex patterns.
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- Parameters:
indices – Indices of intermediate layers to keep.
prune_norm – Whether to prune normalization layer.
prune_head – Whether to prune the classifier head.
- Returns:
List of indices that were kept.
- reset_classifier(
- num_classes: int,
- global_pool: str | None = None,
Reset the classifier head.
- Parameters:
num_classes – Number of classes for new classifier.
global_pool – Global pooling type.
- set_grad_checkpointing(enable: bool = True) None
Enable or disable gradient checkpointing.
- Parameters:
enable – Whether to enable gradient checkpointing.
- class timm.models.crossvit.CrossVit(img_size: int = 224, img_scale: ~typing.Tuple[float, ...] = (1.0, 1.0), patch_size: ~typing.Tuple[int, ...] = (8, 16), in_chans: int = 3, num_classes: int = 1000, embed_dim: ~typing.Tuple[int, ...] = (192, 384), depth: ~typing.Tuple[~typing.Tuple[int, ...], ...] = ((1, 3, 1), (1, 3, 1), (1, 3, 1)), num_heads: ~typing.Tuple[int, ...] = (6, 12), mlp_ratio: ~typing.Tuple[float, ...] = (2.0, 2.0, 4.0), multi_conv: bool = False, crop_scale: bool = False, qkv_bias: bool = True, drop_rate: float = 0.0, pos_drop_rate: float = 0.0, proj_drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, norm_layer: ~typing.Type[~torch.nn.modules.module.Module] = functools.partial(<class 'torch.nn.modules.normalization.LayerNorm'>, eps=1e-06), global_pool: str = 'token', device=None, dtype=None)
Vision Transformer with support for patch or hybrid CNN input stage
- class timm.models.cspnet.CspNet(
- cfg: CspModelCfg,
- in_chans: int = 3,
- num_classes: int = 1000,
- output_stride: int = 32,
- global_pool: str = 'avg',
- drop_rate: float = 0.0,
- drop_path_rate: float = 0.0,
- zero_init_last: bool = True,
- device=None,
- dtype=None,
- **kwargs,
Cross Stage Partial base model.
Paper: CSPNet: A New Backbone that can Enhance Learning Capability of CNN - https://arxiv.org/abs/1911.11929 Ref Impl: https://github.com/WongKinYiu/CrossStagePartialNetworks
NOTE: There are differences in the way I handle the 1x1 ‘expansion’ conv in this impl vs the darknet impl. I did it this way for simplicity and less special cases.
- class timm.models.davit.DaVit(
- in_chans: int = 3,
- depths: Tuple[int, ...] = (1, 1, 3, 1),
- embed_dims: Tuple[int, ...] = (96, 192, 384, 768),
- num_heads: Tuple[int, ...] = (3, 6, 12, 24),
- window_size: int = 7,
- mlp_ratio: float = 4,
- qkv_bias: bool = True,
- norm_layer: str = 'layernorm2d',
- norm_layer_cl: str = 'layernorm',
- norm_eps: float = 1e-05,
- attn_types: Tuple[str, ...] = ('spatial', 'channel'),
- ffn: bool = True,
- cpe_act: bool = False,
- down_kernel_size: int = 2,
- channel_attn_v2: bool = False,
- named_blocks: bool = False,
- drop_rate: float = 0.0,
- drop_path_rate: float = 0.0,
- num_classes: int = 1000,
- global_pool: str = 'avg',
- head_norm_first: bool = False,
- device=None,
- dtype=None,
- DaViT
A PyTorch implementation of DaViT: Dual Attention Vision Transformers - https://arxiv.org/abs/2204.03645 Supports arbitrary input sizes and pyramid feature extraction
- Parameters:
in_chans (int) – Number of input image channels. Default: 3
num_classes (int) – Number of classes for classification head. Default: 1000
depths (tuple(int)) – Number of blocks in each stage. Default: (1, 1, 3, 1)
embed_dims (tuple(int)) – Patch embedding dimension. Default: (96, 192, 384, 768)
num_heads (tuple(int)) – Number of attention heads in different layers. Default: (3, 6, 12, 24)
window_size (int) – Window size. Default: 7
mlp_ratio (float) – Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool) – If True, add a learnable bias to query, key, value. Default: True
drop_path_rate (float) – Stochastic depth rate. Default: 0.1
norm_layer (nn.Module) – Normalization layer. Default: nn.LayerNorm.
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.deit.VisionTransformerDistilled(*args, **kwargs)
Vision Transformer w/ Distillation Token and Head
- Distillation token & head support for DeiT: Data-efficient Image Transformers
- forward_head(
- x,
- pre_logits: bool = False,
Forward pass through classifier head.
- Parameters:
x – Feature tensor.
pre_logits – Return features before final classifier.
- Returns:
Output tensor.
- get_classifier() Module
Get the classifier head.
- group_matcher(coarse=False)
Create regex patterns for parameter grouping.
- Parameters:
coarse – Use coarse grouping.
- Returns:
Dictionary mapping group names to regex patterns.
- init_weights(mode='', needs_reset=True)
Initialize model weights.
- Parameters:
mode – Weight initialization mode (‘jax’, ‘jax_nlhb’, ‘moco’, or ‘’).
needs_reset – If True, call reset_parameters() on modules that have it. Set to False when modules have already self-initialized in __init__.
- reset_classifier(
- num_classes: int,
- global_pool: str | None = None,
Reset the classifier head.
- Parameters:
num_classes – Number of classes for new classifier.
global_pool – Global pooling type.
- class timm.models.densenet.DenseNet(
- growth_rate: int = 32,
- block_config: Tuple[int, ...] = (6, 12, 24, 16),
- num_classes: int = 1000,
- in_chans: int = 3,
- global_pool: str = 'avg',
- bn_size: int = 4,
- stem_type: str = '',
- act_layer: str = 'relu',
- norm_layer: str = 'batchnorm2d',
- aa_layer: Type[Module] | None = None,
- drop_rate: float = 0.0,
- proj_drop_rate: float = 0.0,
- memory_efficient: bool = False,
- aa_stem_only: bool = True,
- device=None,
- dtype=None,
Densenet-BC model class.
Based on “Densely Connected Convolutional Networks”
- Parameters:
growth_rate – How many filters to add each layer (k in paper).
block_config – How many layers in each pooling block.
bn_size – Multiplicative factor for number of bottle neck layers (i.e. bn_size * k features in the bottleneck layer).
drop_rate – Dropout rate before classifier layer.
proj_drop_rate – Dropout rate after each dense layer.
num_classes – Number of classification classes.
memory_efficient – If True, uses checkpointing. Much more memory efficient, but slower. Default: False. See “paper”.
- forward_features(x: Tensor) Tensor
Forward pass through feature extraction layers.
- forward_head(
- x: Tensor,
- pre_logits: bool = False,
Forward pass through classifier head.
- Parameters:
x – Feature tensor.
pre_logits – Return features before final classifier.
- Returns:
Output tensor.
- get_classifier() Module
Get the classifier head.
- group_matcher(coarse: bool = False) Dict[str, Any]
Group parameters for optimization.
- reset_classifier(num_classes: int, global_pool: str = 'avg') None
Reset the classifier head.
- Parameters:
num_classes – Number of classes for new classifier.
global_pool – Global pooling type.
- set_grad_checkpointing(enable: bool = True) None
Enable or disable gradient checkpointing.
- class timm.models.dla.DLA(levels: ~typing.Tuple[int, ...], channels: ~typing.Tuple[int, ...], output_stride: int = 32, num_classes: int = 1000, in_chans: int = 3, global_pool: str = 'avg', cardinality: int = 1, base_width: int = 64, block: ~typing.Type[~torch.nn.modules.module.Module] = <class 'timm.models.dla.DlaBottle2neck'>, shortcut_root: bool = False, drop_rate: float = 0.0, device=None, dtype=None)
- class timm.models.dpn.DPN(
- k_sec: Tuple[int, ...] = (3, 4, 20, 3),
- inc_sec: Tuple[int, ...] = (16, 32, 24, 128),
- k_r: int = 96,
- groups: int = 32,
- num_classes: int = 1000,
- in_chans: int = 3,
- output_stride: int = 32,
- global_pool: str = 'avg',
- small: bool = False,
- num_init_features: int = 64,
- b: bool = False,
- drop_rate: float = 0.0,
- norm_layer: str = 'batchnorm2d',
- act_layer: str = 'relu',
- fc_act_layer: str = 'elu',
- device=None,
- dtype=None,
- class timm.models.edgenext.EdgeNeXt(
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: str = 'avg',
- dims: Tuple[int,
- ...]=(24,
- 48,
- 88,
- 168),
- depths: Tuple[int,
- ...]=(3,
- 3,
- 9,
- 3),
- global_block_counts: Tuple[int,
- ...]=(0,
- 1,
- 1,
- 1),
- kernel_sizes: Tuple[int,
- ...]=(3,
- 5,
- 7,
- 9),
- heads: Tuple[int,
- ...]=(8,
- 8,
- 8,
- 8),
- d2_scales: Tuple[int,
- ...]=(2,
- 2,
- 3,
- 4),
- use_pos_emb: Tuple[bool,
- ...]=(False,
- True,
- False,
- False),
- ls_init_value: float = 1e-06,
- head_init_scale: float = 1.0,
- expand_ratio: float = 4,
- downsample_block: bool = False,
- conv_bias: bool = True,
- stem_type: str = 'patch',
- head_norm_first: bool = False,
- act_layer: Type[Module] = <class 'torch.nn.modules.activation.GELU'>,
- drop_path_rate: float = 0.0,
- drop_rate: float = 0.0,
- device=None,
- dtype=None,
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.efficientformer.EfficientFormer(
- depths: Tuple[int,
- ...]=(3,
- 2,
- 6,
- 4),
- embed_dims: Tuple[int,
- ...]=(48,
- 96,
- 224,
- 448),
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: str = 'avg',
- downsamples: Tuple[bool,
- ...] | None=None,
- num_vit: int = 0,
- mlp_ratios: float = 4,
- pool_size: int = 3,
- layer_scale_init_value: float = 1e-05,
- act_layer: Type[Module] = <class 'torch.nn.modules.activation.GELU'>,
- norm_layer: Type[Module] = <class 'torch.nn.modules.batchnorm.BatchNorm2d'>,
- norm_layer_cl: Type[Module] = <class 'torch.nn.modules.normalization.LayerNorm'>,
- drop_rate: float = 0.0,
- proj_drop_rate: float = 0.0,
- drop_path_rate: float = 0.0,
- device=None,
- dtype=None,
- **kwargs,
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.efficientformer_v2.EfficientFormerV2(
- depths: Tuple[int, ...],
- in_chans: int = 3,
- img_size: int | Tuple[int, int] = 224,
- global_pool: str = 'avg',
- embed_dims: Tuple[int, ...] | None = None,
- downsamples: Tuple[bool, ...] | None = None,
- mlp_ratios: float | Tuple[float, ...] | Tuple[Tuple[float, ...], ...] = 4,
- norm_layer: str = 'batchnorm2d',
- norm_eps: float = 1e-05,
- act_layer: str = 'gelu',
- num_classes: int = 1000,
- drop_rate: float = 0.0,
- proj_drop_rate: float = 0.0,
- drop_path_rate: float = 0.0,
- layer_scale_init_value: float | None = 1e-05,
- num_vit: int = 0,
- distillation: bool = True,
- device=None,
- dtype=None,
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.efficientnet.EfficientNet(block_args: ~typing.List[~typing.List[~typing.Dict[str, ~typing.Any]]], num_classes: int = 1000, num_features: int = 1280, in_chans: int = 3, stem_size: int = 32, stem_kernel_size: int = 3, fix_stem: bool = False, output_stride: int = 32, pad_type: str = '', act_layer: str | ~typing.Callable | ~typing.Type[~torch.nn.modules.module.Module] | None = None, norm_layer: str | ~typing.Callable | ~typing.Type[~torch.nn.modules.module.Module] | None = None, aa_layer: str | ~typing.Callable | ~typing.Type[~torch.nn.modules.module.Module] | None = None, se_layer: str | ~typing.Callable | ~typing.Type[~torch.nn.modules.module.Module] | None = None, round_chs_fn: ~typing.Callable = <function round_channels>, drop_rate: float = 0.0, drop_path_rate: float = 0.0, global_pool: str = 'avg', device=None, dtype=None)
EfficientNet model architecture.
- A flexible and performant PyTorch implementation of efficient network architectures, including:
EfficientNet-V2 Small, Medium, Large, XL & B0-B3
EfficientNet B0-B8, L2
EfficientNet-EdgeTPU
EfficientNet-CondConv
MixNet S, M, L, XL
MnasNet A1, B1, and small
MobileNet-V2
FBNet C
Single-Path NAS Pixel1
TinyNet
References
EfficientNet: https://arxiv.org/abs/1905.11946
EfficientNetV2: https://arxiv.org/abs/2104.00298
MixNet: https://arxiv.org/abs/1907.09595
MnasNet: https://arxiv.org/abs/1807.11626
- as_sequential() Sequential
Convert model to sequential for feature extraction.
- forward_features(x: Tensor) Tensor
Forward pass through feature extraction layers.
- forward_head(
- x: Tensor,
- pre_logits: bool = False,
Forward pass through classifier head.
- Parameters:
x – Feature tensor.
pre_logits – Return features before final classifier.
- Returns:
Output tensor.
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
- extra_blocks: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor.
indices – Take last n blocks if int, all if None, select matching indices if sequence.
norm – Apply norm layer to compatible intermediates.
stop_early – Stop iterating over blocks when last desired intermediate hit.
output_fmt – Shape of intermediate feature outputs.
intermediates_only – Only return intermediate features.
extra_blocks – Include outputs of all blocks and head conv in output, does not align with feature_info.
- Returns:
List of intermediate features or tuple of (final features, intermediates).
- get_classifier() Module
Get the classifier module.
- group_matcher(
- coarse: bool = False,
Create regex patterns for parameter groups.
- Parameters:
coarse – Use coarse (stage-level) grouping.
- Returns:
Dictionary mapping group names to regex patterns.
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
- extra_blocks: bool = False,
Prune layers not required for specified intermediates.
- Parameters:
indices – Indices of intermediate layers to keep.
prune_norm – Whether to prune normalization layers.
prune_head – Whether to prune the classifier head.
extra_blocks – Include all blocks in indexing.
- Returns:
List of indices that were kept.
- reset_classifier(
- num_classes: int,
- global_pool: str = 'avg',
Reset the classifier head.
- Parameters:
num_classes – Number of classes for new classifier.
global_pool – Global pooling type.
- set_grad_checkpointing(enable: bool = True) None
Enable or disable gradient checkpointing.
- Parameters:
enable – Whether to enable gradient checkpointing.
- class timm.models.efficientvit_mit.EfficientVit(
- in_chans: int = 3,
- widths: Tuple[int,
- ...]=(),
- depths: Tuple[int,
- ...]=(),
- head_dim: int = 32,
- expand_ratio: float = 4,
- norm_layer: Type[Module] = <class 'torch.nn.modules.batchnorm.BatchNorm2d'>,
- act_layer: Type[Module] = <class 'torch.nn.modules.activation.Hardswish'>,
- global_pool: str = 'avg',
- head_widths: Tuple[int,
- ...]=(),
- drop_rate: float = 0.0,
- num_classes: int = 1000,
- device=None,
- dtype=None,
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.efficientvit_msra.EfficientVitMsra(
- img_size: int = 224,
- in_chans: int = 3,
- num_classes: int = 1000,
- embed_dim: Tuple[int, ...] = (64, 128, 192),
- key_dim: Tuple[int, ...] = (16, 16, 16),
- depth: Tuple[int, ...] = (1, 2, 3),
- num_heads: Tuple[int, ...] = (4, 4, 4),
- window_size: Tuple[int, ...] = (7, 7, 7),
- kernels: Tuple[int, ...] = (5, 5, 5, 5),
- down_ops: Tuple[Tuple[str, int], ...] = (('', 1), ('subsample', 2), ('subsample', 2)),
- global_pool: str = 'avg',
- drop_rate: float = 0.0,
- device=None,
- dtype=None,
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.eva.Eva(
- img_size: int | ~typing.Tuple[int,
- int]=224,
- patch_size: int | ~typing.Tuple[int,
- int]=16,
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: str = 'avg',
- embed_dim: int = 768,
- depth: int = 12,
- num_heads: int = 12,
- qkv_bias: bool = True,
- qkv_fused: bool = True,
- mlp_ratio: float = 4.0,
- swiglu_mlp: bool = False,
- swiglu_align_to: int = 0,
- scale_mlp: bool = False,
- scale_attn_inner: bool = False,
- attn_type: str = 'eva',
- drop_rate: float = 0.0,
- pos_drop_rate: float = 0.0,
- patch_drop_rate: float = 0.0,
- proj_drop_rate: float = 0.0,
- attn_drop_rate: float = 0.0,
- drop_path_rate: float = 0.0,
- norm_layer: Callable = <class 'timm.layers.norm.LayerNorm'>,
- init_values: float | None = None,
- class_token: bool = True,
- num_reg_tokens: int = 0,
- no_embed_class: bool = False,
- use_abs_pos_emb: bool = True,
- use_rot_pos_emb: bool = False,
- rope_type: str | None = 'cat',
- rope_grid_offset: float = 0.0,
- rope_grid_indexing: str = 'ij',
- rope_temperature: float = 10000.0,
- rope_rotate_half: bool = False,
- use_post_norm: bool = False,
- use_pre_transformer_norm: bool = False,
- use_post_transformer_norm: bool | None = None,
- use_fc_norm: bool | None = None,
- attn_pool_num_heads: int | None = None,
- attn_pool_mlp_ratio: float | None = None,
- dynamic_img_size: bool = False,
- dynamic_img_pad: bool = False,
- ref_feat_shape: int | ~typing.Tuple[int,
- int] | None=None,
- head_init_scale: float = 0.001,
- device=None,
- dtype=None,
Eva Vision Transformer w/ Abs & Rotary Pos Embed
- This class implements the EVA and EVA02 models that were based on the BEiT ViT variant
EVA - abs pos embed, global avg pool
EVA02 - abs + rope pos embed, global avg pool, SwiGLU, scale Norm in MLP (ala normformer)
- fix_init_weight() None
Fix initialization weights by rescaling based on layer depth.
- forward_features(
- x: Tensor,
- attn_mask: Tensor | None = None,
- is_causal: bool = False,
Forward pass through feature extraction layers.
- Parameters:
x – Input tensor.
attn_mask – Optional attention mask for masked attention
is_causal – If True, use causal (autoregressive) masking in attention.
- Returns:
Feature tensor.
- forward_head(x: Tensor, pre_logits: bool = False) Tensor
Forward pass through classifier head.
- Parameters:
x – Feature tensor.
pre_logits – Return pre-logits if True.
- Returns:
Output tensor.
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- return_prefix_tokens: bool = False,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
- attn_mask: Tensor | None = None,
- is_causal: bool = False,
Forward features that returns intermediates. :param x: Input image tensor :param indices: Take last n blocks if an int, if is a sequence, select by matching indices :param return_prefix_tokens: Return both prefix and spatial intermediate tokens :param norm: Apply norm layer to all intermediates :param stop_early: Stop iterating over blocks when last desired intermediate hit :param output_fmt: Shape of intermediate feature outputs :param intermediates_only: Only return intermediate features :param attn_mask: Optional attention mask for masked attention :param is_causal: If True, use causal (autoregressive) masking in attention
- group_matcher(coarse: bool = False) Dict[str, Any]
Create layer groupings for optimization.
- no_weight_decay() Set[str]
Parameters to exclude from weight decay.
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- reset_classifier(num_classes: int, global_pool: str | None = None) None
Reset the classifier head.
- Parameters:
num_classes – Number of output classes.
global_pool – Global pooling type.
- set_grad_checkpointing(enable: bool = True) None
Enable or disable gradient checkpointing.
- set_input_size(
- img_size: Tuple[int, int] | None = None,
- patch_size: Tuple[int, int] | None = None,
Update the input image resolution and patch size.
- Parameters:
img_size – New input resolution, if None current resolution is used.
patch_size – New patch size, if None existing patch size is used.
- class timm.models.fasternet.FasterNet(
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: str = 'avg',
- embed_dim: int = 96,
- depths: int | ~typing.Tuple[int,
- ...]=(1,
- 2,
- 8,
- 2),
- mlp_ratio: float = 2.0,
- n_div: int = 4,
- patch_size: int | ~typing.Tuple[int,
- int]=4,
- merge_size: int | ~typing.Tuple[int,
- int]=2,
- patch_norm: bool = True,
- feature_dim: int = 1280,
- drop_rate: float = 0.0,
- drop_path_rate: float = 0.1,
- layer_scale_init_value: float = 0.0,
- act_layer: Type[Module] = functools.partial(<class 'torch.nn.modules.activation.ReLU'>,
- inplace=True),
- norm_layer: Type[Module] = <class 'torch.nn.modules.batchnorm.BatchNorm2d'>,
- pconv_fw_type: str = 'split_cat',
- device=None,
- dtype=None,
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.fastvit.FastVit(
- in_chans: int = 3,
- layers: Tuple[int,
- ...]=(2,
- 2,
- 6,
- 2),
- token_mixers: Tuple[str,
- ...]=('repmixer',
- 'repmixer',
- 'repmixer',
- 'repmixer'),
- embed_dims: Tuple[int,
- ...]=(64,
- 128,
- 256,
- 512),
- mlp_ratios: Tuple[float,
- ...]=(4,
- 4,
- 4,
- 4),
- downsamples: Tuple[bool,
- ...]=(False,
- True,
- True,
- True),
- se_downsamples: Tuple[bool,
- ...]=(False,
- False,
- False,
- False),
- repmixer_kernel_size: int = 3,
- num_classes: int = 1000,
- pos_embs: Module | None,
- ...]=(None,
- None,
- None,
- None),
- down_patch_size: int = 7,
- down_stride: int = 2,
- drop_rate: float = 0.0,
- proj_drop_rate: float = 0.0,
- drop_path_rate: float = 0.0,
- layer_scale_init_value: float = 1e-05,
- lkc_use_act: bool = False,
- stem_use_scale_branch: bool = True,
- fork_feat: bool = False,
- cls_ratio: float = 2.0,
- global_pool: str = 'avg',
- norm_layer: Type[Module] = <class 'torch.nn.modules.batchnorm.BatchNorm2d'>,
- act_layer: Type[Module] = <class 'torch.nn.modules.activation.GELU'>,
- inference_mode: bool = False,
- device=None,
- dtype=None,
- fork_feat: Final[bool]
This class implements FastViT architecture
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.focalnet.FocalNet(
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: str = 'avg',
- embed_dim: int = 96,
- depths: Tuple[int,
- ...]=(2,
- 2,
- 6,
- 2),
- mlp_ratio: float = 4.0,
- focal_levels: Tuple[int,
- ...]=(2,
- 2,
- 2,
- 2),
- focal_windows: Tuple[int,
- ...]=(3,
- 3,
- 3,
- 3),
- use_overlap_down: bool = False,
- use_post_norm: bool = False,
- use_post_norm_in_modulation: bool = False,
- normalize_modulator: bool = False,
- head_hidden_size: int | None = None,
- head_init_scale: float = 1.0,
- layerscale_value: float | None = None,
- drop_rate: float = 0.0,
- proj_drop_rate: float = 0.0,
- drop_path_rate: float = 0.1,
- norm_layer: Type[Module] = functools.partial(<class 'timm.layers.norm.LayerNorm2d'>,
- eps=1e-05),
- device=None,
- dtype=None,
“ Focal Modulation Networks (FocalNets)
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.gcvit.GlobalContextVit(
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: str = 'avg',
- img_size: int | Tuple[int, int] = 224,
- window_ratio: Tuple[int, ...] = (32, 32, 16, 32),
- window_size: int | Tuple[int, ...] | None = None,
- embed_dim: int = 64,
- depths: Tuple[int, ...] = (3, 4, 19, 5),
- num_heads: Tuple[int, ...] = (2, 4, 8, 16),
- mlp_ratio: float = 3.0,
- qkv_bias: bool = True,
- layer_scale: float | None = None,
- drop_rate: float = 0.0,
- proj_drop_rate: float = 0.0,
- attn_drop_rate: float = 0.0,
- drop_path_rate: float = 0.0,
- weight_init: str = '',
- act_layer: str = 'gelu',
- norm_layer: str = 'layernorm2d',
- norm_layer_cl: str = 'layernorm',
- norm_eps: float = 1e-05,
- device=None,
- dtype=None,
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.ghostnet.GhostNet(
- cfgs: List[List[List[int | float]]],
- num_classes: int = 1000,
- width: float = 1.0,
- in_chans: int = 3,
- output_stride: int = 32,
- global_pool: str = 'avg',
- drop_rate: float = 0.2,
- version: str = 'v1',
- device=None,
- dtype=None,
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.hgnet.HighPerfGpuNet(
- cfg: Dict,
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: str = 'avg',
- head_hidden_size: int | None = 2048,
- drop_rate: float = 0.0,
- drop_path_rate: float = 0.0,
- use_lab: bool = False,
- device=None,
- dtype=None,
- **kwargs,
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.hiera.Hiera(
- img_size: Tuple[int, ...] = (224, 224),
- in_chans: int = 3,
- embed_dim: int = 96,
- num_heads: int = 1,
- num_classes: int = 1000,
- global_pool: str = 'avg',
- stages: Tuple[int, ...] = (2, 3, 16, 3),
- q_pool: int = 3,
- q_stride: Tuple[int, ...] = (2, 2),
- mask_unit_size: Tuple[int, ...] = (8, 8),
- mask_unit_attn: Tuple[bool, ...] = (True, True, False, False),
- use_expand_proj: bool = True,
- dim_mul: float = 2.0,
- head_mul: float = 2.0,
- patch_kernel: Tuple[int, ...] = (7, 7),
- patch_stride: Tuple[int, ...] = (4, 4),
- patch_padding: Tuple[int, ...] = (3, 3),
- mlp_ratio: float = 4.0,
- drop_path_rate: float = 0.0,
- init_values: float | None = None,
- fix_init: bool = True,
- weight_init: str = '',
- norm_layer: str | Type[Module] = 'LayerNorm',
- drop_rate: float = 0.0,
- patch_drop_rate: float = 0.0,
- head_init_scale: float = 0.001,
- sep_pos_embed: bool = False,
- abs_win_pos_embed: bool = False,
- global_pos_size: Tuple[int, int] = (14, 14),
- device=None,
- dtype=None,
- forward_features(
- x: Tensor,
- mask: Tensor | None = None,
- return_intermediates: bool = False,
mask should be a boolean tensor of shape [B, #MUt*#MUy*#MUx] where #MU are the number of mask units in that dim. Note: 1 in mask is keep, 0 is remove; mask.sum(dim=-1) should be the same across the batch.
- forward_intermediates(
- x: Tensor,
- mask: Tensor | None = None,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = True,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
- coarse: bool = True,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to all intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- get_random_mask(x: Tensor, mask_ratio: float) Tensor
Generates a random mask, mask_ratio fraction are dropped. 1 is keep, 0 is remove. Useful for MAE, FLIP, etc.
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
- coarse: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.hrnet.HighResolutionNet(
- cfg: Dict,
- in_chans: int = 3,
- num_classes: int = 1000,
- output_stride: int = 32,
- global_pool: str = 'avg',
- drop_rate: float = 0.0,
- head: str = 'classification',
- device=None,
- dtype=None,
- **kwargs,
- class timm.models.inception_next.MetaNeXt
A PyTorch impl of : InceptionNeXt: When Inception Meets ConvNeXt - https://arxiv.org/abs/2303.16900
- Parameters:
in_chans (int) – Number of input image channels. Default: 3
num_classes (int) – Number of classes for classification head. Default: 1000
depths (tuple(int)) – Number of blocks at each stage. Default: (3, 3, 9, 3)
dims (tuple(int)) – Feature dimension at each stage. Default: (96, 192, 384, 768)
token_mixers – Token mixer function. Default: nn.Identity
norm_layer – Normalization layer. Default: nn.BatchNorm2d
act_layer – Activation function for MLP. Default: nn.GELU
mlp_ratios (int or tuple(int)) – MLP ratios. Default: (4, 4, 4, 3)
drop_rate (float) – Head dropout rate
drop_path_rate (float) – Stochastic depth rate. Default: 0.
ls_init_value (float) – Init value for Layer Scale. Default: 1e-6.
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.inception_resnet_v2.InceptionResnetV2(
- num_classes: int = 1000,
- in_chans: int = 3,
- drop_rate: float = 0.0,
- output_stride: int = 32,
- global_pool: str = 'avg',
- norm_layer: str = 'batchnorm2d',
- norm_eps: float = 0.001,
- act_layer: str = 'relu',
- device=None,
- dtype=None,
- class timm.models.inception_v3.InceptionV3(
- num_classes: int = 1000,
- in_chans: int = 3,
- drop_rate: float = 0.0,
- global_pool: str = 'avg',
- aux_logits: bool = False,
- norm_layer: str = 'batchnorm2d',
- norm_eps: float = 0.001,
- act_layer: str = 'relu',
- device=None,
- dtype=None,
Inception-V3
- class timm.models.inception_v4.InceptionV4(
- num_classes: int = 1000,
- in_chans: int = 3,
- output_stride: int = 32,
- drop_rate: float = 0.0,
- global_pool: str = 'avg',
- norm_layer: str = 'batchnorm2d',
- norm_eps: float = 0.001,
- act_layer: str = 'relu',
- device=None,
- dtype=None,
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.levit.Levit(
- img_size: int | Tuple[int, int] = 224,
- in_chans: int = 3,
- num_classes: int = 1000,
- embed_dim: Tuple[int, ...] = (192,),
- key_dim: int = 64,
- depth: Tuple[int, ...] = (12,),
- num_heads: int | Tuple[int, ...] = (3,),
- attn_ratio: float | Tuple[float, ...] = 2.0,
- mlp_ratio: float | Tuple[float, ...] = 2.0,
- stem_backbone: Module | None = None,
- stem_stride: int | None = None,
- stem_type: str = 's16',
- down_op: str = 'subsample',
- act_layer: str = 'hard_swish',
- attn_act_layer: str | None = None,
- use_conv: bool = False,
- global_pool: str = 'avg',
- drop_rate: float = 0.0,
- drop_path_rate: float = 0.0,
- device=None,
- dtype=None,
Vision Transformer with support for patch or hybrid CNN input stage
NOTE: distillation is defaulted to True since pretrained weights use it, will cause problems w/ train scripts that don’t take tuple outputs,
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.maxxvit.MaxxVitCfg(embed_dim: ~typing.Tuple[int, ...] = (96, 192, 384, 768), depths: ~typing.Tuple[int, ...] = (2, 3, 5, 2), block_type: ~typing.Tuple[str | ~typing.Tuple[str, ...], ...] = ('C', 'C', 'T', 'T'), stem_width: int | ~typing.Tuple[int, int] = 64, stem_bias: bool = False, conv_cfg: ~timm.models.maxxvit.MaxxVitConvCfg = <factory>, transformer_cfg: ~timm.models.maxxvit.MaxxVitTransformerCfg = <factory>, head_hidden_size: int | None = None, weight_init: str = 'vit_eff')
Configuration for MaxxVit models.
- class timm.models.metaformer.MetaFormer
- A PyTorch impl ofMetaFormer Baselines for Vision -
- Parameters:
in_chans (int) – Number of input image channels.
num_classes (int) – Number of classes for classification head.
global_pool – Pooling for classifier head.
depths (list or tuple) – Number of blocks at each stage.
dims (list or tuple) – Feature dimension at each stage.
token_mixers (list, tuple or token_fcn) – Token mixer for each stage.
mlp_act – Activation layer for MLP.
mlp_bias (boolean) – Enable or disable mlp bias term.
drop_path_rate (float) – Stochastic depth rate.
drop_rate (float) – Dropout rate.
layer_scale_init_values (list, tuple, float or None) – Init value for Layer Scale. None means not use the layer scale. Form: https://arxiv.org/abs/2103.17239.
res_scale_init_values (list, tuple, float or None) – Init value for res Scale on residual connections. None means not use the res scale. From: https://arxiv.org/abs/2110.09456.
downsample_norm (nn.Module) – Norm layer used in stem and downsampling layers.
norm_layers (list, tuple or norm_fcn) – Norm layers for each stage.
output_norm – Norm layer before classifier head.
use_mlp_head – Use MLP classification head.
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.mobilenetv3.MobileNetV3(block_args: ~typing.List[~typing.List[~typing.Dict[str, ~typing.Any]]], num_classes: int = 1000, in_chans: int = 3, stem_size: int = 16, fix_stem: bool = False, num_features: int = 1280, head_bias: bool = True, head_norm: bool = False, pad_type: str = '', act_layer: str | ~typing.Callable | ~typing.Type[~torch.nn.modules.module.Module] | None = None, norm_layer: str | ~typing.Callable | ~typing.Type[~torch.nn.modules.module.Module] | None = None, aa_layer: str | ~typing.Callable | ~typing.Type[~torch.nn.modules.module.Module] | None = None, se_layer: str | ~typing.Callable | ~typing.Type[~torch.nn.modules.module.Module] | None = None, se_from_exp: bool = True, round_chs_fn: ~typing.Callable = <function round_channels>, drop_rate: float = 0.0, drop_path_rate: float = 0.0, layer_scale_init_value: float | None = None, global_pool: str = 'avg', device=None, dtype=None)
MobileNetV3.
Based on my EfficientNet implementation and building blocks, this model utilizes the MobileNet-v3 specific ‘efficient head’, where global pooling is done before the head convolution without a final batch-norm layer before the classifier.
Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244
- Other architectures utilizing MobileNet-V3 efficient head that are supported by this impl include:
HardCoRe-NAS - https://arxiv.org/abs/2102.11646 (defn in hardcorenas.py uses this class)
FBNet-V3 - https://arxiv.org/abs/2006.02049
LCNet - https://arxiv.org/abs/2109.15099
MobileNet-V4 - https://arxiv.org/abs/2404.10518
- as_sequential() Sequential
Convert model to sequential form.
- Returns:
Sequential module containing all layers.
- forward_features(x: Tensor) Tensor
Forward pass through feature extraction layers.
- Parameters:
x – Input tensor.
- Returns:
Feature tensor.
- forward_head(
- x: Tensor,
- pre_logits: bool = False,
Forward pass through classifier head.
- Parameters:
x – Input features.
pre_logits – Return features before final linear layer.
- Returns:
Classification logits or features.
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
- extra_blocks: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
extra_blocks – Include outputs of all blocks and head conv in output, does not align with feature_info
Returns:
- get_classifier() Module
Get the classifier head.
- group_matcher(
- coarse: bool = False,
Group parameters for optimization.
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
- extra_blocks: bool = False,
Prune layers not required for specified intermediates.
- Parameters:
indices – Indices of intermediate layers to keep.
prune_norm – Whether to prune normalization layer.
prune_head – Whether to prune the classifier head.
extra_blocks – Include outputs of all blocks.
- Returns:
List of indices that were kept.
- reset_classifier(num_classes: int, global_pool: str = 'avg') None
Reset the classifier head.
- Parameters:
num_classes – Number of classes for new classifier.
global_pool – Global pooling type.
- set_grad_checkpointing(enable: bool = True) None
Enable or disable gradient checkpointing.
- class timm.models.mobilenetv5.MobileNetV5(block_args: ~typing.List[~typing.List[~typing.Dict[str, ~typing.Any]]], num_classes: int = 1000, in_chans: int = 3, stem_size: int = 16, stem_bias: bool = True, fix_stem: bool = False, num_features: int = 2048, pad_type: str = '', use_msfa: bool = True, msfa_indices: ~typing.List[int] = (-2, -1), msfa_output_resolution: int = 16, act_layer: str | ~typing.Callable | ~typing.Type[~torch.nn.modules.module.Module] | None = None, norm_layer: str | ~typing.Callable | ~typing.Type[~torch.nn.modules.module.Module] | None = None, aa_layer: str | ~typing.Callable | ~typing.Type[~torch.nn.modules.module.Module] | None = None, se_layer: str | ~typing.Callable | ~typing.Type[~torch.nn.modules.module.Module] | None = None, se_from_exp: bool = True, round_chs_fn: ~typing.Callable = <function round_channels>, drop_rate: float = 0.0, drop_path_rate: float = 0.0, layer_scale_init_value: float | None = None, global_pool: str = 'avg', device=None, dtype=None)
MobiletNet-V5
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
- extra_blocks: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
extra_blocks – Include outputs of all blocks and head conv in output, does not align with feature_info
Returns:
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
- extra_blocks: bool = False,
Prune layers not required for specified intermediates.
- class timm.models.mvitv2.MultiScaleVit(
- cfg: MultiScaleVitCfg,
- img_size: Tuple[int, int] = (224, 224),
- in_chans: int = 3,
- global_pool: str | None = None,
- num_classes: int = 1000,
- drop_path_rate: float = 0.0,
- drop_rate: float = 0.0,
- device=None,
- dtype=None,
Improved Multiscale Vision Transformers for Classification and Detection Yanghao Li*, Chao-Yuan Wu*, Haoqi Fan, Karttikeya Mangalam, Bo Xiong, Jitendra Malik,
Christoph Feichtenhofer*
https://arxiv.org/abs/2112.01526
Multiscale Vision Transformers Haoqi Fan*, Bo Xiong*, Karttikeya Mangalam*, Yanghao Li*, Zhicheng Yan, Jitendra Malik,
Christoph Feichtenhofer*
https://arxiv.org/abs/2104.11227
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to all intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.naflexvit.NaFlexVitCfg(
- patch_size: int | Tuple[int, int] = 16,
- embed_dim: int = 768,
- depth: int = 12,
- num_heads: int = 12,
- mlp_ratio: float = 4.0,
- scale_mlp_norm: bool = False,
- qkv_bias: bool = True,
- qk_norm: bool = False,
- proj_bias: bool = True,
- attn_drop_rate: float = 0.0,
- scale_attn_inner_norm: bool = False,
- init_values: float | None = None,
- drop_rate: float = 0.0,
- pos_drop_rate: float = 0.0,
- patch_drop_rate: float = 0.0,
- proj_drop_rate: float = 0.0,
- drop_path_rate: float = 0.0,
- class_token: bool = False,
- reg_tokens: int = 0,
- pos_embed: str = 'learned',
- pos_embed_grid_size: Tuple[int, int] | None = (16, 16),
- pos_embed_interp_mode: str = 'bicubic',
- pos_embed_ar_preserving: bool = False,
- pos_embed_use_grid_sample: bool = False,
- rope_type: str = '',
- rope_temperature: float = 10000.0,
- rope_ref_feat_shape: Tuple[int, int] | None = None,
- rope_grid_offset: float = 0.0,
- rope_grid_indexing: str = 'ij',
- rope_rotate_half: bool = False,
- dynamic_img_pad: bool = False,
- pre_norm: bool = False,
- final_norm: bool = True,
- fc_norm: bool | None = None,
- global_pool: str = 'map',
- pool_include_prefix: bool = False,
- attn_pool_num_heads: int | None = None,
- attn_pool_mlp_ratio: float | None = None,
- weight_init: str = '',
- fix_init: bool = True,
- embed_proj_type: str = 'linear',
- input_norm_layer: str | None = None,
- embed_norm_layer: str | None = None,
- norm_layer: str | None = None,
- act_layer: str | None = None,
- block_fn: str | None = None,
- mlp_layer: str | None = None,
- attn_layer: str | None = None,
- attn_type: str = 'standard',
- swiglu_mlp: bool = False,
- qkv_fused: bool = True,
- enable_patch_interpolator: bool = False,
Configuration for FlexVit model.
This dataclass contains the bulk of model configuration parameters, with core parameters (img_size, in_chans, num_classes, etc.) remaining as direct constructor arguments for API compatibility.
- class timm.models.nasnet.NASNetALarge(6 @ 4032)
- class timm.models.nest.Nest(
- img_size: int = 224,
- in_chans: int = 3,
- patch_size: int = 4,
- num_levels: int = 3,
- embed_dims: Tuple[int, ...] = (128, 256, 512),
- num_heads: Tuple[int, ...] = (4, 8, 16),
- depths: Tuple[int, ...] = (2, 2, 20),
- num_classes: int = 1000,
- mlp_ratio: float = 4.0,
- qkv_bias: bool = True,
- drop_rate: float = 0.0,
- proj_drop_rate: float = 0.0,
- attn_drop_rate: float = 0.0,
- drop_path_rate: float = 0.5,
- norm_layer: Type[Module] | None = None,
- act_layer: Type[Module] | None = None,
- pad_type: str = '',
- weight_init: str = '',
- global_pool: str = 'avg',
- device=None,
- dtype=None,
Nested Transformer (NesT)
- A PyTorch impl ofAggregating Nested Transformers
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.nextvit.NextViT(
- in_chans: int,
- num_classes: int = 1000,
- global_pool: str = 'avg',
- stem_chs: Tuple[int,
- ...]=(64,
- 32,
- 64),
- depths: Tuple[int,
- ...]=(3,
- 4,
- 10,
- 3),
- strides: Tuple[int,
- ...]=(1,
- 2,
- 2,
- 2),
- sr_ratios: Tuple[int,
- ...]=(8,
- 4,
- 2,
- 1),
- drop_path_rate: float = 0.1,
- attn_drop_rate: float = 0.0,
- drop_rate: float = 0.0,
- head_dim: int = 32,
- mix_block_ratio: float = 0.75,
- norm_layer: Type[Module] = <class 'torch.nn.modules.batchnorm.BatchNorm2d'>,
- act_layer: Type[Module] | None = None,
- device=None,
- dtype=None,
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.nfnet.NormFreeNet(
- cfg: NfCfg,
- num_classes: int = 1000,
- in_chans: int = 3,
- global_pool: str = 'avg',
- output_stride: int = 32,
- drop_rate: float = 0.0,
- drop_path_rate: float = 0.0,
- device=None,
- dtype=None,
- **kwargs: Any,
Normalization-Free Network
As described in : Characterizing signal propagation to close the performance gap in unnormalized ResNets
and High-Performance Large-Scale Image Recognition Without Normalization - https://arxiv.org/abs/2102.06171
This model aims to cover both the NFRegNet-Bx models as detailed in the paper’s code snippets and the (preact) ResNet models described earlier in the paper.
- There are a few differences:
- channels are rounded to be divisible by 8 by default (keep tensor core kernels happy),
this changes channel dim and param counts slightly from the paper models
- activation correcting gamma constants are moved into the ScaledStdConv as it has less performance
impact in PyTorch when done with the weight scaling there. This likely wasn’t a concern in the JAX impl.
- a config option gamma_in_act can be enabled to not apply gamma in StdConv as described above, but
apply it in each activation. This is slightly slower, numerically different, but matches official impl.
- skipinit is disabled by default, it seems to have a rather drastic impact on GPU memory use and throughput
for what it is/does. Approx 8-10% throughput loss.
- forward_features(x: Tensor) Tensor
Forward pass through feature extraction layers.
- Parameters:
x – Input tensor.
- Returns:
Feature tensor.
- forward_head(
- x: Tensor,
- pre_logits: bool = False,
Forward pass through classifier head.
- Parameters:
x – Input features.
pre_logits – Return features before final linear layer.
- Returns:
Classification logits or features.
- get_classifier() Module
Get the classifier head.
- group_matcher(
- coarse: bool = False,
Group parameters for optimization.
- reset_classifier(
- num_classes: int,
- global_pool: str | None = None,
Reset the classifier head.
- Parameters:
num_classes – Number of classes for new classifier.
global_pool – Global pooling type.
- set_grad_checkpointing(enable: bool = True) None
Enable or disable gradient checkpointing.
- class timm.models.pit.PoolingVisionTransformer(
- img_size: int = 224,
- patch_size: int = 16,
- stride: int = 8,
- stem_type: str = 'overlap',
- base_dims: Sequence[int] = (48, 48, 48),
- depth: Sequence[int] = (2, 6, 4),
- heads: Sequence[int] = (2, 4, 8),
- mlp_ratio: float = 4,
- num_classes: int = 1000,
- in_chans: int = 3,
- global_pool: str = 'token',
- distilled: bool = False,
- drop_rate: float = 0.0,
- pos_drop_drate: float = 0.0,
- proj_drop_rate: float = 0.0,
- attn_drop_rate: float = 0.0,
- drop_path_rate: float = 0.0,
- device=None,
- dtype=None,
Pooling-based Vision Transformer
- A PyTorch implement of ‘Rethinking Spatial Dimensions of Vision Transformers’
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.pnasnet.PNASNet5Large(
- num_classes: int = 1000,
- in_chans: int = 3,
- output_stride: int = 32,
- drop_rate: float = 0.0,
- global_pool: str = 'avg',
- pad_type: str = '',
- device=None,
- dtype=None,
- class timm.models.pvt_v2.PyramidVisionTransformerV2(
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: str = 'avg',
- depths: Tuple[int,
- ...]=(3,
- 4,
- 6,
- 3),
- embed_dims: Tuple[int,
- ...]=(64,
- 128,
- 256,
- 512),
- num_heads: Tuple[int,
- ...]=(1,
- 2,
- 4,
- 8),
- sr_ratios: Tuple[int,
- ...]=(8,
- 4,
- 2,
- 1),
- mlp_ratios: Tuple[float,
- ...]=(8.0,
- 8.0,
- 4.0,
- 4.0),
- qkv_bias: bool = True,
- linear: bool = False,
- drop_rate: float = 0.0,
- proj_drop_rate: float = 0.0,
- attn_drop_rate: float = 0.0,
- drop_path_rate: float = 0.0,
- norm_layer: Type[Module] = <class 'timm.layers.norm.LayerNorm'>,
- device=None,
- dtype=None,
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.rdnet.RDNet(
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: str = 'avg',
- growth_rates: List[int] | Tuple[int] = (64, 104, 128, 128, 128, 128, 224),
- num_blocks_list: List[int] | Tuple[int] = (3, 3, 3, 3, 3, 3, 3),
- block_type: List[int] | Tuple[int] = ('Block', 'Block', 'BlockESE', 'BlockESE', 'BlockESE', 'BlockESE', 'BlockESE'),
- is_downsample_block: List[bool] | Tuple[bool] = (None, True, True, False, False, False, True),
- bottleneck_width_ratio: float = 4.0,
- transition_compression_ratio: float = 0.5,
- ls_init_value: float = 1e-06,
- stem_type: str = 'patch',
- patch_size: int = 4,
- num_init_features: int = 64,
- head_init_scale: float = 1.0,
- head_norm_first: bool = False,
- conv_bias: bool = True,
- act_layer: str | Callable = 'gelu',
- norm_layer: str = 'layernorm2d',
- norm_eps: float | None = None,
- drop_rate: float = 0.0,
- drop_path_rate: float = 0.0,
- device=None,
- dtype=None,
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.regnet.RegNet(
- cfg: RegNetCfg,
- in_chans: int = 3,
- num_classes: int = 1000,
- output_stride: int = 32,
- global_pool: str = 'avg',
- drop_rate: float = 0.0,
- drop_path_rate: float = 0.0,
- zero_init_last: bool = True,
- device=None,
- dtype=None,
- **kwargs,
RegNet-X, Y, and Z Models.
Paper: https://arxiv.org/abs/2003.13678 Original Impl: https://github.com/facebookresearch/pycls/blob/master/pycls/models/regnet.py
- forward_features(x: Tensor) Tensor
Forward pass through feature extraction layers.
- Parameters:
x – Input tensor.
- Returns:
Feature tensor.
- forward_head(x: Tensor, pre_logits: bool = False) Tensor
Forward pass through classifier head.
- Parameters:
x – Input features.
pre_logits – Return features before final linear layer.
- Returns:
Classification logits or features.
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- get_classifier() Module
Get the classifier head.
- group_matcher(coarse: bool = False) Dict[str, Any]
Group parameters for optimization.
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- Parameters:
indices – Indices of intermediate layers to keep.
prune_norm – Whether to prune normalization layer.
prune_head – Whether to prune the classifier head.
- Returns:
List of indices that were kept.
- reset_classifier(
- num_classes: int,
- global_pool: str | None = None,
Reset the classifier head.
- Parameters:
num_classes – Number of classes for new classifier.
global_pool – Global pooling type.
- set_grad_checkpointing(enable: bool = True) None
Enable or disable gradient checkpointing.
- class timm.models.repghost.RepGhostNet(
- cfgs: List[List[List]],
- num_classes: int = 1000,
- width: float = 1.0,
- in_chans: int = 3,
- output_stride: int = 32,
- global_pool: str = 'avg',
- drop_rate: float = 0.2,
- reparam: bool = True,
- device=None,
- dtype=None,
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.repvit.RepVit(
- in_chans: int = 3,
- img_size: int = 224,
- embed_dim: Tuple[int,
- ...]=(48,
- ),
- depth: Tuple[int,
- ...]=(2,
- ),
- mlp_ratio: float = 2,
- global_pool: str = 'avg',
- kernel_size: int = 3,
- num_classes: int = 1000,
- act_layer: Type[Module] = <class 'torch.nn.modules.activation.GELU'>,
- distillation: bool = True,
- drop_rate: float = 0.0,
- legacy: bool = False,
- device=None,
- dtype=None,
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.resnet.ResNet(block: ~timm.models.resnet.BasicBlock | ~timm.models.resnet.Bottleneck, layers: ~typing.Tuple[int, ...], num_classes: int = 1000, in_chans: int = 3, output_stride: int = 32, global_pool: str = 'avg', cardinality: int = 1, base_width: int = 64, stem_width: int = 64, stem_type: str = '', replace_stem_pool: bool = False, block_reduce_first: int = 1, down_kernel_size: int = 1, avg_down: bool = False, channels: ~typing.Tuple[int, ...] | None = (64, 128, 256, 512), act_layer: str | ~typing.Callable | ~typing.Type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.ReLU'>, norm_layer: str | ~typing.Callable | ~typing.Type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.batchnorm.BatchNorm2d'>, aa_layer: ~typing.Type[~torch.nn.modules.module.Module] | None = None, drop_rate: float = 0.0, drop_path_rate: float = 0.0, drop_block_rate: float = 0.0, zero_init_last: bool = True, block_args: ~typing.Dict[str, ~typing.Any] | None = None, device=None, dtype=None)
ResNet / ResNeXt / SE-ResNeXt / SE-Net
- This class implements all variants of ResNet, ResNeXt, SE-ResNeXt, and SENet that
have > 1 stride in the 3x3 conv layer of bottleneck
have conv-bn-act ordering
This ResNet impl supports a number of stem and downsample options based on the v1c, v1d, v1e, and v1s variants included in the MXNet Gluon ResNetV1b model. The C and D variants are also discussed in the ‘Bag of Tricks’ paper: https://arxiv.org/pdf/1812.01187. The B variant is equivalent to torchvision default.
- ResNet variants (the same modifications can be used in SE/ResNeXt models as well):
normal, b - 7x7 stem, stem_width = 64, same as torchvision ResNet, NVIDIA ResNet ‘v1.5’, Gluon v1b
c - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64)
d - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64), average pool in downsample
e - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128), average pool in downsample
s - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128)
t - 3 layer deep 3x3 stem, stem width = 32 (24, 48, 64), average pool in downsample
tn - 3 layer deep 3x3 stem, stem width = 32 (24, 32, 64), average pool in downsample
- ResNeXt
normal - 7x7 stem, stem_width = 64, standard cardinality and base widths
same c,d, e, s variants as ResNet can be enabled
- SE-ResNeXt
normal - 7x7 stem, stem_width = 64
same c, d, e, s variants as ResNet can be enabled
- SENet-154 - 3 layer deep 3x3 stem (same as v1c-v1s), stem_width = 64, cardinality=64,
reduction by 2 on width of first bottleneck convolution, 3x3 downsample convs after first block
- forward_features(x: Tensor) Tensor
Forward pass through feature extraction layers.
- forward_head(x: Tensor, pre_logits: bool = False) Tensor
Forward pass through classifier head.
- Parameters:
x – Feature tensor.
pre_logits – Return features before final classifier layer.
- Returns:
Output tensor.
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor.
indices – Take last n blocks if int, all if None, select matching indices if sequence.
norm – Apply norm layer to compatible intermediates.
stop_early – Stop iterating over blocks when last desired intermediate hit.
output_fmt – Shape of intermediate feature outputs.
intermediates_only – Only return intermediate features.
- Returns:
Features and list of intermediate features or just intermediate features.
- get_classifier(
- name_only: bool = False,
Get the classifier module.
- Parameters:
name_only – Return classifier module name instead of module.
- Returns:
Classifier module or name.
- group_matcher(coarse: bool = False) Dict[str, str]
Create regex patterns for parameter grouping.
- Parameters:
coarse – Use coarse (stage-level) or fine (block-level) grouping.
- Returns:
Dictionary mapping group names to regex patterns.
- init_weights(zero_init_last: bool = True) None
Initialize model weights.
- Parameters:
zero_init_last – Zero-initialize the last BN in each residual branch.
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- Parameters:
indices – Indices of intermediate layers to keep.
prune_norm – Whether to prune normalization layers.
prune_head – Whether to prune the classifier head.
- Returns:
List of indices that were kept.
- reset_classifier(num_classes: int, global_pool: str = 'avg') None
Reset the classifier head.
- Parameters:
num_classes – Number of classes for new classifier.
global_pool – Global pooling type.
- set_grad_checkpointing(enable: bool = True) None
Enable or disable gradient checkpointing.
- Parameters:
enable – Whether to enable gradient checkpointing.
- class timm.models.resnetv2.ResNetV2(layers: ~typing.List[int], channels: ~typing.Tuple[int, ...] = (256, 512, 1024, 2048), num_classes: int = 1000, in_chans: int = 3, global_pool: str = 'avg', output_stride: int = 32, width_factor: int = 1, stem_chs: int = 64, stem_type: str = '', avg_down: bool = False, preact: bool = True, basic: bool = False, bottle_ratio: float = 0.25, act_layer: ~typing.Callable = <class 'torch.nn.modules.activation.ReLU'>, norm_layer: ~typing.Callable = functools.partial(<class 'timm.layers.norm_act.GroupNormAct'>, num_groups=32), conv_layer: ~typing.Callable = <class 'timm.layers.std_conv.StdConv2d'>, drop_rate: float = 0.0, drop_path_rate: float = 0.0, zero_init_last: bool = False, device=None, dtype=None)
Implementation of Pre-activation (v2) ResNet mode.
- forward_features(x: Tensor) Tensor
Forward pass through feature extraction layers.
- Parameters:
x – Input tensor.
- Returns:
Feature tensor.
- forward_head(
- x: Tensor,
- pre_logits: bool = False,
Forward pass through classifier head.
- Parameters:
x – Input features.
pre_logits – Return features before final linear layer.
- Returns:
Classification logits or features.
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- get_classifier() Module
Get the classifier head.
- group_matcher(coarse: bool = False) Dict[str, Any]
Group parameters for optimization.
- init_weights(zero_init_last: bool = True) None
Initialize model weights.
- load_pretrained(checkpoint_path: str, prefix: str = 'resnet/') None
Load pretrained weights.
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- reset_classifier(
- num_classes: int,
- global_pool: str | None = None,
Reset the classifier head.
- Parameters:
num_classes – Number of classes for new classifier.
global_pool – Global pooling type.
- set_grad_checkpointing(enable: bool = True) None
Enable or disable gradient checkpointing.
- class timm.models.rexnet.RexNet(
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: str = 'avg',
- output_stride: int = 32,
- initial_chs: int = 16,
- final_chs: int = 180,
- width_mult: float = 1.0,
- depth_mult: float = 1.0,
- se_ratio: float = 0.08333333333333333,
- ch_div: int = 1,
- act_layer: str = 'swish',
- dw_act_layer: str = 'relu6',
- drop_rate: float = 0.2,
- drop_path_rate: float = 0.0,
- device=None,
- dtype=None,
ReXNet model architecture.
Based on ReXNet: Diminishing Representational Bottleneck on Convolutional Neural Network - https://arxiv.org/abs/2007.00992
- forward_features(x: Tensor) Tensor
Forward pass through feature extraction layers.
- Parameters:
x – Input tensor.
- Returns:
Feature tensor.
- forward_head(x: Tensor, pre_logits: bool = False) Tensor
Forward pass through head.
- Parameters:
x – Input features.
pre_logits – Return features before final linear layer.
- Returns:
Classification logits or features.
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- get_classifier() Module
Get the classifier module.
- Returns:
Classifier module.
- group_matcher(coarse: bool = False) Dict[str, Any]
Group matcher for parameter groups.
- Parameters:
coarse – Whether to use coarse grouping.
- Returns:
Dictionary of grouped parameters.
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- Parameters:
indices – Indices of intermediate layers to keep.
prune_norm – Whether to prune normalization layer.
prune_head – Whether to prune the classifier head.
- Returns:
List of indices that were kept.
- reset_classifier(
- num_classes: int,
- global_pool: str | None = None,
- device=None,
- dtype=None,
Reset the classifier.
- Parameters:
num_classes – Number of classes for new classifier.
global_pool – Global pooling type.
- set_grad_checkpointing(enable: bool = True) None
Enable or disable gradient checkpointing.
- Parameters:
enable – Whether to enable gradient checkpointing.
- class timm.models.selecsls.SelecSls(
- cfg,
- num_classes: int = 1000,
- in_chans: int = 3,
- drop_rate: float = 0.0,
- global_pool: str = 'avg',
- device=None,
- dtype=None,
SelecSls42 / SelecSls60 / SelecSls84
- Parameters:
cfg (network config dictionary specifying block type, feature, and head args)
num_classes (int, default 1000) – Number of classification classes.
in_chans (int, default 3) – Number of input (color) channels.
drop_rate (float, default 0.) – Dropout probability before classifier, for training
global_pool (str, default 'avg') – Global pooling type. One of ‘avg’, ‘max’, ‘avgmax’, ‘catavgmax’
- class timm.models.senet.SENet(
- block: Type[Module],
- layers: Tuple[int, ...],
- groups: int,
- reduction: int,
- drop_rate: float = 0.2,
- in_chans: int = 3,
- inplanes: int = 64,
- input_3x3: bool = False,
- downsample_kernel_size: int = 1,
- downsample_padding: int = 0,
- num_classes: int = 1000,
- global_pool: str = 'avg',
- device=None,
- dtype=None,
- class timm.models.sequencer.Sequencer2d(
- num_classes: int = 1000,
- img_size: int = 224,
- in_chans: int = 3,
- global_pool: str = 'avg',
- layers: Tuple[int,
- ...]=(4,
- 3,
- 8,
- 3),
- patch_sizes: Tuple[int,
- ...]=(7,
- 2,
- 2,
- 1),
- embed_dims: Tuple[int,
- ...]=(192,
- 384,
- 384,
- 384),
- hidden_sizes: Tuple[int,
- ...]=(48,
- 96,
- 96,
- 96),
- mlp_ratios: Tuple[float,
- ...]=(3.0,
- 3.0,
- 3.0,
- 3.0),
- block_layer: Type[Module] = <class 'timm.models.sequencer.Sequencer2dBlock'>,
- rnn_layer: Type[Module] = <class 'timm.models.sequencer.LSTM2d'>,
- mlp_layer: Type[Module] = <class 'timm.layers.mlp.Mlp'>,
- norm_layer: Type[Module] = functools.partial(<class 'torch.nn.modules.normalization.LayerNorm'>,
- eps=1e-06),
- act_layer: Type[Module] = <class 'torch.nn.modules.activation.GELU'>,
- num_rnn_layers: int = 1,
- bidirectional: bool = True,
- union: str = 'cat',
- with_fc: bool = True,
- drop_rate: float = 0.0,
- drop_path_rate: float = 0.0,
- nlhb: bool = False,
- stem_norm: bool = False,
- device=None,
- dtype=None,
- class timm.models.shvit.SHViT(
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: str = 'avg',
- embed_dim: Tuple[int,
- int,
- int]=(128,
- 256,
- 384),
- partial_dim: Tuple[int,
- int,
- int]=(32,
- 64,
- 96),
- qk_dim: Tuple[int,
- int,
- int]=(16,
- 16,
- 16),
- depth: Tuple[int,
- int,
- int]=(1,
- 2,
- 3),
- types: Tuple[str,
- str,
- str]=('s',
- 's',
- 's'),
- drop_rate: float = 0.0,
- norm_layer: Type[Module] = <class 'timm.layers.norm.GroupNorm1'>,
- act_layer: Type[Module] = <class 'torch.nn.modules.activation.ReLU'>,
- device=None,
- dtype=None,
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.starnet.StarNet(base_dim: int = 32, depths: ~typing.List[int] = [3, 3, 12, 5], mlp_ratio: int = 4, drop_rate: float = 0.0, drop_path_rate: float = 0.0, act_layer: ~typing.Type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.ReLU6'>, num_classes: int = 1000, in_chans: int = 3, global_pool: str = 'avg', output_stride: int = 32, device=None, dtype=None, **kwargs)
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.swiftformer.SwiftFormer(layers: ~typing.List[int] = [3, 3, 6, 4], embed_dims: ~typing.List[int] = [48, 56, 112, 220], mlp_ratios: int = 4, downsamples: ~typing.List[bool] = [False, True, True, True], act_layer: ~typing.Type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.GELU'>, down_patch_size: int = 3, down_stride: int = 2, down_pad: int = 1, num_classes: int = 1000, drop_rate: float = 0.0, drop_path_rate: float = 0.0, use_layer_scale: bool = True, layer_scale_init_value: float = 1e-05, global_pool: str = 'avg', output_stride: int = 32, in_chans: int = 3, device=None, dtype=None, **kwargs)
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.swin_transformer.SwinTransformer(
- img_size: int | ~typing.Tuple[int,
- int]=224,
- patch_size: int = 4,
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: str = 'avg',
- embed_dim: int = 96,
- depths: Tuple[int,
- ...]=(2,
- 2,
- 6,
- 2),
- num_heads: Tuple[int,
- ...]=(3,
- 6,
- 12,
- 24),
- head_dim: int | None = None,
- window_size: int | ~typing.Tuple[int,
- int]=7,
- always_partition: bool = False,
- strict_img_size: bool = True,
- mlp_ratio: float = 4.0,
- qkv_bias: bool = True,
- drop_rate: float = 0.0,
- proj_drop_rate: float = 0.0,
- attn_drop_rate: float = 0.0,
- drop_path_rate: float = 0.1,
- embed_layer: Type[Module] = <class 'timm.layers.patch_embed.PatchEmbed'>,
- norm_layer: str | Type[Module] = <class 'torch.nn.modules.normalization.LayerNorm'>,
- weight_init: str = '',
- device=None,
- dtype=None,
- **kwargs,
Swin Transformer.
- A PyTorch impl ofSwin Transformer: Hierarchical Vision Transformer using Shifted Windows -
- forward_features(x: Tensor) Tensor
Forward pass through feature extraction layers.
- forward_head(
- x: Tensor,
- pre_logits: bool = False,
Forward pass through classifier head.
- Parameters:
x – Feature tensor.
pre_logits – Return features before final classifier.
- Returns:
Output tensor.
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor.
indices – Take last n blocks if int, all if None, select matching indices if sequence.
norm – Apply norm layer to compatible intermediates.
stop_early – Stop iterating over blocks when last desired intermediate hit.
output_fmt – Shape of intermediate feature outputs.
intermediates_only – Only return intermediate features.
- Returns:
List of intermediate features or tuple of (final features, intermediates).
- get_classifier() Module
Get the classifier head.
- group_matcher(
- coarse: bool = False,
Group parameters for optimization.
- init_weights(mode: str = '', needs_reset: bool = True) None
Initialize model weights.
- Parameters:
mode – Weight initialization mode (‘jax’, ‘jax_nlhb’, ‘moco’, or ‘’).
needs_reset – If True, call reset_parameters() on modules that have it. Set to False when modules have already self-initialized in __init__.
- no_weight_decay() Set[str]
Parameters that should not use weight decay.
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- Parameters:
indices – Indices of intermediate layers to keep.
prune_norm – Whether to prune normalization layer.
prune_head – Whether to prune the classifier head.
- Returns:
List of indices that were kept.
- reset_classifier(
- num_classes: int,
- global_pool: str | None = None,
Reset the classifier head.
- Parameters:
num_classes – Number of classes for new classifier.
global_pool – Global pooling type.
- set_grad_checkpointing(enable: bool = True) None
Enable or disable gradient checkpointing.
- set_input_size(
- img_size: Tuple[int, int] | None = None,
- patch_size: Tuple[int, int] | None = None,
- window_size: Tuple[int, int] | None = None,
- window_ratio: int = 8,
- always_partition: bool | None = None,
Update the image resolution and window size.
- Parameters:
img_size – New input resolution, if None current resolution is used.
patch_size – New patch size, if None use current patch size.
window_size – New window size, if None based on new_img_size // window_div.
window_ratio – Divisor for calculating window size from grid size.
always_partition – Always partition into windows and shift (even if window size < feat size).
- class timm.models.swin_transformer_v2.SwinTransformerV2(
- img_size: int | ~typing.Tuple[int,
- int]=224,
- patch_size: int = 4,
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: str = 'avg',
- embed_dim: int = 96,
- depths: Tuple[int,
- ...]=(2,
- 2,
- 6,
- 2),
- num_heads: Tuple[int,
- ...]=(3,
- 6,
- 12,
- 24),
- window_size: int | ~typing.Tuple[int,
- int]=7,
- always_partition: bool = False,
- strict_img_size: bool = True,
- mlp_ratio: float = 4.0,
- qkv_bias: bool = True,
- drop_rate: float = 0.0,
- proj_drop_rate: float = 0.0,
- attn_drop_rate: float = 0.0,
- drop_path_rate: float = 0.1,
- act_layer: str | Callable = 'gelu',
- norm_layer: Type[Module] = <class 'torch.nn.modules.normalization.LayerNorm'>,
- pretrained_window_sizes: Tuple[int,
- ...]=(0,
- 0,
- 0,
- 0),
- device=None,
- dtype=None,
- **kwargs,
Swin Transformer V2.
A hierarchical vision transformer using shifted windows for efficient self-attention computation with continuous position bias.
- A PyTorch impl ofSwin Transformer V2: Scaling Up Capacity and Resolution
- forward_features(x: Tensor) Tensor
Forward pass through feature extraction layers.
- Parameters:
x – Input tensor of shape (B, C, H, W).
- Returns:
Feature tensor of shape (B, H’, W’, C).
- forward_head(
- x: Tensor,
- pre_logits: bool = False,
Forward pass through classification head.
- Parameters:
x – Feature tensor of shape (B, H, W, C).
pre_logits – If True, return features before final linear layer.
- Returns:
Logits tensor of shape (B, num_classes) or pre-logits.
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- get_classifier() Module
Get the classifier head.
- Returns:
The classification head module.
- group_matcher(
- coarse: bool = False,
Create parameter group matcher for optimizer parameter groups.
- Parameters:
coarse – If True, use coarse grouping.
- Returns:
Dictionary mapping group names to regex patterns.
- init_weights(needs_reset: bool = True) None
Initialize model weights.
- Parameters:
needs_reset – If True, call reset_parameters() on modules (default for after to_empty()). If False, skip reset_parameters() (for __init__ where modules already self-initialized).
- no_weight_decay() Set[str]
Get parameter names that should not use weight decay.
- Returns:
Set of parameter names to exclude from weight decay.
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- reset_classifier(
- num_classes: int,
- global_pool: str | None = None,
Reset the classification head.
- Parameters:
num_classes – Number of classes for new head.
global_pool – Global pooling type.
- set_grad_checkpointing(enable: bool = True) None
Enable or disable gradient checkpointing.
- Parameters:
enable – If True, enable gradient checkpointing.
- set_input_size(
- img_size: Tuple[int, int] | None = None,
- patch_size: Tuple[int, int] | None = None,
- window_size: Tuple[int, int] | None = None,
- window_ratio: int | None = 8,
- always_partition: bool | None = None,
Updates the image resolution, window size, and so the pair-wise relative positions.
- Parameters:
img_size (Optional[Tuple[int, int]]) – New input resolution, if None current resolution is used
patch_size (Optional[Tuple[int, int]) – New patch size, if None use current patch size
window_size (Optional[int]) – New window size, if None based on new_img_size // window_div
window_ratio (int) – divisor for calculating window size from patch grid size
always_partition – always partition / shift windows even if feat size is < window
- class timm.models.swin_transformer_v2_cr.SwinTransformerV2Cr(
- img_size: Tuple[int,
- int]=(224,
- 224),
- patch_size: int = 4,
- window_size: int | None = None,
- window_ratio: int = 8,
- always_partition: bool = False,
- strict_img_size: bool = True,
- in_chans: int = 3,
- num_classes: int = 1000,
- embed_dim: int = 96,
- depths: Tuple[int,
- ...]=(2,
- 2,
- 6,
- 2),
- num_heads: Tuple[int,
- ...]=(3,
- 6,
- 12,
- 24),
- mlp_ratio: float = 4.0,
- init_values: float | None = 0.0,
- drop_rate: float = 0.0,
- proj_drop_rate: float = 0.0,
- attn_drop_rate: float = 0.0,
- drop_path_rate: float = 0.0,
- norm_layer: Type[Module] = <class 'torch.nn.modules.normalization.LayerNorm'>,
- extra_norm_period: int = 0,
- extra_norm_stage: bool = False,
- sequential_attn: bool = False,
- global_pool: str = 'avg',
- weight_init: str = 'reset',
- device=None,
- dtype=None,
- **kwargs: Any,
- Swin Transformer V2
- A PyTorch impl ofSwin Transformer V2: Scaling Up Capacity and Resolution -
- Parameters:
img_size – Input resolution.
window_size – Window size. If None, grid_size // window_div
window_ratio – Window size to patch grid ratio.
patch_size – Patch size.
in_chans – Number of input channels.
depths – Depth of the stage (number of layers).
num_heads – Number of attention heads to be utilized.
embed_dim – Patch embedding dimension.
num_classes – Number of output classes.
mlp_ratio – Ratio of the hidden dimension in the FFN to the input channels.
drop_rate – Dropout rate.
proj_drop_rate – Projection dropout rate.
attn_drop_rate – Dropout rate of attention map.
drop_path_rate – Stochastic depth rate.
norm_layer – Type of normalization layer to be utilized.
extra_norm_period – Insert extra norm layer on main branch every N (period) blocks in stage
extra_norm_stage – End each stage with an extra norm layer in main branch
sequential_attn – If true sequential self-attention is performed.
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- get_classifier() Module
Method returns the classification head of the model. :returns: Current classification head :rtype: head (nn.Module)
- init_weights(needs_reset: bool = True) None
Initialize model weights.
- Parameters:
needs_reset – If True, call reset_parameters() on modules (default for after to_empty()). If False, skip reset_parameters() (for __init__ where modules already self-initialized).
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- reset_classifier(
- num_classes: int,
- global_pool: str | None = None,
Method results the classification head
- Parameters:
num_classes (int) – Number of classes to be predicted
global_pool (str) – Unused
- set_input_size(
- img_size: Tuple[int, int] | None = None,
- window_size: Tuple[int, int] | None = None,
- window_ratio: int = 8,
- always_partition: bool | None = None,
Updates the image resolution, window size and so the pair-wise relative positions.
- Parameters:
img_size (Optional[Tuple[int, int]]) – New input resolution, if None current resolution is used
window_size (Optional[int]) – New window size, if None based on new_img_size // window_div
window_ratio (int) – divisor for calculating window size from patch grid size
always_partition – always partition / shift windows even if feat size is < window
- class timm.models.tiny_vit.TinyVit(
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: str = 'avg',
- embed_dims: Tuple[int,
- ...]=(96,
- 192,
- 384,
- 768),
- depths: Tuple[int,
- ...]=(2,
- 2,
- 6,
- 2),
- num_heads: Tuple[int,
- ...]=(3,
- 6,
- 12,
- 24),
- window_sizes: Tuple[int,
- ...]=(7,
- 7,
- 14,
- 7),
- mlp_ratio: float = 4.0,
- drop_rate: float = 0.0,
- drop_path_rate: float = 0.1,
- use_checkpoint: bool = False,
- mbconv_expand_ratio: float = 4.0,
- local_conv_size: int = 3,
- act_layer: Type[Module] = <class 'torch.nn.modules.activation.GELU'>,
- device=None,
- dtype=None,
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.tnt.TNT(
- img_size: int | ~typing.Tuple[int,
- int]=224,
- patch_size: int | ~typing.Tuple[int,
- int]=16,
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: str = 'token',
- embed_dim: int = 768,
- inner_dim: int = 48,
- depth: int = 12,
- num_heads_inner: int = 4,
- num_heads_outer: int = 12,
- mlp_ratio: float = 4.0,
- qkv_bias: bool = False,
- drop_rate: float = 0.0,
- pos_drop_rate: float = 0.0,
- proj_drop_rate: float = 0.0,
- attn_drop_rate: float = 0.0,
- drop_path_rate: float = 0.0,
- norm_layer: Type[Module] = <class 'torch.nn.modules.normalization.LayerNorm'>,
- first_stride: int = 4,
- legacy: bool = False,
- device=None,
- dtype=None,
Transformer in Transformer - https://arxiv.org/abs/2103.00112
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- return_prefix_tokens: bool = False,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if an int, if is a sequence, select by matching indices
return_prefix_tokens – Return both prefix and spatial intermediate tokens
norm – Apply norm layer to all intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.tresnet.TResNet(
- layers: List[int],
- in_chans: int = 3,
- num_classes: int = 1000,
- width_factor: float = 1.0,
- v2: bool = False,
- global_pool: str = 'fast',
- drop_rate: float = 0.0,
- drop_path_rate: float = 0.0,
- device=None,
- dtype=None,
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.twins.Twins(
- img_size: int | ~typing.Tuple[int,
- int]=224,
- patch_size: int = 4,
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: str = 'avg',
- embed_dims: Tuple[int,
- ...]=(64,
- 128,
- 256,
- 512),
- num_heads: Tuple[int,
- ...]=(1,
- 2,
- 4,
- 8),
- mlp_ratios: Tuple[float,
- ...]=(4,
- 4,
- 4,
- 4),
- depths: Tuple[int,
- ...]=(3,
- 4,
- 6,
- 3),
- sr_ratios: Tuple[int,
- ...]=(8,
- 4,
- 2,
- 1),
- wss: Tuple[int,
- ...] | None=None,
- drop_rate: float = 0.0,
- pos_drop_rate: float = 0.0,
- proj_drop_rate: float = 0.0,
- attn_drop_rate: float = 0.0,
- drop_path_rate: float = 0.0,
- norm_layer: Type[Module] = functools.partial(<class 'torch.nn.modules.normalization.LayerNorm'>,
- eps=1e-06),
- block_cls: Any = <class 'timm.models.twins.Block'>,
- device=None,
- dtype=None,
Twins Vision Transformer (Revisiting Spatial Attention)
Adapted from PVT (PyramidVisionTransformer) class at https://github.com/whai362/PVT.git
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates. :param x: Input image tensor :param indices: Take last n blocks if int, all if None, select matching indices if sequence :param norm: Apply norm layer to all intermediates :param stop_early: Stop iterating over blocks when last desired intermediate hit :param output_fmt: Shape of intermediate feature outputs :param intermediates_only: Only return intermediate features
Returns:
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.vgg.VGG(cfg: ~typing.List[~typing.Any], num_classes: int = 1000, in_chans: int = 3, output_stride: int = 32, mlp_ratio: float = 1.0, act_layer: ~typing.Type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.ReLU'>, conv_layer: ~typing.Type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.conv.Conv2d'>, norm_layer: ~typing.Type[~torch.nn.modules.module.Module] | None = None, global_pool: str = 'avg', drop_rate: float = 0.0, device=None, dtype=None)
VGG model architecture.
Based on Very Deep Convolutional Networks for Large-Scale Image Recognition - https://arxiv.org/abs/1409.1556
- forward_features(x: Tensor) Tensor
Forward pass through feature extraction layers.
- Parameters:
x – Input tensor.
- Returns:
Feature tensor.
- forward_head(x: Tensor, pre_logits: bool = False) Tensor
Forward pass through head.
- Parameters:
x – Input features.
pre_logits – Return features before final linear layer.
- Returns:
Classification logits or features.
- get_classifier() Module
Get the classifier module.
- Returns:
Classifier module.
- group_matcher(coarse: bool = False) Dict[str, Any]
Group matcher for parameter groups.
- Parameters:
coarse – Whether to use coarse grouping.
- Returns:
Dictionary of grouped parameters.
- reset_classifier(num_classes: int, global_pool: str | None = None) None
Reset the classifier.
- Parameters:
num_classes – Number of classes for new classifier.
global_pool – Global pooling type.
- set_grad_checkpointing(enable: bool = True) None
Enable or disable gradient checkpointing.
- Parameters:
enable – Whether to enable gradient checkpointing.
- class timm.models.visformer.Visformer(
- img_size: int = 224,
- patch_size: int = 16,
- in_chans: int = 3,
- num_classes: int = 1000,
- init_channels: int | None = 32,
- embed_dim: int = 384,
- depth: int | tuple = 12,
- num_heads: int = 6,
- mlp_ratio: float = 4.0,
- drop_rate: float = 0.0,
- pos_drop_rate: float = 0.0,
- proj_drop_rate: float = 0.0,
- attn_drop_rate: float = 0.0,
- drop_path_rate: float = 0.0,
- norm_layer: Type[Module] = <class 'timm.layers.norm.LayerNorm2d'>,
- attn_stage: str = '111',
- use_pos_embed: bool = True,
- spatial_conv: str = '111',
- vit_stem: bool = False,
- group: int = 8,
- global_pool: str = 'avg',
- conv_init: bool = False,
- embed_norm: Type[Module] | None = None,
- device=None,
- dtype=None,
- class timm.models.vision_transformer.VisionTransformer(
- img_size: int | ~typing.Tuple[int,
- int]=224,
- patch_size: int | ~typing.Tuple[int,
- int]=16,
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: Literal['',
- 'avg',
- 'avgmax',
- 'max',
- 'token',
- 'map',
- 'prr']='token',
- embed_dim: int = 768,
- depth: int = 12,
- num_heads: int = 12,
- mlp_ratio: float = 4.0,
- qkv_bias: bool = True,
- qk_norm: bool = False,
- scale_attn_norm: bool = False,
- scale_mlp_norm: bool = False,
- proj_bias: bool = True,
- init_values: float | None = None,
- class_token: bool = True,
- pos_embed: str = 'learn',
- no_embed_class: bool = False,
- reg_tokens: int = 0,
- pre_norm: bool = False,
- final_norm: bool = True,
- fc_norm: bool | None = None,
- pool_include_prefix: bool = False,
- dynamic_img_size: bool = False,
- dynamic_img_pad: bool = False,
- drop_rate: float = 0.0,
- pos_drop_rate: float = 0.0,
- patch_drop_rate: float = 0.0,
- proj_drop_rate: float = 0.0,
- attn_drop_rate: float = 0.0,
- drop_path_rate: float = 0.0,
- weight_init: Literal['skip',
- 'reset',
- 'jax',
- 'jax_nlhb',
- 'moco',
- '']='',
- fix_init: bool = False,
- embed_layer: Callable = <class 'timm.layers.patch_embed.PatchEmbed'>,
- embed_norm_layer: str | Callable | Type[Module] | None = None,
- norm_layer: str | Callable | Type[Module] | None = None,
- act_layer: str | Callable | Type[Module] | None = None,
- block_fn: Type[Module] = <class 'timm.models.vision_transformer.Block'>,
- mlp_layer: Type[Module] = <class 'timm.layers.mlp.Mlp'>,
- attn_layer: str | Callable | Type[Module] = <class 'timm.layers.attention.Attention'>,
- device=None,
- dtype=None,
Vision Transformer
- A PyTorch impl ofAn Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
- fix_init_weight() None
Apply weight initialization fix (scaling w/ layer index).
- forward_features(
- x: Tensor,
- attn_mask: Tensor | None = None,
- is_causal: bool = False,
Forward pass through feature layers (embeddings, transformer blocks, post-transformer norm).
- forward_head(
- x: Tensor,
- pre_logits: bool = False,
Forward pass through classifier head.
- Parameters:
x – Feature tensor.
pre_logits – Return features before final classifier.
- Returns:
Output tensor.
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- return_prefix_tokens: bool = False,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
- output_dict: bool = False,
- attn_mask: Tensor | None = None,
- is_causal: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
return_prefix_tokens – Return both prefix and spatial intermediate tokens
norm – Apply norm layer to all intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
output_dict – Return outputs as a dictionary with ‘image_features’ and ‘image_intermediates’ keys
attn_mask – Optional attention mask for masked attention (e.g., for NaFlex)
is_causal – If True, use causal (autoregressive) masking in attention
- Returns:
A tuple with (final_features, intermediates), a list of intermediate features, or a dictionary containing ‘image_features’ and ‘image_intermediates’ (and optionally ‘image_intermediates_prefix’)
- get_classifier() Module
Get the classifier head.
- get_intermediate_layers(
- x: Tensor,
- n: int | List[int] | Tuple[int] = 1,
- reshape: bool = False,
- return_prefix_tokens: bool = False,
- norm: bool = False,
- attn_mask: Tensor | None = None,
Get intermediate layer outputs (DINO interface compatibility).
NOTE: This API is for backwards compat, favour using forward_intermediates() directly.
- Parameters:
x – Input tensor.
n – Number or indices of layers.
reshape – Reshape to NCHW format.
return_prefix_tokens – Return prefix tokens.
norm – Apply normalization.
- Returns:
List of intermediate features.
- group_matcher(
- coarse: bool = False,
Create regex patterns for parameter grouping.
- Parameters:
coarse – Use coarse grouping.
- Returns:
Dictionary mapping group names to regex patterns.
- init_weights(mode: str = '', needs_reset: bool = True) None
Initialize model weights.
- Parameters:
mode – Weight initialization mode (‘jax’, ‘jax_nlhb’, ‘moco’, or ‘’).
needs_reset – If True, call reset_parameters() on modules that have it. Set to False when modules have already self-initialized in __init__.
- load_pretrained(
- checkpoint_path: str,
- prefix: str = '',
Load pretrained weights.
- Parameters:
checkpoint_path – Path to checkpoint.
prefix – Prefix for state dict keys.
- no_weight_decay() Set[str]
Set of parameters that should not use weight decay.
- pool(
- x: Tensor,
- pool_type: str | None = None,
Apply pooling to feature tokens.
- Parameters:
x – Feature tensor.
pool_type – Pooling type override.
- Returns:
Pooled features.
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- Parameters:
indices – Indices of intermediate layers to keep.
prune_norm – Whether to prune normalization layer.
prune_head – Whether to prune the classifier head.
- Returns:
List of indices that were kept.
- reset_classifier(
- num_classes: int,
- global_pool: str | None = None,
Reset the classifier head.
- Parameters:
num_classes – Number of classes for new classifier.
global_pool – Global pooling type.
- set_grad_checkpointing(enable: bool = True) None
Enable or disable gradient checkpointing.
- Parameters:
enable – Whether to enable gradient checkpointing.
- set_input_size(
- img_size: Tuple[int, int] | None = None,
- patch_size: Tuple[int, int] | None = None,
Update the input image resolution and patch size.
- Parameters:
img_size – New input resolution, if None current resolution is used.
patch_size – New patch size, if None existing patch size is used.
- class timm.models.vision_transformer_relpos.VisionTransformerRelPos(
- img_size: int | ~typing.Tuple[int,
- int]=224,
- patch_size: int | ~typing.Tuple[int,
- int]=16,
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: Literal['',
- 'avg',
- 'token',
- 'map']='avg',
- embed_dim: int = 768,
- depth: int = 12,
- num_heads: int = 12,
- mlp_ratio: float = 4.0,
- qkv_bias: bool = True,
- qk_norm: bool = False,
- init_values: float | None = 1e-06,
- class_token: bool = False,
- fc_norm: bool = False,
- rel_pos_type: str = 'mlp',
- rel_pos_dim: int | None = None,
- shared_rel_pos: bool = False,
- drop_rate: float = 0.0,
- proj_drop_rate: float = 0.0,
- attn_drop_rate: float = 0.0,
- drop_path_rate: float = 0.0,
- weight_init: Literal['skip',
- 'reset',
- 'jax',
- 'moco',
- '']='reset',
- fix_init: bool = False,
- embed_layer: Type[Module] = <class 'timm.layers.patch_embed.PatchEmbed'>,
- norm_layer: str | Callable | Type[Module] | None = None,
- act_layer: str | Callable | Type[Module] | None = None,
- block_fn: Type[Module] = <class 'timm.models.vision_transformer_relpos.RelPosBlock'>,
- device=None,
- dtype=None,
Vision Transformer w/ Relative Position Bias
- Differing from classic vit, this impl
uses relative position index (swin v1 / beit) or relative log coord + mlp (swin v2) pos embed
defaults to no class token (can be enabled)
defaults to global avg pool for head (can be changed)
layer-scale (residual branch gain) enabled
- fix_init_weight() None
Apply weight initialization fix (scaling w/ layer index).
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- return_prefix_tokens: bool = False,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
return_prefix_tokens – Return both prefix and spatial intermediate tokens
norm – Apply norm layer to all intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- init_weights(
- mode: str = '',
- needs_reset: bool = True,
Initialize model weights.
- Parameters:
mode – Weight initialization mode (‘jax’, ‘jax_nlhb’, ‘moco’, or ‘’).
needs_reset – If True, call reset_parameters() on modules (default for after to_empty()). If False, skip reset_parameters() (for __init__ where modules already self-initialized).
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.vision_transformer_sam.VisionTransformerSAM(img_size: int = 1024, patch_size: int = 16, in_chans: int = 3, num_classes: int = 768, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, qkv_bias: bool = True, qk_norm: bool = False, init_values: float | None = None, pre_norm: bool = False, drop_rate: float = 0.0, pos_drop_rate: float = 0.0, patch_drop_rate: float = 0.0, proj_drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, weight_init: str = '', embed_layer: ~typing.Type[~torch.nn.modules.module.Module] = functools.partial(<class 'timm.layers.patch_embed.PatchEmbed'>, output_fmt=<Format.NHWC: 'NHWC'>, strict_img_size=False), norm_layer: ~typing.Type[~torch.nn.modules.module.Module] | None = <class 'torch.nn.modules.normalization.LayerNorm'>, act_layer: ~typing.Type[~torch.nn.modules.module.Module] | None = <class 'torch.nn.modules.activation.GELU'>, block_fn: ~typing.Type[~torch.nn.modules.module.Module] = <class 'timm.models.vision_transformer_sam.Block'>, mlp_layer: ~typing.Type[~torch.nn.modules.module.Module] = <class 'timm.layers.mlp.Mlp'>, use_abs_pos: bool = True, use_rel_pos: bool = False, use_rope: bool = False, window_size: int = 14, global_attn_indexes: ~typing.Tuple[int, ...] = (), neck_chans: int = 256, global_pool: str = 'avg', head_hidden_size: int | None = None, ref_feat_shape: ~typing.Tuple[~typing.Tuple[int, int], ~typing.Tuple[int, int]] | None = None, device=None, dtype=None)
Vision Transformer for Segment-Anything Model(SAM)
- A PyTorch impl ofExploring Plain Vision Transformer Backbones for Object Detection or Segment Anything Model (SAM)
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to all intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- prune_intermediate_layers(
- indices: int | List[int] | None = None,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.volo.VOLO(layers: ~typing.List[int], img_size: int = 224, in_chans: int = 3, num_classes: int = 1000, global_pool: str = 'token', patch_size: int = 8, stem_hidden_dim: int = 64, embed_dims: ~typing.List[int] | None = None, num_heads: ~typing.List[int] | None = None, downsamples: ~typing.Tuple[bool, ...] = (True, False, False, False), outlook_attention: ~typing.Tuple[bool, ...] = (True, False, False, False), mlp_ratio: float = 3.0, qkv_bias: bool = False, drop_rate: float = 0.0, pos_drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, norm_layer: ~typing.Type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.normalization.LayerNorm'>, post_layers: ~typing.Tuple[str, ...] | None = ('ca', 'ca'), use_aux_head: bool = True, use_mix_token: bool = False, pooling_scale: int = 2, device=None, dtype=None)
Vision Outlooker (VOLO) model.
- forward_cls(x: Tensor) Tensor
Forward pass through class attention blocks.
- Parameters:
x – Input token tensor of shape (B, N, C).
- Returns:
Output tensor with class token of shape (B, N+1, C).
- forward_features(x: Tensor) Tensor
Forward pass through feature extraction.
- Parameters:
x – Input tensor of shape (B, C, H, W).
- Returns:
Feature tensor.
- forward_head(x: Tensor, pre_logits: bool = False) Tensor
Forward pass through classification head.
- Parameters:
x – Input feature tensor.
pre_logits – Whether to return pre-logits features.
- Returns:
Classification logits or pre-logits features.
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to all intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- forward_tokens(x: Tensor) Tensor
Forward pass through token processing stages.
- Parameters:
x – Input tensor of shape (B, H, W, C).
- Returns:
Token tensor of shape (B, N, C).
- forward_train(
- x: Tensor,
Forward pass for training with mix token support.
- Parameters:
x – Input tensor of shape (B, C, H, W).
- Returns:
tuple of (class_token, aux_tokens, bbox). Otherwise: class_token tensor.
- Return type:
If training with mix_token
- get_classifier() Module
Get classifier module.
- Returns:
The classifier head module.
- group_matcher(coarse: bool = False) Dict[str, Any]
Get parameter grouping for optimizer.
- Parameters:
coarse – Whether to use coarse grouping.
- Returns:
Parameter grouping dictionary.
- no_weight_decay() set
Get set of parameters that should not have weight decay.
- Returns:
Set of parameter names.
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- Parameters:
indices – Indices of intermediate layers to keep.
prune_norm – Whether to prune normalization layer.
prune_head – Whether to prune classification head.
- Returns:
List of kept intermediate indices.
- reset_classifier(num_classes: int, global_pool: str | None = None) None
Reset classifier head.
- Parameters:
num_classes – Number of classes for new classifier.
global_pool – Global pooling type.
- set_grad_checkpointing(enable: bool = True) None
Set gradient checkpointing.
- Parameters:
enable – Whether to enable gradient checkpointing.
- class timm.models.vovnet.VovNet(
- cfg: dict,
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: str = 'avg',
- output_stride: int = 32,
- norm_layer: Type[Module] = <class 'timm.layers.norm_act.BatchNormAct2d'>,
- act_layer: Type[Module] = <class 'torch.nn.modules.activation.ReLU'>,
- drop_rate: float = 0.0,
- drop_path_rate: float = 0.0,
- device=None,
- dtype=None,
- **kwargs,
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to compatible intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.
- class timm.models.xception.Xception(
- num_classes: int = 1000,
- in_chans: int = 3,
- drop_rate: float = 0.0,
- global_pool: str = 'avg',
- device=None,
- dtype=None,
Xception optimized for the ImageNet dataset, as specified in https://arxiv.org/pdf/1610.02357.pdf
- class timm.models.xception_aligned.XceptionAligned(block_cfg: ~typing.List[~typing.Dict], num_classes: int = 1000, in_chans: int = 3, output_stride: int = 32, preact: bool = False, act_layer: ~typing.Type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.ReLU'>, norm_layer: ~typing.Type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.batchnorm.BatchNorm2d'>, drop_rate: float = 0.0, drop_path_rate: float = 0.0, global_pool: str = 'avg', device=None, dtype=None)
Modified Aligned Xception
- class timm.models.xcit.Xcit(
- img_size: int | Tuple[int, int] = 224,
- patch_size: int = 16,
- in_chans: int = 3,
- num_classes: int = 1000,
- global_pool: str = 'token',
- embed_dim: int = 768,
- depth: int = 12,
- num_heads: int = 12,
- mlp_ratio: float = 4.0,
- qkv_bias: bool = True,
- drop_rate: float = 0.0,
- pos_drop_rate: float = 0.0,
- proj_drop_rate: float = 0.0,
- attn_drop_rate: float = 0.0,
- drop_path_rate: float = 0.0,
- act_layer: Type[Module] | None = None,
- norm_layer: Type[Module] | None = None,
- cls_attn_layers: int = 2,
- use_pos_embed: bool = True,
- eta: float = 1.0,
- tokens_norm: bool = False,
- device=None,
- dtype=None,
Based on timm and DeiT code bases https://github.com/rwightman/pytorch-image-models/tree/master/timm https://github.com/facebookresearch/deit/
- forward_intermediates(
- x: Tensor,
- indices: int | List[int] | None = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
Forward features that returns intermediates.
- Parameters:
x – Input image tensor
indices – Take last n blocks if int, all if None, select matching indices if sequence
norm – Apply norm layer to all intermediates
stop_early – Stop iterating over blocks when last desired intermediate hit
output_fmt – Shape of intermediate feature outputs
intermediates_only – Only return intermediate features
Returns:
- prune_intermediate_layers(
- indices: int | List[int] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
Prune layers not required for specified intermediates.