diff --git a/tests/test_models.py b/tests/test_models.py index 7f696dc1..34bf0af4 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -49,8 +49,9 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): # models with forward_intermediates() and support for FeatureGetterNet features_only wrapper FEAT_INTER_FILTERS = [ - 'vit_*', 'twins_*', 'deit*', 'beit*', 'mvitv2*', 'eva*', 'samvit_*', 'flexivit*', - 'cait_*', 'xcit_*', 'volo_*', + 'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos', + 'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2', + 'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet' ] # transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output. @@ -388,13 +389,12 @@ def test_model_forward_features(model_name, batch_size): @pytest.mark.features @pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS)) +@pytest.mark.parametrize('model_name', list_models(module=FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS + ['*pruned*'])) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward_intermediates_features(model_name, batch_size): """Run a single forward pass with each model in feature extraction mode""" - model = create_model(model_name, pretrained=False, features_only=True) + model = create_model(model_name, pretrained=False, features_only=True, feature_cls='getter') model.eval() - print(model.feature_info.out_indices) expected_channels = model.feature_info.channels() expected_reduction = model.feature_info.reduction() @@ -420,7 +420,7 @@ def test_model_forward_intermediates_features(model_name, batch_size): @pytest.mark.features @pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS)) +@pytest.mark.parametrize('model_name', list_models(module=FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS + ['*pruned*'])) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward_intermediates(model_name, batch_size): """Run a single forward pass with each model in feature extraction mode""" @@ -429,18 +429,19 @@ def test_model_forward_intermediates(model_name, batch_size): feature_info = timm.models.FeatureInfo(model.feature_info, len(model.feature_info)) expected_channels = feature_info.channels() expected_reduction = feature_info.reduction() - assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6 + assert len(expected_channels) >= 3 # all models here should have at least 3 feature levels input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE) if max(input_size) > MAX_FFEAT_SIZE: pytest.skip("Fixed input size model > limit.") - output_fmt = getattr(model, 'output_fmt', 'NCHW') + output_fmt = 'NCHW' # NOTE output_fmt determined by forward_intermediates() arg, not model attribute feat_axis = get_channel_dim(output_fmt) spatial_axis = get_spatial_dim(output_fmt) import math output, intermediates = model.forward_intermediates( torch.randn((batch_size, *input_size)), + output_fmt=output_fmt, ) assert len(expected_channels) == len(intermediates) spatial_size = input_size[-2:] diff --git a/timm/layers/adaptive_avgmax_pool.py b/timm/layers/adaptive_avgmax_pool.py index 16af4afd..d0dd58d9 100644 --- a/timm/layers/adaptive_avgmax_pool.py +++ b/timm/layers/adaptive_avgmax_pool.py @@ -134,6 +134,7 @@ class SelectAdaptivePool2d(nn.Module): super(SelectAdaptivePool2d, self).__init__() assert input_fmt in ('NCHW', 'NHWC') self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing + pool_type = pool_type.lower() if not pool_type: self.pool = nn.Identity() # pass through self.flatten = nn.Flatten(1) if flatten else nn.Identity() @@ -145,8 +146,10 @@ class SelectAdaptivePool2d(nn.Module): self.pool = FastAdaptiveAvgMaxPool(flatten, input_fmt=input_fmt) elif pool_type.endswith('max'): self.pool = FastAdaptiveMaxPool(flatten, input_fmt=input_fmt) - else: + elif pool_type == 'fast' or pool_type.endswith('avg'): self.pool = FastAdaptiveAvgPool(flatten, input_fmt=input_fmt) + else: + assert False, 'Invalid pool type: %s' % pool_type self.flatten = nn.Identity() else: assert input_fmt == 'NCHW' @@ -156,8 +159,10 @@ class SelectAdaptivePool2d(nn.Module): self.pool = AdaptiveCatAvgMaxPool2d(output_size) elif pool_type == 'max': self.pool = nn.AdaptiveMaxPool2d(output_size) - else: + elif pool_type == 'avg': self.pool = nn.AdaptiveAvgPool2d(output_size) + else: + assert False, 'Invalid pool type: %s' % pool_type self.flatten = nn.Flatten(1) if flatten else nn.Identity() def is_identity(self): diff --git a/timm/models/_builder.py b/timm/models/_builder.py index c1ad5c2d..f248fbd3 100644 --- a/timm/models/_builder.py +++ b/timm/models/_builder.py @@ -2,7 +2,7 @@ import dataclasses import logging import os from copy import deepcopy -from typing import Optional, Dict, Callable, Any, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple from torch import nn as nn from torch.hub import load_state_dict_from_url @@ -359,15 +359,15 @@ def build_model_with_cfg( * pruning config / model adaptation Args: - model_cls (nn.Module): model class - variant (str): model variant name - pretrained (bool): load pretrained weights - pretrained_cfg (dict): model's pretrained weight/task config - model_cfg (Optional[Dict]): model's architecture config - feature_cfg (Optional[Dict]: feature extraction adapter config - pretrained_strict (bool): load pretrained weights strictly - pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights - kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model + model_cls: model class + variant: model variant name + pretrained: load pretrained weights + pretrained_cfg: model's pretrained weight/task config + model_cfg: model's architecture config + feature_cfg: feature extraction adapter config + pretrained_strict: load pretrained weights strictly + pretrained_filter_fn: filter callable for pretrained weights + kwargs_filter: kwargs to filter before passing to model **kwargs: model args passed through to model __init__ """ pruned = kwargs.pop('pruned', False) @@ -392,6 +392,8 @@ def build_model_with_cfg( feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4)) if 'out_indices' in kwargs: feature_cfg['out_indices'] = kwargs.pop('out_indices') + if 'feature_cls' in kwargs: + feature_cfg['feature_cls'] = kwargs.pop('feature_cls') # Instantiate the model if model_cfg is None: @@ -418,24 +420,36 @@ def build_model_with_cfg( # Wrap the model in a feature extraction module if enabled if features: - feature_cls = FeatureListNet - output_fmt = getattr(model, 'output_fmt', None) - if output_fmt is not None: - feature_cfg.setdefault('output_fmt', output_fmt) + use_getter = False if 'feature_cls' in feature_cfg: feature_cls = feature_cfg.pop('feature_cls') if isinstance(feature_cls, str): feature_cls = feature_cls.lower() + + # flatten_sequential only valid for some feature extractors + if feature_cls not in ('dict', 'list', 'hook'): + feature_cfg.pop('flatten_sequential', None) + if 'hook' in feature_cls: feature_cls = FeatureHookNet + elif feature_cls == 'list': + feature_cls = FeatureListNet elif feature_cls == 'dict': feature_cls = FeatureDictNet elif feature_cls == 'fx': feature_cls = FeatureGraphNet elif feature_cls == 'getter': + use_getter = True feature_cls = FeatureGetterNet else: assert False, f'Unknown feature class {feature_cls}' + else: + feature_cls = FeatureListNet + + output_fmt = getattr(model, 'output_fmt', None) + if output_fmt is not None and not use_getter: # don't set default for intermediate feat getter + feature_cfg.setdefault('output_fmt', output_fmt) + model = feature_cls(model, **feature_cfg) model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back pretrained cfg model.default_cfg = model.pretrained_cfg # alias for rename backwards compat (default_cfg -> pretrained_cfg) diff --git a/timm/models/_features.py b/timm/models/_features.py index 9dbac1cd..5bd6f1fb 100644 --- a/timm/models/_features.py +++ b/timm/models/_features.py @@ -363,7 +363,7 @@ class FeatureHookNet(nn.ModuleDict): out_map: Optional[Sequence[Union[int, str]]] = None, return_dict: bool = False, output_fmt: str = 'NCHW', - no_rewrite: bool = False, + no_rewrite: Optional[bool] = None, flatten_sequential: bool = False, default_hook_type: str = 'forward', ): @@ -385,7 +385,8 @@ class FeatureHookNet(nn.ModuleDict): self.return_dict = return_dict self.output_fmt = Format(output_fmt) self.grad_checkpointing = False - + if no_rewrite is None: + no_rewrite = not flatten_sequential layers = OrderedDict() hooks = [] if no_rewrite: @@ -467,7 +468,7 @@ class FeatureGetterNet(nn.ModuleDict): self.out_indices = out_indices self.out_map = out_map self.return_dict = return_dict - self.output_fmt = output_fmt + self.output_fmt = Format(output_fmt) self.norm = norm def forward(self, x): diff --git a/timm/models/_features_fx.py b/timm/models/_features_fx.py index e67d1f25..3b5891e6 100644 --- a/timm/models/_features_fx.py +++ b/timm/models/_features_fx.py @@ -15,7 +15,7 @@ except ImportError: has_fx_feature_extraction = False # Layers we went to treat as leaf modules -from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame +from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame, Format from timm.layers.non_local_attn import BilinearAttnTransform from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame from timm.layers.norm_act import ( @@ -108,12 +108,14 @@ class FeatureGraphNet(nn.Module): model: nn.Module, out_indices: Tuple[int, ...], out_map: Optional[Dict] = None, + output_fmt: str = 'NCHW', ): super().__init__() assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' self.feature_info = _get_feature_info(model, out_indices) if out_map is not None: assert len(out_map) == len(out_indices) + self.output_fmt = Format(output_fmt) return_nodes = _get_return_layers(self.feature_info, out_map) self.graph_module = create_feature_extractor(model, return_nodes) diff --git a/timm/models/_registry.py b/timm/models/_registry.py index a129e1af..fde8bac7 100644 --- a/timm/models/_registry.py +++ b/timm/models/_registry.py @@ -184,7 +184,7 @@ def _expand_filter(filter: str): def list_models( filter: Union[str, List[str]] = '', - module: str = '', + module: Union[str, List[str]] = '', pretrained: bool = False, exclude_filters: Union[str, List[str]] = '', name_matches_cfg: bool = False, @@ -217,7 +217,16 @@ def list_models( # FIXME should this be default behaviour? or default to include_tags=True? include_tags = pretrained - all_models: Set[str] = _module_to_models[module] if module else set(_model_entrypoints.keys()) + if not module: + all_models: Set[str] = set(_model_entrypoints.keys()) + else: + if isinstance(module, str): + all_models: Set[str] = _module_to_models[module] + else: + assert isinstance(module, Sequence) + all_models: Set[str] = set() + for m in module: + all_models.update(_module_to_models[m]) all_models = all_models - _deprecated_models.keys() # remove deprecated models from listings if include_tags: diff --git a/timm/models/beit.py b/timm/models/beit.py index 19bf2c58..43048285 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -407,7 +407,7 @@ class Beit(nn.Module): indices: Optional[Union[int, List[int], Tuple[int]]] = None, return_prefix_tokens: bool = False, norm: bool = False, - stop_early: bool = True, + stop_early: bool = False, output_fmt: str = 'NCHW', intermediates_only: bool = False, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: @@ -424,7 +424,7 @@ class Beit(nn.Module): Returns: """ - assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.' + assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.' reshape = output_fmt == 'NCHW' intermediates = [] take_indices, max_index = feature_take_indices(len(self.blocks), indices) @@ -436,6 +436,7 @@ class Beit(nn.Module): if self.pos_embed is not None: x = x + self.pos_embed x = self.pos_drop(x) + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript blocks = self.blocks @@ -469,19 +470,19 @@ class Beit(nn.Module): def prune_intermediate_layers( self, - n: Union[int, List[int], Tuple[int]] = 1, + indices: Union[int, List[int], Tuple[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): """ Prune layers not required for specified intermediates. """ - take_indices, max_index = feature_take_indices(len(self.blocks), n) + take_indices, max_index = feature_take_indices(len(self.blocks), indices) self.blocks = self.blocks[:max_index + 1] # truncate blocks if prune_norm: self.norm = nn.Identity() if prune_head: self.fc_norm = nn.Identity() - self.head = nn.Identity() + self.reset_classifier(0, '') return take_indices def forward_features(self, x): diff --git a/timm/models/cait.py b/timm/models/cait.py index 50148405..bf649076 100644 --- a/timm/models/cait.py +++ b/timm/models/cait.py @@ -343,7 +343,7 @@ class Cait(nn.Module): x: torch.Tensor, indices: Optional[Union[int, List[int], Tuple[int]]] = None, norm: bool = False, - stop_early: bool = True, + stop_early: bool = False, output_fmt: str = 'NCHW', intermediates_only: bool = False, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: @@ -357,7 +357,7 @@ class Cait(nn.Module): output_fmt: Shape of intermediate feature outputs intermediates_only: Only return intermediate features """ - assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.' + assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.' reshape = output_fmt == 'NCHW' intermediates = [] take_indices, max_index = feature_take_indices(len(self.blocks), indices) @@ -367,6 +367,7 @@ class Cait(nn.Module): x = self.patch_embed(x) x = x + self.pos_embed x = self.pos_drop(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript blocks = self.blocks else: @@ -397,19 +398,19 @@ class Cait(nn.Module): def prune_intermediate_layers( self, - n: Union[int, List[int], Tuple[int]] = 1, + indices: Union[int, List[int], Tuple[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): """ Prune layers not required for specified intermediates. """ - take_indices, max_index = feature_take_indices(len(self.blocks), n) + take_indices, max_index = feature_take_indices(len(self.blocks), indices) self.blocks = self.blocks[:max_index + 1] # truncate blocks if prune_norm: self.norm = nn.Identity() if prune_head: self.blocks_token_only = nn.ModuleList() # prune token blocks with head - self.head = nn.Identity() + self.reset_classifier(0, '') return take_indices def forward_features(self, x): diff --git a/timm/models/convnext.py b/timm/models/convnext.py index ce7fd20b..76bb8136 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -39,7 +39,7 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W from collections import OrderedDict from functools import partial -from typing import Callable, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -49,6 +49,7 @@ from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalRespo LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple from timm.layers import NormMlpClassifierHead, ClassifierHead from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import named_apply, checkpoint_seq from ._registry import generate_default_cfgs, register_model, register_model_deprecations @@ -407,6 +408,71 @@ class ConvNeXt(nn.Module): def reset_classifier(self, num_classes=0, global_pool=None): self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Union[int, List[int], Tuple[int]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices) + + # forward pass + feat_idx = 0 # stem is index 0 + x = self.stem(x) + if feat_idx in take_indices: + intermediates.append(x) + + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index] + for stage in stages: + feat_idx += 1 + x = stage(x) + if feat_idx in take_indices: + # NOTE not bothering to apply norm_pre when norm=True as almost no models have it enabled + intermediates.append(x) + + if intermediates_only: + return intermediates + + x = self.norm_pre(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices) + self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0 + if prune_norm: + self.norm_pre = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) x = self.stages(x) diff --git a/timm/models/efficientformer.py b/timm/models/efficientformer.py index 727fcac3..fb6ff8ec 100644 --- a/timm/models/efficientformer.py +++ b/timm/models/efficientformer.py @@ -12,7 +12,7 @@ Based on Apache 2.0 licensed code at https://github.com/snap-research/EfficientF Modifications and timm support by / Copyright 2022, Ross Wightman """ -from typing import Dict +from typing import Dict, List, Tuple, Union import torch import torch.nn as nn @@ -20,6 +20,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp, ndgrid from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model @@ -382,16 +383,19 @@ class EfficientFormer(nn.Module): prev_dim = embed_dims[0] # stochastic depth decay rule + self.num_stages = len(depths) + last_stage = self.num_stages - 1 dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] - downsamples = downsamples or (False,) + (True,) * (len(depths) - 1) + downsamples = downsamples or (False,) + (True,) * (self.num_stages - 1) stages = [] - for i in range(len(depths)): + self.feature_info = [] + for i in range(self.num_stages): stage = EfficientFormerStage( prev_dim, embed_dims[i], depths[i], downsample=downsamples[i], - num_vit=num_vit if i == 3 else 0, + num_vit=num_vit if i == last_stage else 0, pool_size=pool_size, mlp_ratio=mlp_ratios, act_layer=act_layer, @@ -403,7 +407,7 @@ class EfficientFormer(nn.Module): ) prev_dim = embed_dims[i] stages.append(stage) - + self.feature_info += [dict(num_chs=embed_dims[i], reduction=2**(1+i), module=f'stages.{i}')] self.stages = nn.Sequential(*stages) # Classifier head @@ -456,6 +460,76 @@ class EfficientFormer(nn.Module): def set_distilled_training(self, enable=True): self.distilled_training = enable + def forward_intermediates( + self, + x: torch.Tensor, + indices: Union[int, List[int], Tuple[int]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + B, C, H, W = x.shape + + last_idx = self.num_stages - 1 + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + feat_idx = 0 + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx < last_idx: + B, C, H, W = x.shape + if feat_idx in take_indices: + if feat_idx == last_idx: + x_inter = self.norm(x) if norm else x + intermediates.append(x_inter.reshape(B, H // 2, W // 2, -1).permute(0, 3, 1, 2)) + else: + intermediates.append(x) + + if intermediates_only: + return intermediates + + if feat_idx == last_idx: + x = self.norm(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) x = self.stages(x) @@ -534,13 +608,13 @@ default_cfgs = generate_default_cfgs({ def _create_efficientformer(variant, pretrained=False, **kwargs): - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for EfficientFormer models.') - + out_indices = kwargs.pop('out_indices', 4) model = build_model_with_cfg( EfficientFormer, variant, pretrained, pretrained_filter_fn=_checkpoint_filter_fn, - **kwargs) + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), + **kwargs, + ) return model diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 6e61d1bf..fb04fb2c 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -36,7 +36,7 @@ the models and weights open source! Hacked together by / Copyright 2019, Ross Wightman """ from functools import partial -from typing import List +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -49,7 +49,7 @@ from ._builder import build_model_with_cfg, pretrained_cfg_for_features from ._efficientnet_blocks import SqueezeExcite from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \ round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT -from ._features import FeatureInfo, FeatureHooks +from ._features import FeatureInfo, FeatureHooks, feature_take_indices from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model, register_model_deprecations @@ -118,6 +118,7 @@ class EfficientNet(nn.Module): ) self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.feature_info = builder.features + self.stage_ends = [f['stage'] for f in self.feature_info] head_chs = builder.in_chs # Head + Pooling @@ -158,6 +159,86 @@ class EfficientNet(nn.Module): self.global_pool, self.classifier = create_classifier( self.num_features, self.num_classes, pool_type=global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Union[int, List[int], Tuple[int]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + extra_blocks: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + extra_blocks: Include outputs of all blocks and head conv in output, does not align with feature_info + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + if stop_early: + assert intermediates_only, 'Must use intermediates_only for early stopping.' + intermediates = [] + if extra_blocks: + take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices) + else: + take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) + take_indices = [self.stage_ends[i] for i in take_indices] + max_index = self.stage_ends[max_index] + # forward pass + feat_idx = 0 # stem is index 0 + x = self.conv_stem(x) + x = self.bn1(x) + if feat_idx in take_indices: + intermediates.append(x) + + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = self.blocks + else: + blocks = self.blocks[:max_index] + for blk in blocks: + feat_idx += 1 + x = blk(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + x = self.conv_head(x) + x = self.bn2(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + extra_blocks: bool = False, + ): + """ Prune layers not required for specified intermediates. + """ + if extra_blocks: + take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices) + else: + take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) + max_index = self.stage_ends[max_index] + self.blocks = self.blocks[:max_index] # truncate blocks w/ stem as idx 0 + if prune_norm or max_index < len(self.blocks): + self.conv_head = nn.Identity() + self.bn2 = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.conv_stem(x) x = self.bn1(x) @@ -272,7 +353,7 @@ def _create_effnet(variant, pretrained=False, **kwargs): model_cls = EfficientNet kwargs_filter = None if kwargs.pop('features_only', False): - if 'feature_cfg' in kwargs: + if 'feature_cfg' in kwargs or 'feature_cls' in kwargs: features_mode = 'cfg' else: kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'global_pool') diff --git a/timm/models/eva.py b/timm/models/eva.py index 416bc951..15702cf2 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -53,6 +53,7 @@ class EvaAttention(nn.Module): num_heads: int = 8, qkv_bias: bool = True, qkv_fused: bool = True, + num_prefix_tokens: int = 1, attn_drop: float = 0., proj_drop: float = 0., attn_head_dim: Optional[int] = None, @@ -77,6 +78,7 @@ class EvaAttention(nn.Module): head_dim = attn_head_dim all_head_dim = head_dim * self.num_heads self.scale = head_dim ** -0.5 + self.num_prefix_tokens = num_prefix_tokens self.fused_attn = use_fused_attn() if qkv_fused: @@ -119,8 +121,9 @@ class EvaAttention(nn.Module): v = self.v_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) if rope is not None: - q = torch.cat([q[:, :, :1, :], apply_rot_embed_cat(q[:, :, 1:, :], rope)], 2).type_as(v) - k = torch.cat([k[:, :, :1, :], apply_rot_embed_cat(k[:, :, 1:, :], rope)], 2).type_as(v) + npt = self.num_prefix_tokens + q = torch.cat([q[:, :, :npt, :], apply_rot_embed_cat(q[:, :, npt:, :], rope)], dim=2).type_as(v) + k = torch.cat([k[:, :, :npt, :], apply_rot_embed_cat(k[:, :, npt:, :], rope)], dim=2).type_as(v) if self.fused_attn: x = F.scaled_dot_product_attention( @@ -157,6 +160,7 @@ class EvaBlock(nn.Module): swiglu_mlp: bool = False, scale_mlp: bool = False, scale_attn_inner: bool = False, + num_prefix_tokens: int = 1, proj_drop: float = 0., attn_drop: float = 0., drop_path: float = 0., @@ -191,6 +195,7 @@ class EvaBlock(nn.Module): num_heads=num_heads, qkv_bias=qkv_bias, qkv_fused=qkv_fused, + num_prefix_tokens=num_prefix_tokens, attn_drop=attn_drop, proj_drop=proj_drop, attn_head_dim=attn_head_dim, @@ -253,6 +258,7 @@ class EvaBlockPostNorm(nn.Module): swiglu_mlp: bool = False, scale_mlp: bool = False, scale_attn_inner: bool = False, + num_prefix_tokens: int = 1, proj_drop: float = 0., attn_drop: float = 0., drop_path: float = 0., @@ -286,6 +292,7 @@ class EvaBlockPostNorm(nn.Module): num_heads=num_heads, qkv_bias=qkv_bias, qkv_fused=qkv_fused, + num_prefix_tokens=num_prefix_tokens, attn_drop=attn_drop, proj_drop=proj_drop, attn_head_dim=attn_head_dim, @@ -364,6 +371,7 @@ class Eva(nn.Module): norm_layer: Callable = LayerNorm, init_values: Optional[float] = None, class_token: bool = True, + num_reg_tokens: int = 0, use_abs_pos_emb: bool = True, use_rot_pos_emb: bool = False, use_post_norm: bool = False, @@ -407,7 +415,7 @@ class Eva(nn.Module): self.num_classes = num_classes self.global_pool = global_pool self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models - self.num_prefix_tokens = 1 if class_token else 0 + self.num_prefix_tokens = (1 if class_token else 0) + num_reg_tokens self.dynamic_img_size = dynamic_img_size self.grad_checkpointing = False @@ -427,6 +435,8 @@ class Eva(nn.Module): r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None + self.reg_token = nn.Parameter(torch.zeros(1, num_reg_tokens, embed_dim)) if num_reg_tokens else None + self.cls_embed = class_token and self.reg_token is None self.pos_embed = nn.Parameter( torch.zeros(1, num_patches + self.num_prefix_tokens, embed_dim)) if use_abs_pos_emb else None @@ -463,6 +473,7 @@ class Eva(nn.Module): swiglu_mlp=swiglu_mlp, scale_mlp=scale_mlp, scale_attn_inner=scale_attn_inner, + num_prefix_tokens=self.num_prefix_tokens, proj_drop=proj_drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], @@ -484,6 +495,8 @@ class Eva(nn.Module): trunc_normal_(self.pos_embed, std=.02) if self.cls_token is not None: trunc_normal_(self.cls_token, std=.02) + if self.reg_token is not None: + trunc_normal_(self.reg_token, std=.02) self.fix_init_weight() if isinstance(self.head, nn.Linear): @@ -551,8 +564,17 @@ class Eva(nn.Module): if self.cls_token is not None: x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + if pos_embed is not None: x = x + pos_embed + + if self.reg_token is not None: + to_cat = [] + if self.cls_token is not None: + to_cat.append(self.cls_token.expand(x.shape[0], -1, -1)) + to_cat.append(self.reg_token.expand(x.shape[0], -1, -1)) + x = torch.cat(to_cat + [x], dim=1) + x = self.pos_drop(x) # obtain shared rotary position embedding and apply patch dropout @@ -568,7 +590,7 @@ class Eva(nn.Module): indices: Optional[Union[int, List[int], Tuple[int]]] = None, return_prefix_tokens: bool = False, norm: bool = False, - stop_early: bool = True, + stop_early: bool = False, output_fmt: str = 'NCHW', intermediates_only: bool = False, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: @@ -622,19 +644,19 @@ class Eva(nn.Module): def prune_intermediate_layers( self, - n: Union[int, List[int], Tuple[int]] = 1, + indices: Union[int, List[int], Tuple[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): """ Prune layers not required for specified intermediates. """ - take_indices, max_index = feature_take_indices(len(self.blocks), n) + take_indices, max_index = feature_take_indices(len(self.blocks), indices) self.blocks = self.blocks[:max_index + 1] # truncate blocks if prune_norm: self.norm = nn.Identity() if prune_head: self.fc_norm = nn.Identity() - self.head = nn.Identity() + self.reset_classifier(0, '') return take_indices def forward_features(self, x): @@ -695,6 +717,12 @@ def checkpoint_filter_fn( # fixed embedding no need to load buffer from checkpoint continue + # FIXME here while import new weights, to remove + # if k == 'cls_token': + # print('DEBUG: cls token -> reg') + # k = 'reg_token' + # #v = v + state_dict['pos_embed'][0, :] + if 'patch_embed.proj.weight' in k: _, _, H, W = model.patch_embed.proj.weight.shape if v.shape[-1] != W or v.shape[-2] != H: @@ -923,6 +951,29 @@ default_cfgs = generate_default_cfgs({ num_classes=0, ), + 'vit_medium_patch16_rope_reg1_gap_256.in1k': _cfg( + #hf_hub_id='timm/', + file='vit_medium_gap1_rope-in1k-20230920-5.pth', + input_size=(3, 256, 256), crop_pct=0.95, + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5) + ), + 'vit_mediumd_patch16_rope_reg1_gap_256.in1k': _cfg( + #hf_hub_id='timm/', + file='vit_mediumd_gap1_rope-in1k-20230926-5.pth', + input_size=(3, 256, 256), crop_pct=0.95, + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5) + ), + 'vit_betwixt_patch16_rope_reg4_gap_256.in1k': _cfg( + #hf_hub_id='timm/', + file='vit_betwixt_gap4_rope-in1k-20231005-5.pth', + input_size=(3, 256, 256), crop_pct=0.95, + ), + 'vit_base_patch16_rope_reg1_gap_256.in1k': _cfg( + # hf_hub_id='timm/', + file='vit_base_gap1_rope-in1k-20230930-5.pth', + input_size=(3, 256, 256), crop_pct=0.95, + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5) + ), }) @@ -1185,3 +1236,87 @@ def eva02_enormous_patch14_clip_224(pretrained=False, **kwargs) -> Eva: ) model = _create_eva('eva02_enormous_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model + + +@register_model +def vit_medium_patch16_rope_reg1_gap_256(pretrained=False, **kwargs) -> Eva: + model_args = dict( + img_size=256, + patch_size=16, + embed_dim=512, + depth=12, + num_heads=8, + qkv_fused=True, + qkv_bias=True, + init_values=1e-5, + class_token=False, + num_reg_tokens=1, + use_rot_pos_emb=True, + use_abs_pos_emb=False, + ref_feat_shape=(16, 16), # 224/14 + ) + model = _create_eva('vit_medium_patch16_rope_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_mediumd_patch16_rope_reg1_gap_256(pretrained=False, **kwargs) -> Eva: + model_args = dict( + img_size=256, + patch_size=16, + embed_dim=512, + depth=20, + num_heads=8, + qkv_fused=True, + qkv_bias=False, + init_values=1e-5, + class_token=False, + num_reg_tokens=1, + use_rot_pos_emb=True, + use_abs_pos_emb=False, + ref_feat_shape=(16, 16), # 224/14 + ) + model = _create_eva('vit_mediumd_patch16_rope_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_betwixt_patch16_rope_reg4_gap_256(pretrained=False, **kwargs) -> Eva: + model_args = dict( + img_size=256, + patch_size=16, + embed_dim=640, + depth=12, + num_heads=10, + qkv_fused=True, + qkv_bias=True, + init_values=1e-5, + class_token=False, + num_reg_tokens=4, + use_rot_pos_emb=True, + use_abs_pos_emb=False, + ref_feat_shape=(16, 16), # 224/14 + ) + model = _create_eva('vit_betwixt_patch16_rope_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_rope_reg1_gap_256(pretrained=False, **kwargs) -> Eva: + model_args = dict( + img_size=256, + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + qkv_fused=True, + qkv_bias=True, + init_values=1e-5, + class_token=False, + num_reg_tokens=1, + use_rot_pos_emb=True, + use_abs_pos_emb=False, + ref_feat_shape=(16, 16), # 224/14 + ) + model = _create_eva('vit_base_patch16_rope_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model diff --git a/timm/models/levit.py b/timm/models/levit.py index ca0708bd..023f131b 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -25,7 +25,7 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W # Copyright 2020 Ross Wightman, Apache-2.0 License from collections import OrderedDict from functools import partial -from typing import Dict +from typing import Dict, List, Tuple, Union import torch import torch.nn as nn @@ -33,6 +33,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN from timm.layers import to_ntuple, to_2tuple, get_act_layer, DropPath, trunc_normal_, ndgrid from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model @@ -634,6 +635,70 @@ class Levit(nn.Module): self.head = NormLinear( self.embed_dim[-1], num_classes, drop=self.drop_rate) if num_classes > 0 else nn.Identity() + def forward_intermediates( + self, + x: torch.Tensor, + indices: Union[int, List[int], Tuple[int]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + B, C, H, W = x.shape + if not self.use_conv: + x = x.flatten(2).transpose(1, 2) + + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + if self.use_conv: + intermediates.append(x) + else: + intermediates.append(x.reshape(B, H, W, -1).permute(0, 3, 1, 2)) + H = (H + 2 - 1) // 2 + W = (W + 2 - 1) // 2 + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) if not self.use_conv: @@ -746,9 +811,8 @@ model_cfgs = dict( def create_levit(variant, cfg_variant=None, pretrained=False, distilled=True, **kwargs): is_conv = '_conv' in variant out_indices = kwargs.pop('out_indices', (0, 1, 2)) - if kwargs.get('features_only', None): - if not is_conv: - raise RuntimeError('features_only not implemented for LeVit in non-convolutional mode.') + if kwargs.get('features_only', False) and not is_conv: + kwargs.setdefault('feature_cls', 'getter') if cfg_variant is None: if variant in model_cfgs: cfg_variant = variant diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index 3019fecc..86eed72a 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -50,6 +50,7 @@ from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act from timm.layers import trunc_normal_tf_, to_2tuple, extend_tuple, make_divisible, _assert from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf, use_fused_attn, resize_rel_pos_bias_table from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._features_fx import register_notrace_function from ._manipulate import named_apply, checkpoint_seq from ._registry import generate_default_cfgs, register_model @@ -1251,6 +1252,75 @@ class MaxxVit(nn.Module): self.num_classes = num_classes self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Union[int, List[int], Tuple[int]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices) + + # forward pass + feat_idx = 0 # stem is index 0 + x = self.stem(x) + if feat_idx in take_indices: + intermediates.append(x) + + last_idx = len(self.stages) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index] + for stage in stages: + feat_idx += 1 + x = stage(x) + if feat_idx in take_indices: + if norm and feat_idx == last_idx: + x_inter = self.norm(x) # applying final norm to last intermediate + else: + x_inter = x + intermediates.append(x_inter) + + if intermediates_only: + return intermediates + + x = self.norm(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices) + self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0 + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.head = self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) x = self.stages(x) diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 2d197a9d..3976f0db 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -7,7 +7,7 @@ Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244 Hacked together by / Copyright 2019, Ross Wightman """ from functools import partial -from typing import Callable, List, Optional, Tuple +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -20,7 +20,7 @@ from ._builder import build_model_with_cfg, pretrained_cfg_for_features from ._efficientnet_blocks import SqueezeExcite from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \ round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT -from ._features import FeatureInfo, FeatureHooks +from ._features import FeatureInfo, FeatureHooks, feature_take_indices from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model, register_model_deprecations @@ -109,6 +109,7 @@ class MobileNetV3(nn.Module): ) self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.feature_info = builder.features + self.stage_ends = [f['stage'] for f in self.feature_info] head_chs = builder.in_chs # Head + Pooling @@ -150,6 +151,84 @@ class MobileNetV3(nn.Module): self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + def forward_intermediates( + self, + x: torch.Tensor, + indices: Union[int, List[int], Tuple[int]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + extra_blocks: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + extra_blocks: Include outputs of all blocks and head conv in output, does not align with feature_info + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + if stop_early: + assert intermediates_only, 'Must use intermediates_only for early stopping.' + intermediates = [] + if extra_blocks: + take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices) + else: + take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) + take_indices = [self.stage_ends[i] for i in take_indices] + max_index = self.stage_ends[max_index] + + # forward pass + feat_idx = 0 # stem is index 0 + x = self.conv_stem(x) + x = self.bn1(x) + if feat_idx in take_indices: + intermediates.append(x) + + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = self.blocks + else: + blocks = self.blocks[:max_index] + for blk in blocks: + feat_idx += 1 + x = blk(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + extra_blocks: bool = False, + ): + """ Prune layers not required for specified intermediates. + """ + if extra_blocks: + take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices) + else: + take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) + max_index = self.stage_ends[max_index] + self.blocks = self.blocks[:max_index] # truncate blocks w/ stem as idx 0 + if max_index < len(self.blocks): + self.conv_head = nn.Identity() + if prune_head: + self.conv_head = nn.Identity() + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.conv_stem(x) x = self.bn1(x) @@ -288,7 +367,7 @@ def _create_mnv3(variant: str, pretrained: bool = False, **kwargs) -> MobileNetV model_cls = MobileNetV3 kwargs_filter = None if kwargs.pop('features_only', False): - if 'feature_cfg' in kwargs: + if 'feature_cfg' in kwargs or 'feature_cls' in kwargs: features_mode = 'cfg' else: kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool') diff --git a/timm/models/mvitv2.py b/timm/models/mvitv2.py index 5d146468..c8afe470 100644 --- a/timm/models/mvitv2.py +++ b/timm/models/mvitv2.py @@ -839,7 +839,7 @@ class MultiScaleVit(nn.Module): x: torch.Tensor, indices: Union[int, List[int], Tuple[int]] = None, norm: bool = False, - stop_early: bool = True, + stop_early: bool = False, output_fmt: str = 'NCHW', intermediates_only: bool = False, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: @@ -855,13 +855,12 @@ class MultiScaleVit(nn.Module): Returns: """ - assert output_fmt in ('NCHW', 'NLC'), 'Output shape for MViT-V2 must be NCHW or NLC.' + assert output_fmt in ('NCHW', 'NLC'), 'Output shape must be NCHW or NLC.' reshape = output_fmt == 'NCHW' intermediates = [] take_indices, max_index = feature_take_indices(len(self.stages), indices) # FIXME slice block/pos_block if < max - # forward pass x, feat_size = self.patch_embed(x) B = x.shape[0] @@ -870,6 +869,7 @@ class MultiScaleVit(nn.Module): x = torch.cat((cls_tokens, x), dim=1) if self.pos_embed is not None: x = x + self.pos_embed + for i, stage in enumerate(self.stages): x, feat_size = stage(x, feat_size) if i in take_indices: @@ -891,6 +891,23 @@ class MultiScaleVit(nn.Module): return x, intermediates + def prune_intermediate_layers( + self, + indices: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + # FIXME add stage pruning + # self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0 + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x, feat_size = self.patch_embed(x) B, N, C = x.shape diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 69e28946..a2e303ca 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -19,6 +19,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, LayerType, create_attn, \ get_attn, get_act_layer, get_norm_layer, create_classifier from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs, register_model_deprecations @@ -295,8 +296,8 @@ def drop_blocks(drop_prob: float = 0.): def make_blocks( block_fn: Union[BasicBlock, Bottleneck], - channels: List[int], - block_repeats: List[int], + channels: Tuple[int, ...], + block_repeats: Tuple[int, ...], inplanes: int, reduce_first: int = 1, output_stride: int = 32, @@ -394,7 +395,7 @@ class ResNet(nn.Module): def __init__( self, block: Union[BasicBlock, Bottleneck], - layers: List[int], + layers: Tuple[int, ...], num_classes: int = 1000, in_chans: int = 3, output_stride: int = 32, @@ -497,7 +498,7 @@ class ResNet(nn.Module): self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Feature Blocks - channels = [64, 128, 256, 512] + channels = (64, 128, 256, 512) stage_modules, stage_feature_info = make_blocks( block, channels, @@ -553,6 +554,73 @@ class ResNet(nn.Module): self.num_classes = num_classes self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Union[int, List[int], Tuple[int]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + if stop_early: + assert intermediates_only, 'Must use intermediates_only for early stopping.' + intermediates = [] + take_indices, max_index = feature_take_indices(5, indices) + + # forward pass + feat_idx = 0 + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + if feat_idx in take_indices: + intermediates.append(x) + x = self.maxpool(x) + + layer_names = ('layer1', 'layer2', 'layer3', 'layer4') + if stop_early: + layer_names = layer_names[:max_index] + for n in layer_names: + feat_idx += 1 + x = getattr(self, n)(x) # won't work with torchscript, but keeps code reasonable, FML + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(5, indices) + layer_names = ('layer1', 'layer2', 'layer3', 'layer4') + layer_names = layer_names[max_index:] + for n in layer_names: + setattr(self, n, nn.Identity()) + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.conv1(x) x = self.bn1(x) @@ -1246,7 +1314,7 @@ default_cfgs = generate_default_cfgs({ def resnet10t(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-10-T model. """ - model_args = dict(block=BasicBlock, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True) + model_args = dict(block=BasicBlock, layers=(1, 1, 1, 1), stem_width=32, stem_type='deep_tiered', avg_down=True) return _create_resnet('resnet10t', pretrained, **dict(model_args, **kwargs)) @@ -1254,7 +1322,7 @@ def resnet10t(pretrained: bool = False, **kwargs) -> ResNet: def resnet14t(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-14-T model. """ - model_args = dict(block=Bottleneck, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True) + model_args = dict(block=Bottleneck, layers=(1, 1, 1, 1), stem_width=32, stem_type='deep_tiered', avg_down=True) return _create_resnet('resnet14t', pretrained, **dict(model_args, **kwargs)) @@ -1262,7 +1330,7 @@ def resnet14t(pretrained: bool = False, **kwargs) -> ResNet: def resnet18(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-18 model. """ - model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2]) + model_args = dict(block=BasicBlock, layers=(2, 2, 2, 2)) return _create_resnet('resnet18', pretrained, **dict(model_args, **kwargs)) @@ -1270,7 +1338,7 @@ def resnet18(pretrained: bool = False, **kwargs) -> ResNet: def resnet18d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-18-D model. """ - model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True) + model_args = dict(block=BasicBlock, layers=(2, 2, 2, 2), stem_width=32, stem_type='deep', avg_down=True) return _create_resnet('resnet18d', pretrained, **dict(model_args, **kwargs)) @@ -1278,7 +1346,7 @@ def resnet18d(pretrained: bool = False, **kwargs) -> ResNet: def resnet34(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-34 model. """ - model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3]) + model_args = dict(block=BasicBlock, layers=(3, 4, 6, 3)) return _create_resnet('resnet34', pretrained, **dict(model_args, **kwargs)) @@ -1286,7 +1354,7 @@ def resnet34(pretrained: bool = False, **kwargs) -> ResNet: def resnet34d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-34-D model. """ - model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True) + model_args = dict(block=BasicBlock, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep', avg_down=True) return _create_resnet('resnet34d', pretrained, **dict(model_args, **kwargs)) @@ -1294,7 +1362,7 @@ def resnet34d(pretrained: bool = False, **kwargs) -> ResNet: def resnet26(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-26 model. """ - model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2]) + model_args = dict(block=Bottleneck, layers=(2, 2, 2, 2)) return _create_resnet('resnet26', pretrained, **dict(model_args, **kwargs)) @@ -1302,7 +1370,7 @@ def resnet26(pretrained: bool = False, **kwargs) -> ResNet: def resnet26t(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-26-T model. """ - model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep_tiered', avg_down=True) + model_args = dict(block=Bottleneck, layers=(2, 2, 2, 2), stem_width=32, stem_type='deep_tiered', avg_down=True) return _create_resnet('resnet26t', pretrained, **dict(model_args, **kwargs)) @@ -1310,7 +1378,7 @@ def resnet26t(pretrained: bool = False, **kwargs) -> ResNet: def resnet26d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-26-D model. """ - model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True) + model_args = dict(block=Bottleneck, layers=(2, 2, 2, 2), stem_width=32, stem_type='deep', avg_down=True) return _create_resnet('resnet26d', pretrained, **dict(model_args, **kwargs)) @@ -1318,7 +1386,7 @@ def resnet26d(pretrained: bool = False, **kwargs) -> ResNet: def resnet50(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-50 model. """ - model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3]) + model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3)) return _create_resnet('resnet50', pretrained, **dict(model_args, **kwargs)) @@ -1326,7 +1394,7 @@ def resnet50(pretrained: bool = False, **kwargs) -> ResNet: def resnet50c(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-50-C model. """ - model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep') + model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep') return _create_resnet('resnet50c', pretrained, **dict(model_args, **kwargs)) @@ -1334,7 +1402,7 @@ def resnet50c(pretrained: bool = False, **kwargs) -> ResNet: def resnet50d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-50-D model. """ - model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True) + model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep', avg_down=True) return _create_resnet('resnet50d', pretrained, **dict(model_args, **kwargs)) @@ -1342,7 +1410,7 @@ def resnet50d(pretrained: bool = False, **kwargs) -> ResNet: def resnet50s(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-50-S model. """ - model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=64, stem_type='deep') + model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), stem_width=64, stem_type='deep') return _create_resnet('resnet50s', pretrained, **dict(model_args, **kwargs)) @@ -1350,7 +1418,7 @@ def resnet50s(pretrained: bool = False, **kwargs) -> ResNet: def resnet50t(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-50-T model. """ - model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered', avg_down=True) + model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep_tiered', avg_down=True) return _create_resnet('resnet50t', pretrained, **dict(model_args, **kwargs)) @@ -1358,7 +1426,7 @@ def resnet50t(pretrained: bool = False, **kwargs) -> ResNet: def resnet101(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-101 model. """ - model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3]) + model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3)) return _create_resnet('resnet101', pretrained, **dict(model_args, **kwargs)) @@ -1366,7 +1434,7 @@ def resnet101(pretrained: bool = False, **kwargs) -> ResNet: def resnet101c(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-101-C model. """ - model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep') + model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), stem_width=32, stem_type='deep') return _create_resnet('resnet101c', pretrained, **dict(model_args, **kwargs)) @@ -1374,7 +1442,7 @@ def resnet101c(pretrained: bool = False, **kwargs) -> ResNet: def resnet101d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-101-D model. """ - model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True) + model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), stem_width=32, stem_type='deep', avg_down=True) return _create_resnet('resnet101d', pretrained, **dict(model_args, **kwargs)) @@ -1382,7 +1450,7 @@ def resnet101d(pretrained: bool = False, **kwargs) -> ResNet: def resnet101s(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-101-S model. """ - model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=64, stem_type='deep') + model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), stem_width=64, stem_type='deep') return _create_resnet('resnet101s', pretrained, **dict(model_args, **kwargs)) @@ -1390,7 +1458,7 @@ def resnet101s(pretrained: bool = False, **kwargs) -> ResNet: def resnet152(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-152 model. """ - model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3]) + model_args = dict(block=Bottleneck, layers=(3, 8, 36, 3)) return _create_resnet('resnet152', pretrained, **dict(model_args, **kwargs)) @@ -1398,7 +1466,7 @@ def resnet152(pretrained: bool = False, **kwargs) -> ResNet: def resnet152c(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-152-C model. """ - model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep') + model_args = dict(block=Bottleneck, layers=(3, 8, 36, 3), stem_width=32, stem_type='deep') return _create_resnet('resnet152c', pretrained, **dict(model_args, **kwargs)) @@ -1406,7 +1474,7 @@ def resnet152c(pretrained: bool = False, **kwargs) -> ResNet: def resnet152d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-152-D model. """ - model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True) + model_args = dict(block=Bottleneck, layers=(3, 8, 36, 3), stem_width=32, stem_type='deep', avg_down=True) return _create_resnet('resnet152d', pretrained, **dict(model_args, **kwargs)) @@ -1414,7 +1482,7 @@ def resnet152d(pretrained: bool = False, **kwargs) -> ResNet: def resnet152s(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-152-S model. """ - model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], stem_width=64, stem_type='deep') + model_args = dict(block=Bottleneck, layers=(3, 8, 36, 3), stem_width=64, stem_type='deep') return _create_resnet('resnet152s', pretrained, **dict(model_args, **kwargs)) @@ -1422,7 +1490,7 @@ def resnet152s(pretrained: bool = False, **kwargs) -> ResNet: def resnet200(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-200 model. """ - model_args = dict(block=Bottleneck, layers=[3, 24, 36, 3]) + model_args = dict(block=Bottleneck, layers=(3, 24, 36, 3)) return _create_resnet('resnet200', pretrained, **dict(model_args, **kwargs)) @@ -1430,7 +1498,7 @@ def resnet200(pretrained: bool = False, **kwargs) -> ResNet: def resnet200d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-200-D model. """ - model_args = dict(block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True) + model_args = dict(block=Bottleneck, layers=(3, 24, 36, 3), stem_width=32, stem_type='deep', avg_down=True) return _create_resnet('resnet200d', pretrained, **dict(model_args, **kwargs)) @@ -1442,7 +1510,7 @@ def wide_resnet50_2(pretrained: bool = False, **kwargs) -> ResNet: convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 channels, and in Wide ResNet-50-2 has 2048-1024-2048. """ - model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], base_width=128) + model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), base_width=128) return _create_resnet('wide_resnet50_2', pretrained, **dict(model_args, **kwargs)) @@ -1453,7 +1521,7 @@ def wide_resnet101_2(pretrained: bool = False, **kwargs) -> ResNet: which is twice larger in every block. The number of channels in outer 1x1 convolutions is the same. """ - model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], base_width=128) + model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), base_width=128) return _create_resnet('wide_resnet101_2', pretrained, **dict(model_args, **kwargs)) @@ -1461,7 +1529,7 @@ def wide_resnet101_2(pretrained: bool = False, **kwargs) -> ResNet: def resnet50_gn(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-50 model w/ GroupNorm """ - model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], norm_layer='groupnorm') + model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), norm_layer='groupnorm') return _create_resnet('resnet50_gn', pretrained, **dict(model_args, **kwargs)) @@ -1469,7 +1537,7 @@ def resnet50_gn(pretrained: bool = False, **kwargs) -> ResNet: def resnext50_32x4d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNeXt50-32x4d model. """ - model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4) + model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), cardinality=32, base_width=4) return _create_resnet('resnext50_32x4d', pretrained, **dict(model_args, **kwargs)) @@ -1478,7 +1546,7 @@ def resnext50d_32x4d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNeXt50d-32x4d model. ResNext50 w/ deep stem & avg pool downsample """ model_args = dict( - block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, + block=Bottleneck, layers=(3, 4, 6, 3), cardinality=32, base_width=4, stem_width=32, stem_type='deep', avg_down=True) return _create_resnet('resnext50d_32x4d', pretrained, **dict(model_args, **kwargs)) @@ -1487,7 +1555,7 @@ def resnext50d_32x4d(pretrained: bool = False, **kwargs) -> ResNet: def resnext101_32x4d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNeXt-101 32x4d model. """ - model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4) + model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=4) return _create_resnet('resnext101_32x4d', pretrained, **dict(model_args, **kwargs)) @@ -1495,7 +1563,7 @@ def resnext101_32x4d(pretrained: bool = False, **kwargs) -> ResNet: def resnext101_32x8d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNeXt-101 32x8d model. """ - model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8) + model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=8) return _create_resnet('resnext101_32x8d', pretrained, **dict(model_args, **kwargs)) @@ -1503,7 +1571,7 @@ def resnext101_32x8d(pretrained: bool = False, **kwargs) -> ResNet: def resnext101_32x16d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNeXt-101 32x16d model """ - model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16) + model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=16) return _create_resnet('resnext101_32x16d', pretrained, **dict(model_args, **kwargs)) @@ -1511,7 +1579,7 @@ def resnext101_32x16d(pretrained: bool = False, **kwargs) -> ResNet: def resnext101_32x32d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNeXt-101 32x32d model """ - model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=32) + model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=32) return _create_resnet('resnext101_32x32d', pretrained, **dict(model_args, **kwargs)) @@ -1519,7 +1587,7 @@ def resnext101_32x32d(pretrained: bool = False, **kwargs) -> ResNet: def resnext101_64x4d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNeXt101-64x4d model. """ - model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4) + model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), cardinality=64, base_width=4) return _create_resnet('resnext101_64x4d', pretrained, **dict(model_args, **kwargs)) @@ -1530,7 +1598,7 @@ def ecaresnet26t(pretrained: bool = False, **kwargs) -> ResNet: in the deep stem and ECA attn. """ model_args = dict( - block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, + block=Bottleneck, layers=(2, 2, 2, 2), stem_width=32, stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca')) return _create_resnet('ecaresnet26t', pretrained, **dict(model_args, **kwargs)) @@ -1540,7 +1608,7 @@ def ecaresnet50d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-50-D model with eca. """ model_args = dict( - block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, + block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='eca')) return _create_resnet('ecaresnet50d', pretrained, **dict(model_args, **kwargs)) @@ -1551,7 +1619,7 @@ def ecaresnet50d_pruned(pretrained: bool = False, **kwargs) -> ResNet: The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """ model_args = dict( - block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, + block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='eca')) return _create_resnet('ecaresnet50d_pruned', pretrained, pruned=True, **dict(model_args, **kwargs)) @@ -1562,7 +1630,7 @@ def ecaresnet50t(pretrained: bool = False, **kwargs) -> ResNet: Like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels in the deep stem and ECA attn. """ model_args = dict( - block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, + block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca')) return _create_resnet('ecaresnet50t', pretrained, **dict(model_args, **kwargs)) @@ -1572,7 +1640,7 @@ def ecaresnetlight(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-50-D light model with eca. """ model_args = dict( - block=Bottleneck, layers=[1, 1, 11, 3], stem_width=32, avg_down=True, + block=Bottleneck, layers=(1, 1, 11, 3), stem_width=32, avg_down=True, block_args=dict(attn_layer='eca')) return _create_resnet('ecaresnetlight', pretrained, **dict(model_args, **kwargs)) @@ -1582,7 +1650,7 @@ def ecaresnet101d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-101-D model with eca. """ model_args = dict( - block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, + block=Bottleneck, layers=(3, 4, 23, 3), stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='eca')) return _create_resnet('ecaresnet101d', pretrained, **dict(model_args, **kwargs)) @@ -1593,7 +1661,7 @@ def ecaresnet101d_pruned(pretrained: bool = False, **kwargs) -> ResNet: The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """ model_args = dict( - block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, + block=Bottleneck, layers=(3, 4, 23, 3), stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='eca')) return _create_resnet('ecaresnet101d_pruned', pretrained, pruned=True, **dict(model_args, **kwargs)) @@ -1603,7 +1671,7 @@ def ecaresnet200d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-200-D model with ECA. """ model_args = dict( - block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True, + block=Bottleneck, layers=(3, 24, 36, 3), stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='eca')) return _create_resnet('ecaresnet200d', pretrained, **dict(model_args, **kwargs)) @@ -1613,7 +1681,7 @@ def ecaresnet269d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-269-D model with ECA. """ model_args = dict( - block=Bottleneck, layers=[3, 30, 48, 8], stem_width=32, stem_type='deep', avg_down=True, + block=Bottleneck, layers=(3, 30, 48, 8), stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='eca')) return _create_resnet('ecaresnet269d', pretrained, **dict(model_args, **kwargs)) @@ -1625,7 +1693,7 @@ def ecaresnext26t_32x4d(pretrained: bool = False, **kwargs) -> ResNet: in the deep stem. This model replaces SE module with the ECA module """ model_args = dict( - block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, + block=Bottleneck, layers=(2, 2, 2, 2), cardinality=32, base_width=4, stem_width=32, stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca')) return _create_resnet('ecaresnext26t_32x4d', pretrained, **dict(model_args, **kwargs)) @@ -1637,53 +1705,53 @@ def ecaresnext50t_32x4d(pretrained: bool = False, **kwargs) -> ResNet: in the deep stem. This model replaces SE module with the ECA module """ model_args = dict( - block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, + block=Bottleneck, layers=(2, 2, 2, 2), cardinality=32, base_width=4, stem_width=32, stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca')) return _create_resnet('ecaresnext50t_32x4d', pretrained, **dict(model_args, **kwargs)) @register_model def seresnet18(pretrained: bool = False, **kwargs) -> ResNet: - model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='se')) + model_args = dict(block=BasicBlock, layers=(2, 2, 2, 2), block_args=dict(attn_layer='se')) return _create_resnet('seresnet18', pretrained, **dict(model_args, **kwargs)) @register_model def seresnet34(pretrained: bool = False, **kwargs) -> ResNet: - model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se')) + model_args = dict(block=BasicBlock, layers=(3, 4, 6, 3), block_args=dict(attn_layer='se')) return _create_resnet('seresnet34', pretrained, **dict(model_args, **kwargs)) @register_model def seresnet50(pretrained: bool = False, **kwargs) -> ResNet: - model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se')) + model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), block_args=dict(attn_layer='se')) return _create_resnet('seresnet50', pretrained, **dict(model_args, **kwargs)) @register_model def seresnet50t(pretrained: bool = False, **kwargs) -> ResNet: model_args = dict( - block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered', + block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='se')) return _create_resnet('seresnet50t', pretrained, **dict(model_args, **kwargs)) @register_model def seresnet101(pretrained: bool = False, **kwargs) -> ResNet: - model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], block_args=dict(attn_layer='se')) + model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), block_args=dict(attn_layer='se')) return _create_resnet('seresnet101', pretrained, **dict(model_args, **kwargs)) @register_model def seresnet152(pretrained: bool = False, **kwargs) -> ResNet: - model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], block_args=dict(attn_layer='se')) + model_args = dict(block=Bottleneck, layers=(3, 8, 36, 3), block_args=dict(attn_layer='se')) return _create_resnet('seresnet152', pretrained, **dict(model_args, **kwargs)) @register_model def seresnet152d(pretrained: bool = False, **kwargs) -> ResNet: model_args = dict( - block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', + block=Bottleneck, layers=(3, 8, 36, 3), stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='se')) return _create_resnet('seresnet152d', pretrained, **dict(model_args, **kwargs)) @@ -1693,7 +1761,7 @@ def seresnet200d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-200-D model with SE attn. """ model_args = dict( - block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', + block=Bottleneck, layers=(3, 24, 36, 3), stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='se')) return _create_resnet('seresnet200d', pretrained, **dict(model_args, **kwargs)) @@ -1703,7 +1771,7 @@ def seresnet269d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-269-D model with SE attn. """ model_args = dict( - block=Bottleneck, layers=[3, 30, 48, 8], stem_width=32, stem_type='deep', + block=Bottleneck, layers=(3, 30, 48, 8), stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='se')) return _create_resnet('seresnet269d', pretrained, **dict(model_args, **kwargs)) @@ -1715,7 +1783,7 @@ def seresnext26d_32x4d(pretrained: bool = False, **kwargs) -> ResNet: combination of deep stem and avg_pool in downsample. """ model_args = dict( - block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, + block=Bottleneck, layers=(2, 2, 2, 2), cardinality=32, base_width=4, stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='se')) return _create_resnet('seresnext26d_32x4d', pretrained, **dict(model_args, **kwargs)) @@ -1727,7 +1795,7 @@ def seresnext26t_32x4d(pretrained: bool = False, **kwargs) -> ResNet: in the deep stem. """ model_args = dict( - block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, + block=Bottleneck, layers=(2, 2, 2, 2), cardinality=32, base_width=4, stem_width=32, stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='se')) return _create_resnet('seresnext26t_32x4d', pretrained, **dict(model_args, **kwargs)) @@ -1735,7 +1803,7 @@ def seresnext26t_32x4d(pretrained: bool = False, **kwargs) -> ResNet: @register_model def seresnext50_32x4d(pretrained: bool = False, **kwargs) -> ResNet: model_args = dict( - block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, + block=Bottleneck, layers=(3, 4, 6, 3), cardinality=32, base_width=4, block_args=dict(attn_layer='se')) return _create_resnet('seresnext50_32x4d', pretrained, **dict(model_args, **kwargs)) @@ -1743,7 +1811,7 @@ def seresnext50_32x4d(pretrained: bool = False, **kwargs) -> ResNet: @register_model def seresnext101_32x4d(pretrained: bool = False, **kwargs) -> ResNet: model_args = dict( - block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, + block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=4, block_args=dict(attn_layer='se')) return _create_resnet('seresnext101_32x4d', pretrained, **dict(model_args, **kwargs)) @@ -1751,7 +1819,7 @@ def seresnext101_32x4d(pretrained: bool = False, **kwargs) -> ResNet: @register_model def seresnext101_32x8d(pretrained: bool = False, **kwargs) -> ResNet: model_args = dict( - block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, + block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=8, block_args=dict(attn_layer='se')) return _create_resnet('seresnext101_32x8d', pretrained, **dict(model_args, **kwargs)) @@ -1759,7 +1827,7 @@ def seresnext101_32x8d(pretrained: bool = False, **kwargs) -> ResNet: @register_model def seresnext101d_32x8d(pretrained: bool = False, **kwargs) -> ResNet: model_args = dict( - block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, + block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=8, stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='se')) return _create_resnet('seresnext101d_32x8d', pretrained, **dict(model_args, **kwargs)) @@ -1768,7 +1836,7 @@ def seresnext101d_32x8d(pretrained: bool = False, **kwargs) -> ResNet: @register_model def seresnext101_64x4d(pretrained: bool = False, **kwargs) -> ResNet: model_args = dict( - block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4, + block=Bottleneck, layers=(3, 4, 23, 3), cardinality=64, base_width=4, block_args=dict(attn_layer='se')) return _create_resnet('seresnext101_64x4d', pretrained, **dict(model_args, **kwargs)) @@ -1776,7 +1844,7 @@ def seresnext101_64x4d(pretrained: bool = False, **kwargs) -> ResNet: @register_model def senet154(pretrained: bool = False, **kwargs) -> ResNet: model_args = dict( - block=Bottleneck, layers=[3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep', + block=Bottleneck, layers=(3, 8, 36, 3), cardinality=64, base_width=4, stem_type='deep', down_kernel_size=3, block_reduce_first=2, block_args=dict(attn_layer='se')) return _create_resnet('senet154', pretrained, **dict(model_args, **kwargs)) @@ -1785,7 +1853,7 @@ def senet154(pretrained: bool = False, **kwargs) -> ResNet: def resnetblur18(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-18 model with blur anti-aliasing """ - model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], aa_layer=BlurPool2d) + model_args = dict(block=BasicBlock, layers=(2, 2, 2, 2), aa_layer=BlurPool2d) return _create_resnet('resnetblur18', pretrained, **dict(model_args, **kwargs)) @@ -1793,7 +1861,7 @@ def resnetblur18(pretrained: bool = False, **kwargs) -> ResNet: def resnetblur50(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-50 model with blur anti-aliasing """ - model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d) + model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), aa_layer=BlurPool2d) return _create_resnet('resnetblur50', pretrained, **dict(model_args, **kwargs)) @@ -1802,7 +1870,7 @@ def resnetblur50d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-50-D model with blur anti-aliasing """ model_args = dict( - block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d, + block=Bottleneck, layers=(3, 4, 6, 3), aa_layer=BlurPool2d, stem_width=32, stem_type='deep', avg_down=True) return _create_resnet('resnetblur50d', pretrained, **dict(model_args, **kwargs)) @@ -1812,7 +1880,7 @@ def resnetblur101d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-101-D model with blur anti-aliasing """ model_args = dict( - block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=BlurPool2d, + block=Bottleneck, layers=(3, 4, 23, 3), aa_layer=BlurPool2d, stem_width=32, stem_type='deep', avg_down=True) return _create_resnet('resnetblur101d', pretrained, **dict(model_args, **kwargs)) @@ -1822,7 +1890,7 @@ def resnetaa34d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-34-D model w/ avgpool anti-aliasing """ model_args = dict( - block=BasicBlock, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, stem_width=32, stem_type='deep', avg_down=True) + block=BasicBlock, layers=(3, 4, 6, 3), aa_layer=nn.AvgPool2d, stem_width=32, stem_type='deep', avg_down=True) return _create_resnet('resnetaa34d', pretrained, **dict(model_args, **kwargs)) @@ -1830,7 +1898,7 @@ def resnetaa34d(pretrained: bool = False, **kwargs) -> ResNet: def resnetaa50(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-50 model with avgpool anti-aliasing """ - model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d) + model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), aa_layer=nn.AvgPool2d) return _create_resnet('resnetaa50', pretrained, **dict(model_args, **kwargs)) @@ -1839,7 +1907,7 @@ def resnetaa50d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-50-D model with avgpool anti-aliasing """ model_args = dict( - block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, + block=Bottleneck, layers=(3, 4, 6, 3), aa_layer=nn.AvgPool2d, stem_width=32, stem_type='deep', avg_down=True) return _create_resnet('resnetaa50d', pretrained, **dict(model_args, **kwargs)) @@ -1849,7 +1917,7 @@ def resnetaa101d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a ResNet-101-D model with avgpool anti-aliasing """ model_args = dict( - block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=nn.AvgPool2d, + block=Bottleneck, layers=(3, 4, 23, 3), aa_layer=nn.AvgPool2d, stem_width=32, stem_type='deep', avg_down=True) return _create_resnet('resnetaa101d', pretrained, **dict(model_args, **kwargs)) @@ -1859,7 +1927,7 @@ def seresnetaa50d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a SE=ResNet-50-D model with avgpool anti-aliasing """ model_args = dict( - block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, + block=Bottleneck, layers=(3, 4, 6, 3), aa_layer=nn.AvgPool2d, stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='se')) return _create_resnet('seresnetaa50d', pretrained, **dict(model_args, **kwargs)) @@ -1869,7 +1937,7 @@ def seresnextaa101d_32x8d(pretrained: bool = False, **kwargs) -> ResNet: """Constructs a SE=ResNeXt-101-D 32x8d model with avgpool anti-aliasing """ model_args = dict( - block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, + block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=8, stem_width=32, stem_type='deep', avg_down=True, aa_layer=nn.AvgPool2d, block_args=dict(attn_layer='se')) return _create_resnet('seresnextaa101d_32x8d', pretrained, **dict(model_args, **kwargs)) @@ -1880,7 +1948,7 @@ def seresnextaa201d_32x8d(pretrained: bool = False, **kwargs): """Constructs a SE=ResNeXt-101-D 32x8d model with avgpool anti-aliasing """ model_args = dict( - block=Bottleneck, layers=[3, 24, 36, 4], cardinality=32, base_width=8, + block=Bottleneck, layers=(3, 24, 36, 4), cardinality=32, base_width=8, stem_width=64, stem_type='deep', avg_down=True, aa_layer=nn.AvgPool2d, block_args=dict(attn_layer='se')) return _create_resnet('seresnextaa201d_32x8d', pretrained, **dict(model_args, **kwargs)) @@ -1894,7 +1962,7 @@ def resnetrs50(pretrained: bool = False, **kwargs) -> ResNet: """ attn_layer = partial(get_attn('se'), rd_ratio=0.25) model_args = dict( - block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, + block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep', replace_stem_pool=True, avg_down=True, block_args=dict(attn_layer=attn_layer)) return _create_resnet('resnetrs50', pretrained, **dict(model_args, **kwargs)) @@ -1907,7 +1975,7 @@ def resnetrs101(pretrained: bool = False, **kwargs) -> ResNet: """ attn_layer = partial(get_attn('se'), rd_ratio=0.25) model_args = dict( - block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, + block=Bottleneck, layers=(3, 4, 23, 3), stem_width=32, stem_type='deep', replace_stem_pool=True, avg_down=True, block_args=dict(attn_layer=attn_layer)) return _create_resnet('resnetrs101', pretrained, **dict(model_args, **kwargs)) @@ -1920,7 +1988,7 @@ def resnetrs152(pretrained: bool = False, **kwargs) -> ResNet: """ attn_layer = partial(get_attn('se'), rd_ratio=0.25) model_args = dict( - block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, + block=Bottleneck, layers=(3, 8, 36, 3), stem_width=32, stem_type='deep', replace_stem_pool=True, avg_down=True, block_args=dict(attn_layer=attn_layer)) return _create_resnet('resnetrs152', pretrained, **dict(model_args, **kwargs)) @@ -1933,7 +2001,7 @@ def resnetrs200(pretrained: bool = False, **kwargs) -> ResNet: """ attn_layer = partial(get_attn('se'), rd_ratio=0.25) model_args = dict( - block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, + block=Bottleneck, layers=(3, 24, 36, 3), stem_width=32, stem_type='deep', replace_stem_pool=True, avg_down=True, block_args=dict(attn_layer=attn_layer)) return _create_resnet('resnetrs200', pretrained, **dict(model_args, **kwargs)) @@ -1946,7 +2014,7 @@ def resnetrs270(pretrained: bool = False, **kwargs) -> ResNet: """ attn_layer = partial(get_attn('se'), rd_ratio=0.25) model_args = dict( - block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, + block=Bottleneck, layers=(4, 29, 53, 4), stem_width=32, stem_type='deep', replace_stem_pool=True, avg_down=True, block_args=dict(attn_layer=attn_layer)) return _create_resnet('resnetrs270', pretrained, **dict(model_args, **kwargs)) @@ -1960,7 +2028,7 @@ def resnetrs350(pretrained: bool = False, **kwargs) -> ResNet: """ attn_layer = partial(get_attn('se'), rd_ratio=0.25) model_args = dict( - block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, + block=Bottleneck, layers=(4, 36, 72, 4), stem_width=32, stem_type='deep', replace_stem_pool=True, avg_down=True, block_args=dict(attn_layer=attn_layer)) return _create_resnet('resnetrs350', pretrained, **dict(model_args, **kwargs)) @@ -1973,7 +2041,7 @@ def resnetrs420(pretrained: bool = False, **kwargs) -> ResNet: """ attn_layer = partial(get_attn('se'), rd_ratio=0.25) model_args = dict( - block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, + block=Bottleneck, layers=(4, 44, 87, 4), stem_width=32, stem_type='deep', replace_stem_pool=True, avg_down=True, block_args=dict(attn_layer=attn_layer)) return _create_resnet('resnetrs420', pretrained, **dict(model_args, **kwargs)) diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index bb3f9508..5c0f7b4f 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -26,6 +26,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import PatchEmbed, Mlp, DropPath, ClassifierHead, to_2tuple, to_ntuple, trunc_normal_, \ _assert, use_fused_attn, resize_rel_pos_bias_table, resample_patch_embed, ndgrid from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._features_fx import register_notrace_function from ._manipulate import checkpoint_seq, named_apply from ._registry import generate_default_cfgs, register_model, register_model_deprecations @@ -607,6 +608,72 @@ class SwinTransformer(nn.Module): self.num_classes = num_classes self.head.reset(num_classes, pool_type=global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Union[int, List[int], Tuple[int]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.layers), indices) + + # forward pass + x = self.patch_embed(x) + + num_stages = len(self.layers) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.layers + else: + stages = self.layers[:max_index + 1] + for i, stage in enumerate(stages): + x = stage(x) + if i in take_indices: + if norm and i == num_stages - 1: + x_inter = self.norm(x) # applying final norm last intermediate + else: + x_inter = x + x_inter = x_inter.permute(0, 3, 1, 2).contiguous() + intermediates.append(x_inter) + + if intermediates_only: + return intermediates + + x = self.norm(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.layers), indices) + self.layers = self.layers[:max_index + 1] # truncate blocks + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.patch_embed(x) x = self.layers(x) diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index e8a5a9cf..a6ebb664 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -13,7 +13,7 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W # Written by Ze Liu # -------------------------------------------------------- import math -from typing import Callable, Optional, Tuple, Union, Set, Dict +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -24,6 +24,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, ClassifierHead,\ resample_patch_embed, ndgrid, get_act_layer, LayerType from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._features_fx import register_notrace_function from ._registry import generate_default_cfgs, register_model, register_model_deprecations @@ -608,6 +609,72 @@ class SwinTransformerV2(nn.Module): self.num_classes = num_classes self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Union[int, List[int], Tuple[int]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.layers), indices) + + # forward pass + x = self.patch_embed(x) + + num_stages = len(self.layers) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.layers + else: + stages = self.layers[:max_index + 1] + for i, stage in enumerate(stages): + x = stage(x) + if i in take_indices: + if norm and i == num_stages - 1: + x_inter = self.norm(x) # applying final norm last intermediate + else: + x_inter = x + x_inter = x_inter.permute(0, 3, 1, 2).contiguous() + intermediates.append(x_inter) + + if intermediates_only: + return intermediates + + x = self.norm(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.layers), indices) + self.layers = self.layers[:max_index + 1] # truncate blocks + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.patch_embed(x) x = self.layers(x) diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index 1aae8645..58cfcd36 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -39,6 +39,7 @@ import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import DropPath, Mlp, ClassifierHead, to_2tuple, _assert, ndgrid from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._features_fx import register_notrace_function from ._manipulate import named_apply from ._registry import generate_default_cfgs, register_model @@ -718,6 +719,62 @@ class SwinTransformerV2Cr(nn.Module): self.num_classes = num_classes self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Union[int, List[int], Tuple[int]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.patch_embed(x) + + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + for i, stage in enumerate(stages): + x = stage(x) + if i in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) x = self.stages(x) diff --git a/timm/models/twins.py b/timm/models/twins.py index 24f7b801..b87a9c79 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -409,7 +409,7 @@ class Twins(nn.Module): x: torch.Tensor, indices: Union[int, List[int], Tuple[int]] = None, norm: bool = False, - stop_early: bool = True, + stop_early: bool = False, output_fmt: str = 'NCHW', intermediates_only: bool = False, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: @@ -459,6 +459,22 @@ class Twins(nn.Module): return x, intermediates + def prune_intermediate_layers( + self, + indices: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.blocks), indices) + # FIXME add block pruning + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): B = x.shape[0] for i, (embed, drop, blocks, pos_blk) in enumerate( diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index c20564ba..e8fba8a9 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -638,7 +638,7 @@ class VisionTransformer(nn.Module): indices: Optional[Union[int, List[int], Tuple[int]]] = None, return_prefix_tokens: bool = False, norm: bool = False, - stop_early: bool = True, + stop_early: bool = False, output_fmt: str = 'NCHW', intermediates_only: bool = False, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: @@ -655,7 +655,7 @@ class VisionTransformer(nn.Module): Returns: """ - assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.' + assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.' reshape = output_fmt == 'NCHW' intermediates = [] take_indices, max_index = feature_take_indices(len(self.blocks), indices) @@ -666,6 +666,7 @@ class VisionTransformer(nn.Module): x = self._pos_embed(x) x = self.patch_drop(x) x = self.norm_pre(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript blocks = self.blocks else: @@ -698,21 +699,19 @@ class VisionTransformer(nn.Module): def prune_intermediate_layers( self, - n: Union[int, List[int], Tuple[int]] = 1, + indices: Union[int, List[int], Tuple[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): """ Prune layers not required for specified intermediates. """ - take_indices, max_index = feature_take_indices(len(self.blocks), n) + take_indices, max_index = feature_take_indices(len(self.blocks), indices) self.blocks = self.blocks[:max_index + 1] # truncate blocks if prune_norm: self.norm = nn.Identity() if prune_head: - if self.attn_pool is not None: - self.attn_pool = None self.fc_norm = nn.Identity() - self.head = nn.Identity() + self.reset_classifier(0, '') return take_indices def get_intermediate_layers( @@ -1791,12 +1790,27 @@ default_cfgs = { license='mit', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), - 'vit_medium_patch16_reg4_256': _cfg( - input_size=(3, 256, 256)), + 'vit_wee_patch16_reg1_gap_256': _cfg( + file='', + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_little_patch16_reg4_gap_256': _cfg( + file='', + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_medium_patch16_reg1_gap_256': _cfg( + file='vit_medium_gap1-in1k-20231118-8.pth', + input_size=(3, 256, 256), crop_pct=0.95), 'vit_medium_patch16_reg4_gap_256': _cfg( - input_size=(3, 256, 256)), + file='vit_medium_gap4-in1k-20231115-8.pth', + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_betwixt_patch16_reg1_gap_256': _cfg( + file='vit_betwixt_gap1-in1k-20231121-8.pth', + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_betwixt_patch16_reg4_gap_256': _cfg( + file='vit_betwixt_gap4-in1k-20231106-8.pth', + input_size=(3, 256, 256), crop_pct=0.95), 'vit_base_patch16_reg4_gap_256': _cfg( input_size=(3, 256, 256)), + 'vit_so150m_patch16_reg4_gap_256': _cfg( input_size=(3, 256, 256)), 'vit_so150m_patch16_reg4_map_256': _cfg( @@ -2083,6 +2097,18 @@ def vit_medium_patch16_gap_384(pretrained: bool = False, **kwargs) -> VisionTran return model +@register_model +def vit_betwixt_patch16_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Betwixt (ViT-b/16) w/o class token, w/ avg-pool @ 256x256 + """ + model_args = dict( + patch_size=16, embed_dim=640, depth=12, num_heads=10, class_token=False, + global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False) + model = _create_vision_transformer( + 'vit_medium_patch16_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + @register_model def vit_base_patch16_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 224x224 @@ -2714,21 +2740,54 @@ def vit_so400m_patch14_siglip_384(pretrained: bool = False, **kwargs) -> VisionT return model +# @register_model +# def vit_medium_patch16_reg4_256(pretrained: bool = False, **kwargs) -> VisionTransformer: +# model_args = dict( +# patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=True, +# no_embed_class=True, reg_tokens=4, +# ) +# model = _create_vision_transformer( +# 'vit_medium_patch16_reg4_256', pretrained=pretrained, **dict(model_args, **kwargs)) +# return model + + @register_model -def vit_medium_patch16_reg4_256(pretrained: bool = False, **kwargs) -> VisionTransformer: +def vit_wee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: model_args = dict( - patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=True, - no_embed_class=True, reg_tokens=4, + patch_size=16, embed_dim=256, depth=14, num_heads=4, init_values=1e-5, mlp_ratio=5, + class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', block_fn=ParallelScalingBlock, ) model = _create_vision_transformer( - 'vit_medium_patch16_reg4_256', pretrained=pretrained, **dict(model_args, **kwargs)) + 'vit_medium_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_little_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=320, depth=14, num_heads=5, init_values=1e-5, mlp_ratio=5.6, + class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg', + ) + model = _create_vision_transformer( + 'vit_medium_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_medium_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=512, depth=12, num_heads=8, init_values=1e-5, + class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', + ) + model = _create_vision_transformer( + 'vit_medium_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def vit_medium_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: model_args = dict( - patch_size=16, embed_dim=512, depth=12, num_heads=8, + patch_size=16, embed_dim=512, depth=12, num_heads=8, init_values=1e-5, class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg', ) model = _create_vision_transformer( @@ -2736,6 +2795,28 @@ def vit_medium_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio return model +@register_model +def vit_betwixt_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=640, depth=12, num_heads=10, init_values=1e-5, + class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', + ) + model = _create_vision_transformer( + 'vit_betwixt_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_betwixt_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=640, depth=12, num_heads=10, init_values=1e-5, + class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg', + ) + model = _create_vision_transformer( + 'vit_betwixt_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + @register_model def vit_base_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: model_args = dict( diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index 8461ada3..cd477009 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -394,7 +394,7 @@ class VisionTransformerRelPos(nn.Module): indices: Optional[Union[int, List[int], Tuple[int]]] = None, return_prefix_tokens: bool = False, norm: bool = False, - stop_early: bool = True, + stop_early: bool = False, output_fmt: str = 'NCHW', intermediates_only: bool = False, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: @@ -411,7 +411,7 @@ class VisionTransformerRelPos(nn.Module): Returns: """ - assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.' + assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.' reshape = output_fmt == 'NCHW' intermediates = [] take_indices, max_index = feature_take_indices(len(self.blocks), indices) @@ -455,19 +455,19 @@ class VisionTransformerRelPos(nn.Module): def prune_intermediate_layers( self, - n: Union[int, List[int], Tuple[int]] = 1, + indices: Union[int, List[int], Tuple[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): """ Prune layers not required for specified intermediates. """ - take_indices, max_index = feature_take_indices(len(self.blocks), n) + take_indices, max_index = feature_take_indices(len(self.blocks), indices) self.blocks = self.blocks[:max_index + 1] # truncate blocks if prune_norm: self.norm = nn.Identity() if prune_head: self.fc_norm = nn.Identity() - self.head = nn.Identity() + self.reset_classifier(0, '') return take_indices def forward_features(self, x): diff --git a/timm/models/vision_transformer_sam.py b/timm/models/vision_transformer_sam.py index d4d974a0..7bf6363f 100644 --- a/timm/models/vision_transformer_sam.py +++ b/timm/models/vision_transformer_sam.py @@ -545,7 +545,7 @@ class VisionTransformerSAM(nn.Module): x: torch.Tensor, indices: Union[int, List[int], Tuple[int]] = None, norm: bool = False, - stop_early: bool = True, + stop_early: bool = False, output_fmt: str = 'NCHW', intermediates_only: bool = False, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: @@ -573,6 +573,7 @@ class VisionTransformerSAM(nn.Module): x = self.pos_drop(x) x = self.patch_drop(x) x = self.norm_pre(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript blocks = self.blocks else: @@ -597,19 +598,19 @@ class VisionTransformerSAM(nn.Module): def prune_intermediate_layers( self, - n: Union[int, List[int], Tuple[int]] = None, + indices: Union[int, List[int], Tuple[int]] = None, prune_norm: bool = False, prune_head: bool = True, ): """ Prune layers not required for specified intermediates. """ - take_indices, max_index = feature_take_indices(len(self.blocks), n) + take_indices, max_index = feature_take_indices(len(self.blocks), indices) self.blocks = self.blocks[:max_index + 1] # truncate blocks if prune_norm: # neck is being treated as equivalent to final norm here self.neck = nn.Identity() if prune_head: - self.head = nn.Identity() + self.reset_classifier(0, '') return take_indices def forward_features(self, x): diff --git a/timm/models/volo.py b/timm/models/volo.py index d997f909..a9ff905c 100644 --- a/timm/models/volo.py +++ b/timm/models/volo.py @@ -713,7 +713,7 @@ class VOLO(nn.Module): x: torch.Tensor, indices: Optional[Union[int, List[int], Tuple[int]]] = None, norm: bool = False, - stop_early: bool = True, + stop_early: bool = False, output_fmt: str = 'NCHW', intermediates_only: bool = False, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: @@ -751,8 +751,11 @@ class VOLO(nn.Module): x = self.pos_drop(x) x = block(x) if idx in take_indices: - # normalize intermediates with final norm layer if enabled - intermediates.append(x.permute(0, 3, 1, 2)) + if norm and idx >= 2: + x_inter = self.norm(x) + else: + x_inter = x + intermediates.append(x_inter.permute(0, 3, 1, 2)) if intermediates_only: return intermediates @@ -769,20 +772,20 @@ class VOLO(nn.Module): def prune_intermediate_layers( self, - n: Union[int, List[int], Tuple[int]] = 1, + indices: Union[int, List[int], Tuple[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): """ Prune layers not required for specified intermediates. """ - take_indices, max_index = feature_take_indices(len(self.stage_ends), n) + take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) max_index = self.stage_ends[max_index] self.network = self.network[:max_index + 1] # truncate blocks if prune_norm: self.norm = nn.Identity() if prune_head: self.post_network = nn.ModuleList() # prune token blocks with head - self.head = nn.Identity() + self.reset_classifier(0, '') return take_indices def forward_features(self, x): diff --git a/timm/models/xcit.py b/timm/models/xcit.py index 941f7bf2..0e6e118e 100644 --- a/timm/models/xcit.py +++ b/timm/models/xcit.py @@ -444,7 +444,7 @@ class Xcit(nn.Module): x: torch.Tensor, indices: Optional[Union[int, List[int], Tuple[int]]] = None, norm: bool = False, - stop_early: bool = True, + stop_early: bool = False, output_fmt: str = 'NCHW', intermediates_only: bool = False, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: @@ -460,7 +460,7 @@ class Xcit(nn.Module): Returns: """ - assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.' + assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.' reshape = output_fmt == 'NCHW' intermediates = [] take_indices, max_index = feature_take_indices(len(self.blocks), indices) @@ -468,7 +468,6 @@ class Xcit(nn.Module): # forward pass B, _, height, width = x.shape x, (Hp, Wp) = self.patch_embed(x) - if self.pos_embed is not None: # `pos_embed` (B, C, Hp, Wp), reshape -> (B, C, N), permute -> (B, N, C) pos_encoding = self.pos_embed(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1) @@ -503,19 +502,19 @@ class Xcit(nn.Module): def prune_intermediate_layers( self, - n: Union[int, List[int], Tuple[int]] = 1, + indices: Union[int, List[int], Tuple[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): """ Prune layers not required for specified intermediates. """ - take_indices, max_index = feature_take_indices(len(self.blocks), n) + take_indices, max_index = feature_take_indices(len(self.blocks), indices) self.blocks = self.blocks[:max_index + 1] # truncate blocks if prune_norm: self.norm = nn.Identity() if prune_head: self.cls_attn_blocks = nn.ModuleList() # prune token blocks with head - self.head = nn.Identity() + self.reset_classifier(0, '') return take_indices def forward_features(self, x):