Merge pull request #2165 from huggingface/more_vits_better_getter

Add features_intermediate() support to coatnet, maxvit, swin*. Tweak feature extraction interfaces. Prep new vit weights while testing.
pull/2169/head
Ross Wightman 2024-05-03 19:26:01 -07:00 committed by GitHub
commit 3e7ab12af9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 1171 additions and 192 deletions

View File

@ -49,8 +49,9 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
# models with forward_intermediates() and support for FeatureGetterNet features_only wrapper
FEAT_INTER_FILTERS = [
'vit_*', 'twins_*', 'deit*', 'beit*', 'mvitv2*', 'eva*', 'samvit_*', 'flexivit*',
'cait_*', 'xcit_*', 'volo_*',
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet'
]
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
@ -388,13 +389,12 @@ def test_model_forward_features(model_name, batch_size):
@pytest.mark.features
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS))
@pytest.mark.parametrize('model_name', list_models(module=FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS + ['*pruned*']))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_intermediates_features(model_name, batch_size):
"""Run a single forward pass with each model in feature extraction mode"""
model = create_model(model_name, pretrained=False, features_only=True)
model = create_model(model_name, pretrained=False, features_only=True, feature_cls='getter')
model.eval()
print(model.feature_info.out_indices)
expected_channels = model.feature_info.channels()
expected_reduction = model.feature_info.reduction()
@ -420,7 +420,7 @@ def test_model_forward_intermediates_features(model_name, batch_size):
@pytest.mark.features
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS))
@pytest.mark.parametrize('model_name', list_models(module=FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS + ['*pruned*']))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_intermediates(model_name, batch_size):
"""Run a single forward pass with each model in feature extraction mode"""
@ -429,18 +429,19 @@ def test_model_forward_intermediates(model_name, batch_size):
feature_info = timm.models.FeatureInfo(model.feature_info, len(model.feature_info))
expected_channels = feature_info.channels()
expected_reduction = feature_info.reduction()
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
assert len(expected_channels) >= 3 # all models here should have at least 3 feature levels
input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE)
if max(input_size) > MAX_FFEAT_SIZE:
pytest.skip("Fixed input size model > limit.")
output_fmt = getattr(model, 'output_fmt', 'NCHW')
output_fmt = 'NCHW' # NOTE output_fmt determined by forward_intermediates() arg, not model attribute
feat_axis = get_channel_dim(output_fmt)
spatial_axis = get_spatial_dim(output_fmt)
import math
output, intermediates = model.forward_intermediates(
torch.randn((batch_size, *input_size)),
output_fmt=output_fmt,
)
assert len(expected_channels) == len(intermediates)
spatial_size = input_size[-2:]

View File

@ -134,6 +134,7 @@ class SelectAdaptivePool2d(nn.Module):
super(SelectAdaptivePool2d, self).__init__()
assert input_fmt in ('NCHW', 'NHWC')
self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
pool_type = pool_type.lower()
if not pool_type:
self.pool = nn.Identity() # pass through
self.flatten = nn.Flatten(1) if flatten else nn.Identity()
@ -145,8 +146,10 @@ class SelectAdaptivePool2d(nn.Module):
self.pool = FastAdaptiveAvgMaxPool(flatten, input_fmt=input_fmt)
elif pool_type.endswith('max'):
self.pool = FastAdaptiveMaxPool(flatten, input_fmt=input_fmt)
else:
elif pool_type == 'fast' or pool_type.endswith('avg'):
self.pool = FastAdaptiveAvgPool(flatten, input_fmt=input_fmt)
else:
assert False, 'Invalid pool type: %s' % pool_type
self.flatten = nn.Identity()
else:
assert input_fmt == 'NCHW'
@ -156,8 +159,10 @@ class SelectAdaptivePool2d(nn.Module):
self.pool = AdaptiveCatAvgMaxPool2d(output_size)
elif pool_type == 'max':
self.pool = nn.AdaptiveMaxPool2d(output_size)
else:
elif pool_type == 'avg':
self.pool = nn.AdaptiveAvgPool2d(output_size)
else:
assert False, 'Invalid pool type: %s' % pool_type
self.flatten = nn.Flatten(1) if flatten else nn.Identity()
def is_identity(self):

View File

@ -2,7 +2,7 @@ import dataclasses
import logging
import os
from copy import deepcopy
from typing import Optional, Dict, Callable, Any, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple
from torch import nn as nn
from torch.hub import load_state_dict_from_url
@ -359,15 +359,15 @@ def build_model_with_cfg(
* pruning config / model adaptation
Args:
model_cls (nn.Module): model class
variant (str): model variant name
pretrained (bool): load pretrained weights
pretrained_cfg (dict): model's pretrained weight/task config
model_cfg (Optional[Dict]): model's architecture config
feature_cfg (Optional[Dict]: feature extraction adapter config
pretrained_strict (bool): load pretrained weights strictly
pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights
kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model
model_cls: model class
variant: model variant name
pretrained: load pretrained weights
pretrained_cfg: model's pretrained weight/task config
model_cfg: model's architecture config
feature_cfg: feature extraction adapter config
pretrained_strict: load pretrained weights strictly
pretrained_filter_fn: filter callable for pretrained weights
kwargs_filter: kwargs to filter before passing to model
**kwargs: model args passed through to model __init__
"""
pruned = kwargs.pop('pruned', False)
@ -392,6 +392,8 @@ def build_model_with_cfg(
feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
if 'out_indices' in kwargs:
feature_cfg['out_indices'] = kwargs.pop('out_indices')
if 'feature_cls' in kwargs:
feature_cfg['feature_cls'] = kwargs.pop('feature_cls')
# Instantiate the model
if model_cfg is None:
@ -418,24 +420,36 @@ def build_model_with_cfg(
# Wrap the model in a feature extraction module if enabled
if features:
feature_cls = FeatureListNet
output_fmt = getattr(model, 'output_fmt', None)
if output_fmt is not None:
feature_cfg.setdefault('output_fmt', output_fmt)
use_getter = False
if 'feature_cls' in feature_cfg:
feature_cls = feature_cfg.pop('feature_cls')
if isinstance(feature_cls, str):
feature_cls = feature_cls.lower()
# flatten_sequential only valid for some feature extractors
if feature_cls not in ('dict', 'list', 'hook'):
feature_cfg.pop('flatten_sequential', None)
if 'hook' in feature_cls:
feature_cls = FeatureHookNet
elif feature_cls == 'list':
feature_cls = FeatureListNet
elif feature_cls == 'dict':
feature_cls = FeatureDictNet
elif feature_cls == 'fx':
feature_cls = FeatureGraphNet
elif feature_cls == 'getter':
use_getter = True
feature_cls = FeatureGetterNet
else:
assert False, f'Unknown feature class {feature_cls}'
else:
feature_cls = FeatureListNet
output_fmt = getattr(model, 'output_fmt', None)
if output_fmt is not None and not use_getter: # don't set default for intermediate feat getter
feature_cfg.setdefault('output_fmt', output_fmt)
model = feature_cls(model, **feature_cfg)
model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back pretrained cfg
model.default_cfg = model.pretrained_cfg # alias for rename backwards compat (default_cfg -> pretrained_cfg)

View File

@ -363,7 +363,7 @@ class FeatureHookNet(nn.ModuleDict):
out_map: Optional[Sequence[Union[int, str]]] = None,
return_dict: bool = False,
output_fmt: str = 'NCHW',
no_rewrite: bool = False,
no_rewrite: Optional[bool] = None,
flatten_sequential: bool = False,
default_hook_type: str = 'forward',
):
@ -385,7 +385,8 @@ class FeatureHookNet(nn.ModuleDict):
self.return_dict = return_dict
self.output_fmt = Format(output_fmt)
self.grad_checkpointing = False
if no_rewrite is None:
no_rewrite = not flatten_sequential
layers = OrderedDict()
hooks = []
if no_rewrite:
@ -467,7 +468,7 @@ class FeatureGetterNet(nn.ModuleDict):
self.out_indices = out_indices
self.out_map = out_map
self.return_dict = return_dict
self.output_fmt = output_fmt
self.output_fmt = Format(output_fmt)
self.norm = norm
def forward(self, x):

View File

@ -15,7 +15,7 @@ except ImportError:
has_fx_feature_extraction = False
# Layers we went to treat as leaf modules
from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame
from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame, Format
from timm.layers.non_local_attn import BilinearAttnTransform
from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
from timm.layers.norm_act import (
@ -108,12 +108,14 @@ class FeatureGraphNet(nn.Module):
model: nn.Module,
out_indices: Tuple[int, ...],
out_map: Optional[Dict] = None,
output_fmt: str = 'NCHW',
):
super().__init__()
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
self.feature_info = _get_feature_info(model, out_indices)
if out_map is not None:
assert len(out_map) == len(out_indices)
self.output_fmt = Format(output_fmt)
return_nodes = _get_return_layers(self.feature_info, out_map)
self.graph_module = create_feature_extractor(model, return_nodes)

View File

@ -184,7 +184,7 @@ def _expand_filter(filter: str):
def list_models(
filter: Union[str, List[str]] = '',
module: str = '',
module: Union[str, List[str]] = '',
pretrained: bool = False,
exclude_filters: Union[str, List[str]] = '',
name_matches_cfg: bool = False,
@ -217,7 +217,16 @@ def list_models(
# FIXME should this be default behaviour? or default to include_tags=True?
include_tags = pretrained
all_models: Set[str] = _module_to_models[module] if module else set(_model_entrypoints.keys())
if not module:
all_models: Set[str] = set(_model_entrypoints.keys())
else:
if isinstance(module, str):
all_models: Set[str] = _module_to_models[module]
else:
assert isinstance(module, Sequence)
all_models: Set[str] = set()
for m in module:
all_models.update(_module_to_models[m])
all_models = all_models - _deprecated_models.keys() # remove deprecated models from listings
if include_tags:

View File

@ -407,7 +407,7 @@ class Beit(nn.Module):
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
return_prefix_tokens: bool = False,
norm: bool = False,
stop_early: bool = True,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
@ -424,7 +424,7 @@ class Beit(nn.Module):
Returns:
"""
assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.'
assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
reshape = output_fmt == 'NCHW'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
@ -436,6 +436,7 @@ class Beit(nn.Module):
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x)
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
blocks = self.blocks
@ -469,19 +470,19 @@ class Beit(nn.Module):
def prune_intermediate_layers(
self,
n: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.blocks), n)
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
self.blocks = self.blocks[:max_index + 1] # truncate blocks
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.fc_norm = nn.Identity()
self.head = nn.Identity()
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):

View File

@ -343,7 +343,7 @@ class Cait(nn.Module):
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
norm: bool = False,
stop_early: bool = True,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
@ -357,7 +357,7 @@ class Cait(nn.Module):
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
"""
assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.'
assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
reshape = output_fmt == 'NCHW'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
@ -367,6 +367,7 @@ class Cait(nn.Module):
x = self.patch_embed(x)
x = x + self.pos_embed
x = self.pos_drop(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
blocks = self.blocks
else:
@ -397,19 +398,19 @@ class Cait(nn.Module):
def prune_intermediate_layers(
self,
n: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.blocks), n)
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
self.blocks = self.blocks[:max_index + 1] # truncate blocks
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.blocks_token_only = nn.ModuleList() # prune token blocks with head
self.head = nn.Identity()
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):

View File

@ -39,7 +39,7 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W
from collections import OrderedDict
from functools import partial
from typing import Callable, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -49,6 +49,7 @@ from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalRespo
LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
from timm.layers import NormMlpClassifierHead, ClassifierHead
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import named_apply, checkpoint_seq
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
@ -407,6 +408,71 @@ class ConvNeXt(nn.Module):
def reset_classifier(self, num_classes=0, global_pool=None):
self.head.reset(num_classes, global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Union[int, List[int], Tuple[int]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)
# forward pass
feat_idx = 0 # stem is index 0
x = self.stem(x)
if feat_idx in take_indices:
intermediates.append(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index]
for stage in stages:
feat_idx += 1
x = stage(x)
if feat_idx in take_indices:
# NOTE not bothering to apply norm_pre when norm=True as almost no models have it enabled
intermediates.append(x)
if intermediates_only:
return intermediates
x = self.norm_pre(x)
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)
self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm_pre = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)

View File

@ -12,7 +12,7 @@ Based on Apache 2.0 licensed code at https://github.com/snap-research/EfficientF
Modifications and timm support by / Copyright 2022, Ross Wightman
"""
from typing import Dict
from typing import Dict, List, Tuple, Union
import torch
import torch.nn as nn
@ -20,6 +20,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp, ndgrid
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model
@ -382,16 +383,19 @@ class EfficientFormer(nn.Module):
prev_dim = embed_dims[0]
# stochastic depth decay rule
self.num_stages = len(depths)
last_stage = self.num_stages - 1
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
downsamples = downsamples or (False,) + (True,) * (len(depths) - 1)
downsamples = downsamples or (False,) + (True,) * (self.num_stages - 1)
stages = []
for i in range(len(depths)):
self.feature_info = []
for i in range(self.num_stages):
stage = EfficientFormerStage(
prev_dim,
embed_dims[i],
depths[i],
downsample=downsamples[i],
num_vit=num_vit if i == 3 else 0,
num_vit=num_vit if i == last_stage else 0,
pool_size=pool_size,
mlp_ratio=mlp_ratios,
act_layer=act_layer,
@ -403,7 +407,7 @@ class EfficientFormer(nn.Module):
)
prev_dim = embed_dims[i]
stages.append(stage)
self.feature_info += [dict(num_chs=embed_dims[i], reduction=2**(1+i), module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)
# Classifier head
@ -456,6 +460,76 @@ class EfficientFormer(nn.Module):
def set_distilled_training(self, enable=True):
self.distilled_training = enable
def forward_intermediates(
self,
x: torch.Tensor,
indices: Union[int, List[int], Tuple[int]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
B, C, H, W = x.shape
last_idx = self.num_stages - 1
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
feat_idx = 0
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx < last_idx:
B, C, H, W = x.shape
if feat_idx in take_indices:
if feat_idx == last_idx:
x_inter = self.norm(x) if norm else x
intermediates.append(x_inter.reshape(B, H // 2, W // 2, -1).permute(0, 3, 1, 2))
else:
intermediates.append(x)
if intermediates_only:
return intermediates
if feat_idx == last_idx:
x = self.norm(x)
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)
@ -534,13 +608,13 @@ default_cfgs = generate_default_cfgs({
def _create_efficientformer(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for EfficientFormer models.')
out_indices = kwargs.pop('out_indices', 4)
model = build_model_with_cfg(
EfficientFormer, variant, pretrained,
pretrained_filter_fn=_checkpoint_filter_fn,
**kwargs)
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
**kwargs,
)
return model

View File

@ -36,7 +36,7 @@ the models and weights open source!
Hacked together by / Copyright 2019, Ross Wightman
"""
from functools import partial
from typing import List
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -49,7 +49,7 @@ from ._builder import build_model_with_cfg, pretrained_cfg_for_features
from ._efficientnet_blocks import SqueezeExcite
from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from ._features import FeatureInfo, FeatureHooks
from ._features import FeatureInfo, FeatureHooks, feature_take_indices
from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
@ -118,6 +118,7 @@ class EfficientNet(nn.Module):
)
self.blocks = nn.Sequential(*builder(stem_size, block_args))
self.feature_info = builder.features
self.stage_ends = [f['stage'] for f in self.feature_info]
head_chs = builder.in_chs
# Head + Pooling
@ -158,6 +159,86 @@ class EfficientNet(nn.Module):
self.global_pool, self.classifier = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Union[int, List[int], Tuple[int]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
extra_blocks: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
extra_blocks: Include outputs of all blocks and head conv in output, does not align with feature_info
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
if stop_early:
assert intermediates_only, 'Must use intermediates_only for early stopping.'
intermediates = []
if extra_blocks:
take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices)
else:
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
take_indices = [self.stage_ends[i] for i in take_indices]
max_index = self.stage_ends[max_index]
# forward pass
feat_idx = 0 # stem is index 0
x = self.conv_stem(x)
x = self.bn1(x)
if feat_idx in take_indices:
intermediates.append(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
blocks = self.blocks
else:
blocks = self.blocks[:max_index]
for blk in blocks:
feat_idx += 1
x = blk(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
x = self.conv_head(x)
x = self.bn2(x)
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
extra_blocks: bool = False,
):
""" Prune layers not required for specified intermediates.
"""
if extra_blocks:
take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices)
else:
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
max_index = self.stage_ends[max_index]
self.blocks = self.blocks[:max_index] # truncate blocks w/ stem as idx 0
if prune_norm or max_index < len(self.blocks):
self.conv_head = nn.Identity()
self.bn2 = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.conv_stem(x)
x = self.bn1(x)
@ -272,7 +353,7 @@ def _create_effnet(variant, pretrained=False, **kwargs):
model_cls = EfficientNet
kwargs_filter = None
if kwargs.pop('features_only', False):
if 'feature_cfg' in kwargs:
if 'feature_cfg' in kwargs or 'feature_cls' in kwargs:
features_mode = 'cfg'
else:
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'global_pool')

View File

@ -53,6 +53,7 @@ class EvaAttention(nn.Module):
num_heads: int = 8,
qkv_bias: bool = True,
qkv_fused: bool = True,
num_prefix_tokens: int = 1,
attn_drop: float = 0.,
proj_drop: float = 0.,
attn_head_dim: Optional[int] = None,
@ -77,6 +78,7 @@ class EvaAttention(nn.Module):
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = head_dim ** -0.5
self.num_prefix_tokens = num_prefix_tokens
self.fused_attn = use_fused_attn()
if qkv_fused:
@ -119,8 +121,9 @@ class EvaAttention(nn.Module):
v = self.v_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
if rope is not None:
q = torch.cat([q[:, :, :1, :], apply_rot_embed_cat(q[:, :, 1:, :], rope)], 2).type_as(v)
k = torch.cat([k[:, :, :1, :], apply_rot_embed_cat(k[:, :, 1:, :], rope)], 2).type_as(v)
npt = self.num_prefix_tokens
q = torch.cat([q[:, :, :npt, :], apply_rot_embed_cat(q[:, :, npt:, :], rope)], dim=2).type_as(v)
k = torch.cat([k[:, :, :npt, :], apply_rot_embed_cat(k[:, :, npt:, :], rope)], dim=2).type_as(v)
if self.fused_attn:
x = F.scaled_dot_product_attention(
@ -157,6 +160,7 @@ class EvaBlock(nn.Module):
swiglu_mlp: bool = False,
scale_mlp: bool = False,
scale_attn_inner: bool = False,
num_prefix_tokens: int = 1,
proj_drop: float = 0.,
attn_drop: float = 0.,
drop_path: float = 0.,
@ -191,6 +195,7 @@ class EvaBlock(nn.Module):
num_heads=num_heads,
qkv_bias=qkv_bias,
qkv_fused=qkv_fused,
num_prefix_tokens=num_prefix_tokens,
attn_drop=attn_drop,
proj_drop=proj_drop,
attn_head_dim=attn_head_dim,
@ -253,6 +258,7 @@ class EvaBlockPostNorm(nn.Module):
swiglu_mlp: bool = False,
scale_mlp: bool = False,
scale_attn_inner: bool = False,
num_prefix_tokens: int = 1,
proj_drop: float = 0.,
attn_drop: float = 0.,
drop_path: float = 0.,
@ -286,6 +292,7 @@ class EvaBlockPostNorm(nn.Module):
num_heads=num_heads,
qkv_bias=qkv_bias,
qkv_fused=qkv_fused,
num_prefix_tokens=num_prefix_tokens,
attn_drop=attn_drop,
proj_drop=proj_drop,
attn_head_dim=attn_head_dim,
@ -364,6 +371,7 @@ class Eva(nn.Module):
norm_layer: Callable = LayerNorm,
init_values: Optional[float] = None,
class_token: bool = True,
num_reg_tokens: int = 0,
use_abs_pos_emb: bool = True,
use_rot_pos_emb: bool = False,
use_post_norm: bool = False,
@ -407,7 +415,7 @@ class Eva(nn.Module):
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_prefix_tokens = 1 if class_token else 0
self.num_prefix_tokens = (1 if class_token else 0) + num_reg_tokens
self.dynamic_img_size = dynamic_img_size
self.grad_checkpointing = False
@ -427,6 +435,8 @@ class Eva(nn.Module):
r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
self.reg_token = nn.Parameter(torch.zeros(1, num_reg_tokens, embed_dim)) if num_reg_tokens else None
self.cls_embed = class_token and self.reg_token is None
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + self.num_prefix_tokens, embed_dim)) if use_abs_pos_emb else None
@ -463,6 +473,7 @@ class Eva(nn.Module):
swiglu_mlp=swiglu_mlp,
scale_mlp=scale_mlp,
scale_attn_inner=scale_attn_inner,
num_prefix_tokens=self.num_prefix_tokens,
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
@ -484,6 +495,8 @@ class Eva(nn.Module):
trunc_normal_(self.pos_embed, std=.02)
if self.cls_token is not None:
trunc_normal_(self.cls_token, std=.02)
if self.reg_token is not None:
trunc_normal_(self.reg_token, std=.02)
self.fix_init_weight()
if isinstance(self.head, nn.Linear):
@ -551,8 +564,17 @@ class Eva(nn.Module):
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
if pos_embed is not None:
x = x + pos_embed
if self.reg_token is not None:
to_cat = []
if self.cls_token is not None:
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
x = torch.cat(to_cat + [x], dim=1)
x = self.pos_drop(x)
# obtain shared rotary position embedding and apply patch dropout
@ -568,7 +590,7 @@ class Eva(nn.Module):
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
return_prefix_tokens: bool = False,
norm: bool = False,
stop_early: bool = True,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
@ -622,19 +644,19 @@ class Eva(nn.Module):
def prune_intermediate_layers(
self,
n: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.blocks), n)
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
self.blocks = self.blocks[:max_index + 1] # truncate blocks
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.fc_norm = nn.Identity()
self.head = nn.Identity()
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
@ -695,6 +717,12 @@ def checkpoint_filter_fn(
# fixed embedding no need to load buffer from checkpoint
continue
# FIXME here while import new weights, to remove
# if k == 'cls_token':
# print('DEBUG: cls token -> reg')
# k = 'reg_token'
# #v = v + state_dict['pos_embed'][0, :]
if 'patch_embed.proj.weight' in k:
_, _, H, W = model.patch_embed.proj.weight.shape
if v.shape[-1] != W or v.shape[-2] != H:
@ -923,6 +951,29 @@ default_cfgs = generate_default_cfgs({
num_classes=0,
),
'vit_medium_patch16_rope_reg1_gap_256.in1k': _cfg(
#hf_hub_id='timm/',
file='vit_medium_gap1_rope-in1k-20230920-5.pth',
input_size=(3, 256, 256), crop_pct=0.95,
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
),
'vit_mediumd_patch16_rope_reg1_gap_256.in1k': _cfg(
#hf_hub_id='timm/',
file='vit_mediumd_gap1_rope-in1k-20230926-5.pth',
input_size=(3, 256, 256), crop_pct=0.95,
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
),
'vit_betwixt_patch16_rope_reg4_gap_256.in1k': _cfg(
#hf_hub_id='timm/',
file='vit_betwixt_gap4_rope-in1k-20231005-5.pth',
input_size=(3, 256, 256), crop_pct=0.95,
),
'vit_base_patch16_rope_reg1_gap_256.in1k': _cfg(
# hf_hub_id='timm/',
file='vit_base_gap1_rope-in1k-20230930-5.pth',
input_size=(3, 256, 256), crop_pct=0.95,
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
),
})
@ -1185,3 +1236,87 @@ def eva02_enormous_patch14_clip_224(pretrained=False, **kwargs) -> Eva:
)
model = _create_eva('eva02_enormous_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_medium_patch16_rope_reg1_gap_256(pretrained=False, **kwargs) -> Eva:
model_args = dict(
img_size=256,
patch_size=16,
embed_dim=512,
depth=12,
num_heads=8,
qkv_fused=True,
qkv_bias=True,
init_values=1e-5,
class_token=False,
num_reg_tokens=1,
use_rot_pos_emb=True,
use_abs_pos_emb=False,
ref_feat_shape=(16, 16), # 224/14
)
model = _create_eva('vit_medium_patch16_rope_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_mediumd_patch16_rope_reg1_gap_256(pretrained=False, **kwargs) -> Eva:
model_args = dict(
img_size=256,
patch_size=16,
embed_dim=512,
depth=20,
num_heads=8,
qkv_fused=True,
qkv_bias=False,
init_values=1e-5,
class_token=False,
num_reg_tokens=1,
use_rot_pos_emb=True,
use_abs_pos_emb=False,
ref_feat_shape=(16, 16), # 224/14
)
model = _create_eva('vit_mediumd_patch16_rope_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_betwixt_patch16_rope_reg4_gap_256(pretrained=False, **kwargs) -> Eva:
model_args = dict(
img_size=256,
patch_size=16,
embed_dim=640,
depth=12,
num_heads=10,
qkv_fused=True,
qkv_bias=True,
init_values=1e-5,
class_token=False,
num_reg_tokens=4,
use_rot_pos_emb=True,
use_abs_pos_emb=False,
ref_feat_shape=(16, 16), # 224/14
)
model = _create_eva('vit_betwixt_patch16_rope_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_base_patch16_rope_reg1_gap_256(pretrained=False, **kwargs) -> Eva:
model_args = dict(
img_size=256,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
qkv_fused=True,
qkv_bias=True,
init_values=1e-5,
class_token=False,
num_reg_tokens=1,
use_rot_pos_emb=True,
use_abs_pos_emb=False,
ref_feat_shape=(16, 16), # 224/14
)
model = _create_eva('vit_base_patch16_rope_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model

View File

@ -25,7 +25,7 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
# Copyright 2020 Ross Wightman, Apache-2.0 License
from collections import OrderedDict
from functools import partial
from typing import Dict
from typing import Dict, List, Tuple, Union
import torch
import torch.nn as nn
@ -33,6 +33,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
from timm.layers import to_ntuple, to_2tuple, get_act_layer, DropPath, trunc_normal_, ndgrid
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model
@ -634,6 +635,70 @@ class Levit(nn.Module):
self.head = NormLinear(
self.embed_dim[-1], num_classes, drop=self.drop_rate) if num_classes > 0 else nn.Identity()
def forward_intermediates(
self,
x: torch.Tensor,
indices: Union[int, List[int], Tuple[int]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
B, C, H, W = x.shape
if not self.use_conv:
x = x.flatten(2).transpose(1, 2)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
if self.use_conv:
intermediates.append(x)
else:
intermediates.append(x.reshape(B, H, W, -1).permute(0, 3, 1, 2))
H = (H + 2 - 1) // 2
W = (W + 2 - 1) // 2
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.stem(x)
if not self.use_conv:
@ -746,9 +811,8 @@ model_cfgs = dict(
def create_levit(variant, cfg_variant=None, pretrained=False, distilled=True, **kwargs):
is_conv = '_conv' in variant
out_indices = kwargs.pop('out_indices', (0, 1, 2))
if kwargs.get('features_only', None):
if not is_conv:
raise RuntimeError('features_only not implemented for LeVit in non-convolutional mode.')
if kwargs.get('features_only', False) and not is_conv:
kwargs.setdefault('feature_cls', 'getter')
if cfg_variant is None:
if variant in model_cfgs:
cfg_variant = variant

View File

@ -50,6 +50,7 @@ from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act
from timm.layers import trunc_normal_tf_, to_2tuple, extend_tuple, make_divisible, _assert
from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf, use_fused_attn, resize_rel_pos_bias_table
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function
from ._manipulate import named_apply, checkpoint_seq
from ._registry import generate_default_cfgs, register_model
@ -1251,6 +1252,75 @@ class MaxxVit(nn.Module):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Union[int, List[int], Tuple[int]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)
# forward pass
feat_idx = 0 # stem is index 0
x = self.stem(x)
if feat_idx in take_indices:
intermediates.append(x)
last_idx = len(self.stages)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index]
for stage in stages:
feat_idx += 1
x = stage(x)
if feat_idx in take_indices:
if norm and feat_idx == last_idx:
x_inter = self.norm(x) # applying final norm to last intermediate
else:
x_inter = x
intermediates.append(x_inter)
if intermediates_only:
return intermediates
x = self.norm(x)
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)
self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.head = self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)

View File

@ -7,7 +7,7 @@ Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244
Hacked together by / Copyright 2019, Ross Wightman
"""
from functools import partial
from typing import Callable, List, Optional, Tuple
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -20,7 +20,7 @@ from ._builder import build_model_with_cfg, pretrained_cfg_for_features
from ._efficientnet_blocks import SqueezeExcite
from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from ._features import FeatureInfo, FeatureHooks
from ._features import FeatureInfo, FeatureHooks, feature_take_indices
from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
@ -109,6 +109,7 @@ class MobileNetV3(nn.Module):
)
self.blocks = nn.Sequential(*builder(stem_size, block_args))
self.feature_info = builder.features
self.stage_ends = [f['stage'] for f in self.feature_info]
head_chs = builder.in_chs
# Head + Pooling
@ -150,6 +151,84 @@ class MobileNetV3(nn.Module):
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_intermediates(
self,
x: torch.Tensor,
indices: Union[int, List[int], Tuple[int]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
extra_blocks: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
extra_blocks: Include outputs of all blocks and head conv in output, does not align with feature_info
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
if stop_early:
assert intermediates_only, 'Must use intermediates_only for early stopping.'
intermediates = []
if extra_blocks:
take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices)
else:
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
take_indices = [self.stage_ends[i] for i in take_indices]
max_index = self.stage_ends[max_index]
# forward pass
feat_idx = 0 # stem is index 0
x = self.conv_stem(x)
x = self.bn1(x)
if feat_idx in take_indices:
intermediates.append(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
blocks = self.blocks
else:
blocks = self.blocks[:max_index]
for blk in blocks:
feat_idx += 1
x = blk(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
extra_blocks: bool = False,
):
""" Prune layers not required for specified intermediates.
"""
if extra_blocks:
take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices)
else:
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
max_index = self.stage_ends[max_index]
self.blocks = self.blocks[:max_index] # truncate blocks w/ stem as idx 0
if max_index < len(self.blocks):
self.conv_head = nn.Identity()
if prune_head:
self.conv_head = nn.Identity()
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv_stem(x)
x = self.bn1(x)
@ -288,7 +367,7 @@ def _create_mnv3(variant: str, pretrained: bool = False, **kwargs) -> MobileNetV
model_cls = MobileNetV3
kwargs_filter = None
if kwargs.pop('features_only', False):
if 'feature_cfg' in kwargs:
if 'feature_cfg' in kwargs or 'feature_cls' in kwargs:
features_mode = 'cfg'
else:
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool')

View File

@ -839,7 +839,7 @@ class MultiScaleVit(nn.Module):
x: torch.Tensor,
indices: Union[int, List[int], Tuple[int]] = None,
norm: bool = False,
stop_early: bool = True,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
@ -855,13 +855,12 @@ class MultiScaleVit(nn.Module):
Returns:
"""
assert output_fmt in ('NCHW', 'NLC'), 'Output shape for MViT-V2 must be NCHW or NLC.'
assert output_fmt in ('NCHW', 'NLC'), 'Output shape must be NCHW or NLC.'
reshape = output_fmt == 'NCHW'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# FIXME slice block/pos_block if < max
# forward pass
x, feat_size = self.patch_embed(x)
B = x.shape[0]
@ -870,6 +869,7 @@ class MultiScaleVit(nn.Module):
x = torch.cat((cls_tokens, x), dim=1)
if self.pos_embed is not None:
x = x + self.pos_embed
for i, stage in enumerate(self.stages):
x, feat_size = stage(x, feat_size)
if i in take_indices:
@ -891,6 +891,23 @@ class MultiScaleVit(nn.Module):
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# FIXME add stage pruning
# self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x, feat_size = self.patch_embed(x)
B, N, C = x.shape

View File

@ -19,6 +19,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, LayerType, create_attn, \
get_attn, get_act_layer, get_norm_layer, create_classifier
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
@ -295,8 +296,8 @@ def drop_blocks(drop_prob: float = 0.):
def make_blocks(
block_fn: Union[BasicBlock, Bottleneck],
channels: List[int],
block_repeats: List[int],
channels: Tuple[int, ...],
block_repeats: Tuple[int, ...],
inplanes: int,
reduce_first: int = 1,
output_stride: int = 32,
@ -394,7 +395,7 @@ class ResNet(nn.Module):
def __init__(
self,
block: Union[BasicBlock, Bottleneck],
layers: List[int],
layers: Tuple[int, ...],
num_classes: int = 1000,
in_chans: int = 3,
output_stride: int = 32,
@ -497,7 +498,7 @@ class ResNet(nn.Module):
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# Feature Blocks
channels = [64, 128, 256, 512]
channels = (64, 128, 256, 512)
stage_modules, stage_feature_info = make_blocks(
block,
channels,
@ -553,6 +554,73 @@ class ResNet(nn.Module):
self.num_classes = num_classes
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Union[int, List[int], Tuple[int]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
if stop_early:
assert intermediates_only, 'Must use intermediates_only for early stopping.'
intermediates = []
take_indices, max_index = feature_take_indices(5, indices)
# forward pass
feat_idx = 0
x = self.conv1(x)
x = self.bn1(x)
x = self.act1(x)
if feat_idx in take_indices:
intermediates.append(x)
x = self.maxpool(x)
layer_names = ('layer1', 'layer2', 'layer3', 'layer4')
if stop_early:
layer_names = layer_names[:max_index]
for n in layer_names:
feat_idx += 1
x = getattr(self, n)(x) # won't work with torchscript, but keeps code reasonable, FML
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(5, indices)
layer_names = ('layer1', 'layer2', 'layer3', 'layer4')
layer_names = layer_names[max_index:]
for n in layer_names:
setattr(self, n, nn.Identity())
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv1(x)
x = self.bn1(x)
@ -1246,7 +1314,7 @@ default_cfgs = generate_default_cfgs({
def resnet10t(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-10-T model.
"""
model_args = dict(block=BasicBlock, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True)
model_args = dict(block=BasicBlock, layers=(1, 1, 1, 1), stem_width=32, stem_type='deep_tiered', avg_down=True)
return _create_resnet('resnet10t', pretrained, **dict(model_args, **kwargs))
@ -1254,7 +1322,7 @@ def resnet10t(pretrained: bool = False, **kwargs) -> ResNet:
def resnet14t(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-14-T model.
"""
model_args = dict(block=Bottleneck, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True)
model_args = dict(block=Bottleneck, layers=(1, 1, 1, 1), stem_width=32, stem_type='deep_tiered', avg_down=True)
return _create_resnet('resnet14t', pretrained, **dict(model_args, **kwargs))
@ -1262,7 +1330,7 @@ def resnet14t(pretrained: bool = False, **kwargs) -> ResNet:
def resnet18(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-18 model.
"""
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2])
model_args = dict(block=BasicBlock, layers=(2, 2, 2, 2))
return _create_resnet('resnet18', pretrained, **dict(model_args, **kwargs))
@ -1270,7 +1338,7 @@ def resnet18(pretrained: bool = False, **kwargs) -> ResNet:
def resnet18d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-18-D model.
"""
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True)
model_args = dict(block=BasicBlock, layers=(2, 2, 2, 2), stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnet18d', pretrained, **dict(model_args, **kwargs))
@ -1278,7 +1346,7 @@ def resnet18d(pretrained: bool = False, **kwargs) -> ResNet:
def resnet34(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-34 model.
"""
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3])
model_args = dict(block=BasicBlock, layers=(3, 4, 6, 3))
return _create_resnet('resnet34', pretrained, **dict(model_args, **kwargs))
@ -1286,7 +1354,7 @@ def resnet34(pretrained: bool = False, **kwargs) -> ResNet:
def resnet34d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-34-D model.
"""
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True)
model_args = dict(block=BasicBlock, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnet34d', pretrained, **dict(model_args, **kwargs))
@ -1294,7 +1362,7 @@ def resnet34d(pretrained: bool = False, **kwargs) -> ResNet:
def resnet26(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-26 model.
"""
model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2])
model_args = dict(block=Bottleneck, layers=(2, 2, 2, 2))
return _create_resnet('resnet26', pretrained, **dict(model_args, **kwargs))
@ -1302,7 +1370,7 @@ def resnet26(pretrained: bool = False, **kwargs) -> ResNet:
def resnet26t(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-26-T model.
"""
model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep_tiered', avg_down=True)
model_args = dict(block=Bottleneck, layers=(2, 2, 2, 2), stem_width=32, stem_type='deep_tiered', avg_down=True)
return _create_resnet('resnet26t', pretrained, **dict(model_args, **kwargs))
@ -1310,7 +1378,7 @@ def resnet26t(pretrained: bool = False, **kwargs) -> ResNet:
def resnet26d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-26-D model.
"""
model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True)
model_args = dict(block=Bottleneck, layers=(2, 2, 2, 2), stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnet26d', pretrained, **dict(model_args, **kwargs))
@ -1318,7 +1386,7 @@ def resnet26d(pretrained: bool = False, **kwargs) -> ResNet:
def resnet50(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-50 model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3])
model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3))
return _create_resnet('resnet50', pretrained, **dict(model_args, **kwargs))
@ -1326,7 +1394,7 @@ def resnet50(pretrained: bool = False, **kwargs) -> ResNet:
def resnet50c(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-50-C model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep')
model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep')
return _create_resnet('resnet50c', pretrained, **dict(model_args, **kwargs))
@ -1334,7 +1402,7 @@ def resnet50c(pretrained: bool = False, **kwargs) -> ResNet:
def resnet50d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-50-D model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True)
model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnet50d', pretrained, **dict(model_args, **kwargs))
@ -1342,7 +1410,7 @@ def resnet50d(pretrained: bool = False, **kwargs) -> ResNet:
def resnet50s(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-50-S model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=64, stem_type='deep')
model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), stem_width=64, stem_type='deep')
return _create_resnet('resnet50s', pretrained, **dict(model_args, **kwargs))
@ -1350,7 +1418,7 @@ def resnet50s(pretrained: bool = False, **kwargs) -> ResNet:
def resnet50t(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-50-T model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered', avg_down=True)
model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep_tiered', avg_down=True)
return _create_resnet('resnet50t', pretrained, **dict(model_args, **kwargs))
@ -1358,7 +1426,7 @@ def resnet50t(pretrained: bool = False, **kwargs) -> ResNet:
def resnet101(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-101 model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3])
model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3))
return _create_resnet('resnet101', pretrained, **dict(model_args, **kwargs))
@ -1366,7 +1434,7 @@ def resnet101(pretrained: bool = False, **kwargs) -> ResNet:
def resnet101c(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-101-C model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep')
model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), stem_width=32, stem_type='deep')
return _create_resnet('resnet101c', pretrained, **dict(model_args, **kwargs))
@ -1374,7 +1442,7 @@ def resnet101c(pretrained: bool = False, **kwargs) -> ResNet:
def resnet101d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-101-D model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True)
model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnet101d', pretrained, **dict(model_args, **kwargs))
@ -1382,7 +1450,7 @@ def resnet101d(pretrained: bool = False, **kwargs) -> ResNet:
def resnet101s(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-101-S model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=64, stem_type='deep')
model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), stem_width=64, stem_type='deep')
return _create_resnet('resnet101s', pretrained, **dict(model_args, **kwargs))
@ -1390,7 +1458,7 @@ def resnet101s(pretrained: bool = False, **kwargs) -> ResNet:
def resnet152(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-152 model.
"""
model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3])
model_args = dict(block=Bottleneck, layers=(3, 8, 36, 3))
return _create_resnet('resnet152', pretrained, **dict(model_args, **kwargs))
@ -1398,7 +1466,7 @@ def resnet152(pretrained: bool = False, **kwargs) -> ResNet:
def resnet152c(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-152-C model.
"""
model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep')
model_args = dict(block=Bottleneck, layers=(3, 8, 36, 3), stem_width=32, stem_type='deep')
return _create_resnet('resnet152c', pretrained, **dict(model_args, **kwargs))
@ -1406,7 +1474,7 @@ def resnet152c(pretrained: bool = False, **kwargs) -> ResNet:
def resnet152d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-152-D model.
"""
model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True)
model_args = dict(block=Bottleneck, layers=(3, 8, 36, 3), stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnet152d', pretrained, **dict(model_args, **kwargs))
@ -1414,7 +1482,7 @@ def resnet152d(pretrained: bool = False, **kwargs) -> ResNet:
def resnet152s(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-152-S model.
"""
model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], stem_width=64, stem_type='deep')
model_args = dict(block=Bottleneck, layers=(3, 8, 36, 3), stem_width=64, stem_type='deep')
return _create_resnet('resnet152s', pretrained, **dict(model_args, **kwargs))
@ -1422,7 +1490,7 @@ def resnet152s(pretrained: bool = False, **kwargs) -> ResNet:
def resnet200(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-200 model.
"""
model_args = dict(block=Bottleneck, layers=[3, 24, 36, 3])
model_args = dict(block=Bottleneck, layers=(3, 24, 36, 3))
return _create_resnet('resnet200', pretrained, **dict(model_args, **kwargs))
@ -1430,7 +1498,7 @@ def resnet200(pretrained: bool = False, **kwargs) -> ResNet:
def resnet200d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-200-D model.
"""
model_args = dict(block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True)
model_args = dict(block=Bottleneck, layers=(3, 24, 36, 3), stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnet200d', pretrained, **dict(model_args, **kwargs))
@ -1442,7 +1510,7 @@ def wide_resnet50_2(pretrained: bool = False, **kwargs) -> ResNet:
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], base_width=128)
model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), base_width=128)
return _create_resnet('wide_resnet50_2', pretrained, **dict(model_args, **kwargs))
@ -1453,7 +1521,7 @@ def wide_resnet101_2(pretrained: bool = False, **kwargs) -> ResNet:
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], base_width=128)
model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), base_width=128)
return _create_resnet('wide_resnet101_2', pretrained, **dict(model_args, **kwargs))
@ -1461,7 +1529,7 @@ def wide_resnet101_2(pretrained: bool = False, **kwargs) -> ResNet:
def resnet50_gn(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-50 model w/ GroupNorm
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], norm_layer='groupnorm')
model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), norm_layer='groupnorm')
return _create_resnet('resnet50_gn', pretrained, **dict(model_args, **kwargs))
@ -1469,7 +1537,7 @@ def resnet50_gn(pretrained: bool = False, **kwargs) -> ResNet:
def resnext50_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNeXt50-32x4d model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4)
model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), cardinality=32, base_width=4)
return _create_resnet('resnext50_32x4d', pretrained, **dict(model_args, **kwargs))
@ -1478,7 +1546,7 @@ def resnext50d_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNeXt50d-32x4d model. ResNext50 w/ deep stem & avg pool downsample
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4,
block=Bottleneck, layers=(3, 4, 6, 3), cardinality=32, base_width=4,
stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnext50d_32x4d', pretrained, **dict(model_args, **kwargs))
@ -1487,7 +1555,7 @@ def resnext50d_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
def resnext101_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNeXt-101 32x4d model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4)
model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=4)
return _create_resnet('resnext101_32x4d', pretrained, **dict(model_args, **kwargs))
@ -1495,7 +1563,7 @@ def resnext101_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
def resnext101_32x8d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNeXt-101 32x8d model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8)
model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=8)
return _create_resnet('resnext101_32x8d', pretrained, **dict(model_args, **kwargs))
@ -1503,7 +1571,7 @@ def resnext101_32x8d(pretrained: bool = False, **kwargs) -> ResNet:
def resnext101_32x16d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNeXt-101 32x16d model
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16)
model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=16)
return _create_resnet('resnext101_32x16d', pretrained, **dict(model_args, **kwargs))
@ -1511,7 +1579,7 @@ def resnext101_32x16d(pretrained: bool = False, **kwargs) -> ResNet:
def resnext101_32x32d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNeXt-101 32x32d model
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=32)
model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=32)
return _create_resnet('resnext101_32x32d', pretrained, **dict(model_args, **kwargs))
@ -1519,7 +1587,7 @@ def resnext101_32x32d(pretrained: bool = False, **kwargs) -> ResNet:
def resnext101_64x4d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNeXt101-64x4d model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4)
model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), cardinality=64, base_width=4)
return _create_resnet('resnext101_64x4d', pretrained, **dict(model_args, **kwargs))
@ -1530,7 +1598,7 @@ def ecaresnet26t(pretrained: bool = False, **kwargs) -> ResNet:
in the deep stem and ECA attn.
"""
model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32,
block=Bottleneck, layers=(2, 2, 2, 2), stem_width=32,
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'))
return _create_resnet('ecaresnet26t', pretrained, **dict(model_args, **kwargs))
@ -1540,7 +1608,7 @@ def ecaresnet50d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-50-D model with eca.
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='eca'))
return _create_resnet('ecaresnet50d', pretrained, **dict(model_args, **kwargs))
@ -1551,7 +1619,7 @@ def ecaresnet50d_pruned(pretrained: bool = False, **kwargs) -> ResNet:
The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='eca'))
return _create_resnet('ecaresnet50d_pruned', pretrained, pruned=True, **dict(model_args, **kwargs))
@ -1562,7 +1630,7 @@ def ecaresnet50t(pretrained: bool = False, **kwargs) -> ResNet:
Like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels in the deep stem and ECA attn.
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32,
block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32,
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'))
return _create_resnet('ecaresnet50t', pretrained, **dict(model_args, **kwargs))
@ -1572,7 +1640,7 @@ def ecaresnetlight(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-50-D light model with eca.
"""
model_args = dict(
block=Bottleneck, layers=[1, 1, 11, 3], stem_width=32, avg_down=True,
block=Bottleneck, layers=(1, 1, 11, 3), stem_width=32, avg_down=True,
block_args=dict(attn_layer='eca'))
return _create_resnet('ecaresnetlight', pretrained, **dict(model_args, **kwargs))
@ -1582,7 +1650,7 @@ def ecaresnet101d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-101-D model with eca.
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True,
block=Bottleneck, layers=(3, 4, 23, 3), stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='eca'))
return _create_resnet('ecaresnet101d', pretrained, **dict(model_args, **kwargs))
@ -1593,7 +1661,7 @@ def ecaresnet101d_pruned(pretrained: bool = False, **kwargs) -> ResNet:
The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True,
block=Bottleneck, layers=(3, 4, 23, 3), stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='eca'))
return _create_resnet('ecaresnet101d_pruned', pretrained, pruned=True, **dict(model_args, **kwargs))
@ -1603,7 +1671,7 @@ def ecaresnet200d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-200-D model with ECA.
"""
model_args = dict(
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True,
block=Bottleneck, layers=(3, 24, 36, 3), stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='eca'))
return _create_resnet('ecaresnet200d', pretrained, **dict(model_args, **kwargs))
@ -1613,7 +1681,7 @@ def ecaresnet269d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-269-D model with ECA.
"""
model_args = dict(
block=Bottleneck, layers=[3, 30, 48, 8], stem_width=32, stem_type='deep', avg_down=True,
block=Bottleneck, layers=(3, 30, 48, 8), stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='eca'))
return _create_resnet('ecaresnet269d', pretrained, **dict(model_args, **kwargs))
@ -1625,7 +1693,7 @@ def ecaresnext26t_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
in the deep stem. This model replaces SE module with the ECA module
"""
model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
block=Bottleneck, layers=(2, 2, 2, 2), cardinality=32, base_width=4, stem_width=32,
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'))
return _create_resnet('ecaresnext26t_32x4d', pretrained, **dict(model_args, **kwargs))
@ -1637,53 +1705,53 @@ def ecaresnext50t_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
in the deep stem. This model replaces SE module with the ECA module
"""
model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
block=Bottleneck, layers=(2, 2, 2, 2), cardinality=32, base_width=4, stem_width=32,
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'))
return _create_resnet('ecaresnext50t_32x4d', pretrained, **dict(model_args, **kwargs))
@register_model
def seresnet18(pretrained: bool = False, **kwargs) -> ResNet:
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='se'))
model_args = dict(block=BasicBlock, layers=(2, 2, 2, 2), block_args=dict(attn_layer='se'))
return _create_resnet('seresnet18', pretrained, **dict(model_args, **kwargs))
@register_model
def seresnet34(pretrained: bool = False, **kwargs) -> ResNet:
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se'))
model_args = dict(block=BasicBlock, layers=(3, 4, 6, 3), block_args=dict(attn_layer='se'))
return _create_resnet('seresnet34', pretrained, **dict(model_args, **kwargs))
@register_model
def seresnet50(pretrained: bool = False, **kwargs) -> ResNet:
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se'))
model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), block_args=dict(attn_layer='se'))
return _create_resnet('seresnet50', pretrained, **dict(model_args, **kwargs))
@register_model
def seresnet50t(pretrained: bool = False, **kwargs) -> ResNet:
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered',
block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep_tiered',
avg_down=True, block_args=dict(attn_layer='se'))
return _create_resnet('seresnet50t', pretrained, **dict(model_args, **kwargs))
@register_model
def seresnet101(pretrained: bool = False, **kwargs) -> ResNet:
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], block_args=dict(attn_layer='se'))
model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), block_args=dict(attn_layer='se'))
return _create_resnet('seresnet101', pretrained, **dict(model_args, **kwargs))
@register_model
def seresnet152(pretrained: bool = False, **kwargs) -> ResNet:
model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], block_args=dict(attn_layer='se'))
model_args = dict(block=Bottleneck, layers=(3, 8, 36, 3), block_args=dict(attn_layer='se'))
return _create_resnet('seresnet152', pretrained, **dict(model_args, **kwargs))
@register_model
def seresnet152d(pretrained: bool = False, **kwargs) -> ResNet:
model_args = dict(
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep',
block=Bottleneck, layers=(3, 8, 36, 3), stem_width=32, stem_type='deep',
avg_down=True, block_args=dict(attn_layer='se'))
return _create_resnet('seresnet152d', pretrained, **dict(model_args, **kwargs))
@ -1693,7 +1761,7 @@ def seresnet200d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-200-D model with SE attn.
"""
model_args = dict(
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep',
block=Bottleneck, layers=(3, 24, 36, 3), stem_width=32, stem_type='deep',
avg_down=True, block_args=dict(attn_layer='se'))
return _create_resnet('seresnet200d', pretrained, **dict(model_args, **kwargs))
@ -1703,7 +1771,7 @@ def seresnet269d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-269-D model with SE attn.
"""
model_args = dict(
block=Bottleneck, layers=[3, 30, 48, 8], stem_width=32, stem_type='deep',
block=Bottleneck, layers=(3, 30, 48, 8), stem_width=32, stem_type='deep',
avg_down=True, block_args=dict(attn_layer='se'))
return _create_resnet('seresnet269d', pretrained, **dict(model_args, **kwargs))
@ -1715,7 +1783,7 @@ def seresnext26d_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
combination of deep stem and avg_pool in downsample.
"""
model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
block=Bottleneck, layers=(2, 2, 2, 2), cardinality=32, base_width=4, stem_width=32,
stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'))
return _create_resnet('seresnext26d_32x4d', pretrained, **dict(model_args, **kwargs))
@ -1727,7 +1795,7 @@ def seresnext26t_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
in the deep stem.
"""
model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
block=Bottleneck, layers=(2, 2, 2, 2), cardinality=32, base_width=4, stem_width=32,
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='se'))
return _create_resnet('seresnext26t_32x4d', pretrained, **dict(model_args, **kwargs))
@ -1735,7 +1803,7 @@ def seresnext26t_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
@register_model
def seresnext50_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4,
block=Bottleneck, layers=(3, 4, 6, 3), cardinality=32, base_width=4,
block_args=dict(attn_layer='se'))
return _create_resnet('seresnext50_32x4d', pretrained, **dict(model_args, **kwargs))
@ -1743,7 +1811,7 @@ def seresnext50_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
@register_model
def seresnext101_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4,
block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=4,
block_args=dict(attn_layer='se'))
return _create_resnet('seresnext101_32x4d', pretrained, **dict(model_args, **kwargs))
@ -1751,7 +1819,7 @@ def seresnext101_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
@register_model
def seresnext101_32x8d(pretrained: bool = False, **kwargs) -> ResNet:
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8,
block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=8,
block_args=dict(attn_layer='se'))
return _create_resnet('seresnext101_32x8d', pretrained, **dict(model_args, **kwargs))
@ -1759,7 +1827,7 @@ def seresnext101_32x8d(pretrained: bool = False, **kwargs) -> ResNet:
@register_model
def seresnext101d_32x8d(pretrained: bool = False, **kwargs) -> ResNet:
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8,
block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=8,
stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='se'))
return _create_resnet('seresnext101d_32x8d', pretrained, **dict(model_args, **kwargs))
@ -1768,7 +1836,7 @@ def seresnext101d_32x8d(pretrained: bool = False, **kwargs) -> ResNet:
@register_model
def seresnext101_64x4d(pretrained: bool = False, **kwargs) -> ResNet:
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4,
block=Bottleneck, layers=(3, 4, 23, 3), cardinality=64, base_width=4,
block_args=dict(attn_layer='se'))
return _create_resnet('seresnext101_64x4d', pretrained, **dict(model_args, **kwargs))
@ -1776,7 +1844,7 @@ def seresnext101_64x4d(pretrained: bool = False, **kwargs) -> ResNet:
@register_model
def senet154(pretrained: bool = False, **kwargs) -> ResNet:
model_args = dict(
block=Bottleneck, layers=[3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep',
block=Bottleneck, layers=(3, 8, 36, 3), cardinality=64, base_width=4, stem_type='deep',
down_kernel_size=3, block_reduce_first=2, block_args=dict(attn_layer='se'))
return _create_resnet('senet154', pretrained, **dict(model_args, **kwargs))
@ -1785,7 +1853,7 @@ def senet154(pretrained: bool = False, **kwargs) -> ResNet:
def resnetblur18(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-18 model with blur anti-aliasing
"""
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], aa_layer=BlurPool2d)
model_args = dict(block=BasicBlock, layers=(2, 2, 2, 2), aa_layer=BlurPool2d)
return _create_resnet('resnetblur18', pretrained, **dict(model_args, **kwargs))
@ -1793,7 +1861,7 @@ def resnetblur18(pretrained: bool = False, **kwargs) -> ResNet:
def resnetblur50(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-50 model with blur anti-aliasing
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d)
model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), aa_layer=BlurPool2d)
return _create_resnet('resnetblur50', pretrained, **dict(model_args, **kwargs))
@ -1802,7 +1870,7 @@ def resnetblur50d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-50-D model with blur anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d,
block=Bottleneck, layers=(3, 4, 6, 3), aa_layer=BlurPool2d,
stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnetblur50d', pretrained, **dict(model_args, **kwargs))
@ -1812,7 +1880,7 @@ def resnetblur101d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-101-D model with blur anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=BlurPool2d,
block=Bottleneck, layers=(3, 4, 23, 3), aa_layer=BlurPool2d,
stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnetblur101d', pretrained, **dict(model_args, **kwargs))
@ -1822,7 +1890,7 @@ def resnetaa34d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-34-D model w/ avgpool anti-aliasing
"""
model_args = dict(
block=BasicBlock, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, stem_width=32, stem_type='deep', avg_down=True)
block=BasicBlock, layers=(3, 4, 6, 3), aa_layer=nn.AvgPool2d, stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnetaa34d', pretrained, **dict(model_args, **kwargs))
@ -1830,7 +1898,7 @@ def resnetaa34d(pretrained: bool = False, **kwargs) -> ResNet:
def resnetaa50(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-50 model with avgpool anti-aliasing
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d)
model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), aa_layer=nn.AvgPool2d)
return _create_resnet('resnetaa50', pretrained, **dict(model_args, **kwargs))
@ -1839,7 +1907,7 @@ def resnetaa50d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-50-D model with avgpool anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d,
block=Bottleneck, layers=(3, 4, 6, 3), aa_layer=nn.AvgPool2d,
stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnetaa50d', pretrained, **dict(model_args, **kwargs))
@ -1849,7 +1917,7 @@ def resnetaa101d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a ResNet-101-D model with avgpool anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=nn.AvgPool2d,
block=Bottleneck, layers=(3, 4, 23, 3), aa_layer=nn.AvgPool2d,
stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnetaa101d', pretrained, **dict(model_args, **kwargs))
@ -1859,7 +1927,7 @@ def seresnetaa50d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a SE=ResNet-50-D model with avgpool anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d,
block=Bottleneck, layers=(3, 4, 6, 3), aa_layer=nn.AvgPool2d,
stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'))
return _create_resnet('seresnetaa50d', pretrained, **dict(model_args, **kwargs))
@ -1869,7 +1937,7 @@ def seresnextaa101d_32x8d(pretrained: bool = False, **kwargs) -> ResNet:
"""Constructs a SE=ResNeXt-101-D 32x8d model with avgpool anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8,
block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=8,
stem_width=32, stem_type='deep', avg_down=True, aa_layer=nn.AvgPool2d,
block_args=dict(attn_layer='se'))
return _create_resnet('seresnextaa101d_32x8d', pretrained, **dict(model_args, **kwargs))
@ -1880,7 +1948,7 @@ def seresnextaa201d_32x8d(pretrained: bool = False, **kwargs):
"""Constructs a SE=ResNeXt-101-D 32x8d model with avgpool anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 24, 36, 4], cardinality=32, base_width=8,
block=Bottleneck, layers=(3, 24, 36, 4), cardinality=32, base_width=8,
stem_width=64, stem_type='deep', avg_down=True, aa_layer=nn.AvgPool2d,
block_args=dict(attn_layer='se'))
return _create_resnet('seresnextaa201d_32x8d', pretrained, **dict(model_args, **kwargs))
@ -1894,7 +1962,7 @@ def resnetrs50(pretrained: bool = False, **kwargs) -> ResNet:
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer))
return _create_resnet('resnetrs50', pretrained, **dict(model_args, **kwargs))
@ -1907,7 +1975,7 @@ def resnetrs101(pretrained: bool = False, **kwargs) -> ResNet:
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
block=Bottleneck, layers=(3, 4, 23, 3), stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer))
return _create_resnet('resnetrs101', pretrained, **dict(model_args, **kwargs))
@ -1920,7 +1988,7 @@ def resnetrs152(pretrained: bool = False, **kwargs) -> ResNet:
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
block=Bottleneck, layers=(3, 8, 36, 3), stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer))
return _create_resnet('resnetrs152', pretrained, **dict(model_args, **kwargs))
@ -1933,7 +2001,7 @@ def resnetrs200(pretrained: bool = False, **kwargs) -> ResNet:
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
block=Bottleneck, layers=(3, 24, 36, 3), stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer))
return _create_resnet('resnetrs200', pretrained, **dict(model_args, **kwargs))
@ -1946,7 +2014,7 @@ def resnetrs270(pretrained: bool = False, **kwargs) -> ResNet:
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
block=Bottleneck, layers=(4, 29, 53, 4), stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer))
return _create_resnet('resnetrs270', pretrained, **dict(model_args, **kwargs))
@ -1960,7 +2028,7 @@ def resnetrs350(pretrained: bool = False, **kwargs) -> ResNet:
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
block=Bottleneck, layers=(4, 36, 72, 4), stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer))
return _create_resnet('resnetrs350', pretrained, **dict(model_args, **kwargs))
@ -1973,7 +2041,7 @@ def resnetrs420(pretrained: bool = False, **kwargs) -> ResNet:
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
block=Bottleneck, layers=(4, 44, 87, 4), stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer))
return _create_resnet('resnetrs420', pretrained, **dict(model_args, **kwargs))

View File

@ -26,6 +26,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, DropPath, ClassifierHead, to_2tuple, to_ntuple, trunc_normal_, \
_assert, use_fused_attn, resize_rel_pos_bias_table, resample_patch_embed, ndgrid
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq, named_apply
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
@ -607,6 +608,72 @@ class SwinTransformer(nn.Module):
self.num_classes = num_classes
self.head.reset(num_classes, pool_type=global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Union[int, List[int], Tuple[int]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.layers), indices)
# forward pass
x = self.patch_embed(x)
num_stages = len(self.layers)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.layers
else:
stages = self.layers[:max_index + 1]
for i, stage in enumerate(stages):
x = stage(x)
if i in take_indices:
if norm and i == num_stages - 1:
x_inter = self.norm(x) # applying final norm last intermediate
else:
x_inter = x
x_inter = x_inter.permute(0, 3, 1, 2).contiguous()
intermediates.append(x_inter)
if intermediates_only:
return intermediates
x = self.norm(x)
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.layers), indices)
self.layers = self.layers[:max_index + 1] # truncate blocks
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.patch_embed(x)
x = self.layers(x)

View File

@ -13,7 +13,7 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W
# Written by Ze Liu
# --------------------------------------------------------
import math
from typing import Callable, Optional, Tuple, Union, Set, Dict
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -24,6 +24,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, ClassifierHead,\
resample_patch_embed, ndgrid, get_act_layer, LayerType
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
@ -608,6 +609,72 @@ class SwinTransformerV2(nn.Module):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Union[int, List[int], Tuple[int]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.layers), indices)
# forward pass
x = self.patch_embed(x)
num_stages = len(self.layers)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.layers
else:
stages = self.layers[:max_index + 1]
for i, stage in enumerate(stages):
x = stage(x)
if i in take_indices:
if norm and i == num_stages - 1:
x_inter = self.norm(x) # applying final norm last intermediate
else:
x_inter = x
x_inter = x_inter.permute(0, 3, 1, 2).contiguous()
intermediates.append(x_inter)
if intermediates_only:
return intermediates
x = self.norm(x)
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.layers), indices)
self.layers = self.layers[:max_index + 1] # truncate blocks
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.patch_embed(x)
x = self.layers(x)

View File

@ -39,6 +39,7 @@ import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, Mlp, ClassifierHead, to_2tuple, _assert, ndgrid
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function
from ._manipulate import named_apply
from ._registry import generate_default_cfgs, register_model
@ -718,6 +719,62 @@ class SwinTransformerV2Cr(nn.Module):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Union[int, List[int], Tuple[int]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.patch_embed(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
for i, stage in enumerate(stages):
x = stage(x)
if i in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
x = self.stages(x)

View File

@ -409,7 +409,7 @@ class Twins(nn.Module):
x: torch.Tensor,
indices: Union[int, List[int], Tuple[int]] = None,
norm: bool = False,
stop_early: bool = True,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
@ -459,6 +459,22 @@ class Twins(nn.Module):
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
# FIXME add block pruning
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
B = x.shape[0]
for i, (embed, drop, blocks, pos_blk) in enumerate(

View File

@ -638,7 +638,7 @@ class VisionTransformer(nn.Module):
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
return_prefix_tokens: bool = False,
norm: bool = False,
stop_early: bool = True,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
@ -655,7 +655,7 @@ class VisionTransformer(nn.Module):
Returns:
"""
assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.'
assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
reshape = output_fmt == 'NCHW'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
@ -666,6 +666,7 @@ class VisionTransformer(nn.Module):
x = self._pos_embed(x)
x = self.patch_drop(x)
x = self.norm_pre(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
blocks = self.blocks
else:
@ -698,21 +699,19 @@ class VisionTransformer(nn.Module):
def prune_intermediate_layers(
self,
n: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.blocks), n)
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
self.blocks = self.blocks[:max_index + 1] # truncate blocks
if prune_norm:
self.norm = nn.Identity()
if prune_head:
if self.attn_pool is not None:
self.attn_pool = None
self.fc_norm = nn.Identity()
self.head = nn.Identity()
self.reset_classifier(0, '')
return take_indices
def get_intermediate_layers(
@ -1791,12 +1790,27 @@ default_cfgs = {
license='mit',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
'vit_medium_patch16_reg4_256': _cfg(
input_size=(3, 256, 256)),
'vit_wee_patch16_reg1_gap_256': _cfg(
file='',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_little_patch16_reg4_gap_256': _cfg(
file='',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_medium_patch16_reg1_gap_256': _cfg(
file='vit_medium_gap1-in1k-20231118-8.pth',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_medium_patch16_reg4_gap_256': _cfg(
input_size=(3, 256, 256)),
file='vit_medium_gap4-in1k-20231115-8.pth',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_betwixt_patch16_reg1_gap_256': _cfg(
file='vit_betwixt_gap1-in1k-20231121-8.pth',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_betwixt_patch16_reg4_gap_256': _cfg(
file='vit_betwixt_gap4-in1k-20231106-8.pth',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_base_patch16_reg4_gap_256': _cfg(
input_size=(3, 256, 256)),
'vit_so150m_patch16_reg4_gap_256': _cfg(
input_size=(3, 256, 256)),
'vit_so150m_patch16_reg4_map_256': _cfg(
@ -2083,6 +2097,18 @@ def vit_medium_patch16_gap_384(pretrained: bool = False, **kwargs) -> VisionTran
return model
@register_model
def vit_betwixt_patch16_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT-Betwixt (ViT-b/16) w/o class token, w/ avg-pool @ 256x256
"""
model_args = dict(
patch_size=16, embed_dim=640, depth=12, num_heads=10, class_token=False,
global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False)
model = _create_vision_transformer(
'vit_medium_patch16_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_base_patch16_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 224x224
@ -2714,21 +2740,54 @@ def vit_so400m_patch14_siglip_384(pretrained: bool = False, **kwargs) -> VisionT
return model
# @register_model
# def vit_medium_patch16_reg4_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
# model_args = dict(
# patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=True,
# no_embed_class=True, reg_tokens=4,
# )
# model = _create_vision_transformer(
# 'vit_medium_patch16_reg4_256', pretrained=pretrained, **dict(model_args, **kwargs))
# return model
@register_model
def vit_medium_patch16_reg4_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
def vit_wee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=True,
no_embed_class=True, reg_tokens=4,
patch_size=16, embed_dim=256, depth=14, num_heads=4, init_values=1e-5, mlp_ratio=5,
class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', block_fn=ParallelScalingBlock,
)
model = _create_vision_transformer(
'vit_medium_patch16_reg4_256', pretrained=pretrained, **dict(model_args, **kwargs))
'vit_medium_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_little_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=320, depth=14, num_heads=5, init_values=1e-5, mlp_ratio=5.6,
class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
)
model = _create_vision_transformer(
'vit_medium_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_medium_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=512, depth=12, num_heads=8, init_values=1e-5,
class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg',
)
model = _create_vision_transformer(
'vit_medium_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_medium_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=512, depth=12, num_heads=8,
patch_size=16, embed_dim=512, depth=12, num_heads=8, init_values=1e-5,
class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
)
model = _create_vision_transformer(
@ -2736,6 +2795,28 @@ def vit_medium_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio
return model
@register_model
def vit_betwixt_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=640, depth=12, num_heads=10, init_values=1e-5,
class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg',
)
model = _create_vision_transformer(
'vit_betwixt_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_betwixt_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=640, depth=12, num_heads=10, init_values=1e-5,
class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
)
model = _create_vision_transformer(
'vit_betwixt_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_base_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(

View File

@ -394,7 +394,7 @@ class VisionTransformerRelPos(nn.Module):
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
return_prefix_tokens: bool = False,
norm: bool = False,
stop_early: bool = True,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
@ -411,7 +411,7 @@ class VisionTransformerRelPos(nn.Module):
Returns:
"""
assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.'
assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
reshape = output_fmt == 'NCHW'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
@ -455,19 +455,19 @@ class VisionTransformerRelPos(nn.Module):
def prune_intermediate_layers(
self,
n: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.blocks), n)
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
self.blocks = self.blocks[:max_index + 1] # truncate blocks
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.fc_norm = nn.Identity()
self.head = nn.Identity()
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):

View File

@ -545,7 +545,7 @@ class VisionTransformerSAM(nn.Module):
x: torch.Tensor,
indices: Union[int, List[int], Tuple[int]] = None,
norm: bool = False,
stop_early: bool = True,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
@ -573,6 +573,7 @@ class VisionTransformerSAM(nn.Module):
x = self.pos_drop(x)
x = self.patch_drop(x)
x = self.norm_pre(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
blocks = self.blocks
else:
@ -597,19 +598,19 @@ class VisionTransformerSAM(nn.Module):
def prune_intermediate_layers(
self,
n: Union[int, List[int], Tuple[int]] = None,
indices: Union[int, List[int], Tuple[int]] = None,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.blocks), n)
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
self.blocks = self.blocks[:max_index + 1] # truncate blocks
if prune_norm:
# neck is being treated as equivalent to final norm here
self.neck = nn.Identity()
if prune_head:
self.head = nn.Identity()
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):

View File

@ -713,7 +713,7 @@ class VOLO(nn.Module):
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
norm: bool = False,
stop_early: bool = True,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
@ -751,8 +751,11 @@ class VOLO(nn.Module):
x = self.pos_drop(x)
x = block(x)
if idx in take_indices:
# normalize intermediates with final norm layer if enabled
intermediates.append(x.permute(0, 3, 1, 2))
if norm and idx >= 2:
x_inter = self.norm(x)
else:
x_inter = x
intermediates.append(x_inter.permute(0, 3, 1, 2))
if intermediates_only:
return intermediates
@ -769,20 +772,20 @@ class VOLO(nn.Module):
def prune_intermediate_layers(
self,
n: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stage_ends), n)
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
max_index = self.stage_ends[max_index]
self.network = self.network[:max_index + 1] # truncate blocks
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.post_network = nn.ModuleList() # prune token blocks with head
self.head = nn.Identity()
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):

View File

@ -444,7 +444,7 @@ class Xcit(nn.Module):
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
norm: bool = False,
stop_early: bool = True,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
@ -460,7 +460,7 @@ class Xcit(nn.Module):
Returns:
"""
assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.'
assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
reshape = output_fmt == 'NCHW'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
@ -468,7 +468,6 @@ class Xcit(nn.Module):
# forward pass
B, _, height, width = x.shape
x, (Hp, Wp) = self.patch_embed(x)
if self.pos_embed is not None:
# `pos_embed` (B, C, Hp, Wp), reshape -> (B, C, N), permute -> (B, N, C)
pos_encoding = self.pos_embed(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1)
@ -503,19 +502,19 @@ class Xcit(nn.Module):
def prune_intermediate_layers(
self,
n: Union[int, List[int], Tuple[int]] = 1,
indices: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.blocks), n)
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
self.blocks = self.blocks[:max_index + 1] # truncate blocks
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.cls_attn_blocks = nn.ModuleList() # prune token blocks with head
self.head = nn.Identity()
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):