diff --git a/tests/test_models.py b/tests/test_models.py index 34bf0af4..9ff64c3b 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -51,7 +51,8 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): FEAT_INTER_FILTERS = [ 'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos', 'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2', - 'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet' + 'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet', + 'regnet', 'byobnet', 'byoanet', 'mlp_mixer' ] # transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output. diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index f4cc8c07..de077797 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -47,7 +47,7 @@ from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_e from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite from .selective_kernel import SelectiveKernel from .separable_conv import SeparableConv2d, SeparableConvNormAct -from .space_to_depth import SpaceToDepthModule, SpaceToDepth, DepthToSpace +from .space_to_depth import SpaceToDepth, DepthToSpace from .split_attn import SplitAttn from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame diff --git a/timm/layers/activations_jit.py b/timm/layers/activations_jit.py deleted file mode 100644 index b4a51653..00000000 --- a/timm/layers/activations_jit.py +++ /dev/null @@ -1,90 +0,0 @@ -""" Activations - -A collection of jit-scripted activations fn and modules with a common interface so that they can -easily be swapped. All have an `inplace` arg even if not used. - -All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not -currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted -versions if they contain in-place ops. - -Hacked together by / Copyright 2020 Ross Wightman -""" - -import torch -from torch import nn as nn -from torch.nn import functional as F - - -@torch.jit.script -def swish_jit(x, inplace: bool = False): - """Swish - Described in: https://arxiv.org/abs/1710.05941 - """ - return x.mul(x.sigmoid()) - - -@torch.jit.script -def mish_jit(x, _inplace: bool = False): - """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 - """ - return x.mul(F.softplus(x).tanh()) - - -class SwishJit(nn.Module): - def __init__(self, inplace: bool = False): - super(SwishJit, self).__init__() - - def forward(self, x): - return swish_jit(x) - - -class MishJit(nn.Module): - def __init__(self, inplace: bool = False): - super(MishJit, self).__init__() - - def forward(self, x): - return mish_jit(x) - - -@torch.jit.script -def hard_sigmoid_jit(x, inplace: bool = False): - # return F.relu6(x + 3.) / 6. - return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? - - -class HardSigmoidJit(nn.Module): - def __init__(self, inplace: bool = False): - super(HardSigmoidJit, self).__init__() - - def forward(self, x): - return hard_sigmoid_jit(x) - - -@torch.jit.script -def hard_swish_jit(x, inplace: bool = False): - # return x * (F.relu6(x + 3.) / 6) - return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? - - -class HardSwishJit(nn.Module): - def __init__(self, inplace: bool = False): - super(HardSwishJit, self).__init__() - - def forward(self, x): - return hard_swish_jit(x) - - -@torch.jit.script -def hard_mish_jit(x, inplace: bool = False): - """ Hard Mish - Experimental, based on notes by Mish author Diganta Misra at - https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md - """ - return 0.5 * x * (x + 2).clamp(min=0, max=2) - - -class HardMishJit(nn.Module): - def __init__(self, inplace: bool = False): - super(HardMishJit, self).__init__() - - def forward(self, x): - return hard_mish_jit(x) diff --git a/timm/layers/activations_me.py b/timm/layers/activations_me.py index 9a12bb7e..b0ddd5cb 100644 --- a/timm/layers/activations_me.py +++ b/timm/layers/activations_me.py @@ -3,8 +3,8 @@ A collection of activations fn and modules with a common interface so that they can easily be swapped. All have an `inplace` arg even if not used. -These activations are not compatible with jit scripting or ONNX export of the model, please use either -the JIT or basic versions of the activations. +These activations are not compatible with jit scripting or ONNX export of the model, please use +basic versions of the activations. Hacked together by / Copyright 2020 Ross Wightman """ @@ -14,19 +14,17 @@ from torch import nn as nn from torch.nn import functional as F -@torch.jit.script -def swish_jit_fwd(x): +def swish_fwd(x): return x.mul(torch.sigmoid(x)) -@torch.jit.script -def swish_jit_bwd(x, grad_output): +def swish_bwd(x, grad_output): x_sigmoid = torch.sigmoid(x) return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) -class SwishJitAutoFn(torch.autograd.Function): - """ torch.jit.script optimised Swish w/ memory-efficient checkpoint +class SwishAutoFn(torch.autograd.Function): + """ optimised Swish w/ memory-efficient checkpoint Inspired by conversation btw Jeremy Howard & Adam Pazske https://twitter.com/jeremyphoward/status/1188251041835315200 """ @@ -37,16 +35,16 @@ class SwishJitAutoFn(torch.autograd.Function): @staticmethod def forward(ctx, x): ctx.save_for_backward(x) - return swish_jit_fwd(x) + return swish_fwd(x) @staticmethod def backward(ctx, grad_output): x = ctx.saved_tensors[0] - return swish_jit_bwd(x, grad_output) + return swish_bwd(x, grad_output) def swish_me(x, inplace=False): - return SwishJitAutoFn.apply(x) + return SwishAutoFn.apply(x) class SwishMe(nn.Module): @@ -54,38 +52,36 @@ class SwishMe(nn.Module): super(SwishMe, self).__init__() def forward(self, x): - return SwishJitAutoFn.apply(x) + return SwishAutoFn.apply(x) -@torch.jit.script -def mish_jit_fwd(x): +def mish_fwd(x): return x.mul(torch.tanh(F.softplus(x))) -@torch.jit.script -def mish_jit_bwd(x, grad_output): +def mish_bwd(x, grad_output): x_sigmoid = torch.sigmoid(x) x_tanh_sp = F.softplus(x).tanh() return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) -class MishJitAutoFn(torch.autograd.Function): +class MishAutoFn(torch.autograd.Function): """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 - A memory efficient, jit scripted variant of Mish + A memory efficient variant of Mish """ @staticmethod def forward(ctx, x): ctx.save_for_backward(x) - return mish_jit_fwd(x) + return mish_fwd(x) @staticmethod def backward(ctx, grad_output): x = ctx.saved_tensors[0] - return mish_jit_bwd(x, grad_output) + return mish_bwd(x, grad_output) def mish_me(x, inplace=False): - return MishJitAutoFn.apply(x) + return MishAutoFn.apply(x) class MishMe(nn.Module): @@ -93,34 +89,32 @@ class MishMe(nn.Module): super(MishMe, self).__init__() def forward(self, x): - return MishJitAutoFn.apply(x) + return MishAutoFn.apply(x) -@torch.jit.script -def hard_sigmoid_jit_fwd(x, inplace: bool = False): +def hard_sigmoid_fwd(x, inplace: bool = False): return (x + 3).clamp(min=0, max=6).div(6.) -@torch.jit.script -def hard_sigmoid_jit_bwd(x, grad_output): +def hard_sigmoid_bwd(x, grad_output): m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6. return grad_output * m -class HardSigmoidJitAutoFn(torch.autograd.Function): +class HardSigmoidAutoFn(torch.autograd.Function): @staticmethod def forward(ctx, x): ctx.save_for_backward(x) - return hard_sigmoid_jit_fwd(x) + return hard_sigmoid_fwd(x) @staticmethod def backward(ctx, grad_output): x = ctx.saved_tensors[0] - return hard_sigmoid_jit_bwd(x, grad_output) + return hard_sigmoid_bwd(x, grad_output) def hard_sigmoid_me(x, inplace: bool = False): - return HardSigmoidJitAutoFn.apply(x) + return HardSigmoidAutoFn.apply(x) class HardSigmoidMe(nn.Module): @@ -128,32 +122,30 @@ class HardSigmoidMe(nn.Module): super(HardSigmoidMe, self).__init__() def forward(self, x): - return HardSigmoidJitAutoFn.apply(x) + return HardSigmoidAutoFn.apply(x) -@torch.jit.script -def hard_swish_jit_fwd(x): +def hard_swish_fwd(x): return x * (x + 3).clamp(min=0, max=6).div(6.) -@torch.jit.script -def hard_swish_jit_bwd(x, grad_output): +def hard_swish_bwd(x, grad_output): m = torch.ones_like(x) * (x >= 3.) m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m) return grad_output * m -class HardSwishJitAutoFn(torch.autograd.Function): - """A memory efficient, jit-scripted HardSwish activation""" +class HardSwishAutoFn(torch.autograd.Function): + """A memory efficient HardSwish activation""" @staticmethod def forward(ctx, x): ctx.save_for_backward(x) - return hard_swish_jit_fwd(x) + return hard_swish_fwd(x) @staticmethod def backward(ctx, grad_output): x = ctx.saved_tensors[0] - return hard_swish_jit_bwd(x, grad_output) + return hard_swish_bwd(x, grad_output) @staticmethod def symbolic(g, self): @@ -164,7 +156,7 @@ class HardSwishJitAutoFn(torch.autograd.Function): def hard_swish_me(x, inplace=False): - return HardSwishJitAutoFn.apply(x) + return HardSwishAutoFn.apply(x) class HardSwishMe(nn.Module): @@ -172,39 +164,37 @@ class HardSwishMe(nn.Module): super(HardSwishMe, self).__init__() def forward(self, x): - return HardSwishJitAutoFn.apply(x) + return HardSwishAutoFn.apply(x) -@torch.jit.script -def hard_mish_jit_fwd(x): +def hard_mish_fwd(x): return 0.5 * x * (x + 2).clamp(min=0, max=2) -@torch.jit.script -def hard_mish_jit_bwd(x, grad_output): +def hard_mish_bwd(x, grad_output): m = torch.ones_like(x) * (x >= -2.) m = torch.where((x >= -2.) & (x <= 0.), x + 1., m) return grad_output * m -class HardMishJitAutoFn(torch.autograd.Function): - """ A memory efficient, jit scripted variant of Hard Mish +class HardMishAutoFn(torch.autograd.Function): + """ A memory efficient variant of Hard Mish Experimental, based on notes by Mish author Diganta Misra at https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md """ @staticmethod def forward(ctx, x): ctx.save_for_backward(x) - return hard_mish_jit_fwd(x) + return hard_mish_fwd(x) @staticmethod def backward(ctx, grad_output): x = ctx.saved_tensors[0] - return hard_mish_jit_bwd(x, grad_output) + return hard_mish_bwd(x, grad_output) def hard_mish_me(x, inplace: bool = False): - return HardMishJitAutoFn.apply(x) + return HardMishAutoFn.apply(x) class HardMishMe(nn.Module): @@ -212,7 +202,7 @@ class HardMishMe(nn.Module): super(HardMishMe, self).__init__() def forward(self, x): - return HardMishJitAutoFn.apply(x) + return HardMishAutoFn.apply(x) diff --git a/timm/layers/create_act.py b/timm/layers/create_act.py index 93bcbf0e..6bbbc14b 100644 --- a/timm/layers/create_act.py +++ b/timm/layers/create_act.py @@ -4,9 +4,8 @@ Hacked together by / Copyright 2020 Ross Wightman from typing import Union, Callable, Type from .activations import * -from .activations_jit import * from .activations_me import * -from .config import is_exportable, is_scriptable, is_no_jit +from .config import is_exportable, is_scriptable # PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. # Also hardsigmoid, hardswish, and soon mish. This code will use native version if present. @@ -37,15 +36,6 @@ _ACT_FN_DEFAULT = dict( hard_mish=hard_mish, ) -_ACT_FN_JIT = dict( - silu=F.silu if _has_silu else swish_jit, - swish=F.silu if _has_silu else swish_jit, - mish=F.mish if _has_mish else mish_jit, - hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit, - hard_swish=F.hardswish if _has_hardswish else hard_swish_jit, - hard_mish=hard_mish_jit, -) - _ACT_FN_ME = dict( silu=F.silu if _has_silu else swish_me, swish=F.silu if _has_silu else swish_me, @@ -55,7 +45,7 @@ _ACT_FN_ME = dict( hard_mish=hard_mish_me, ) -_ACT_FNS = (_ACT_FN_ME, _ACT_FN_JIT, _ACT_FN_DEFAULT) +_ACT_FNS = (_ACT_FN_ME, _ACT_FN_DEFAULT) for a in _ACT_FNS: a.setdefault('hardsigmoid', a.get('hard_sigmoid')) a.setdefault('hardswish', a.get('hard_swish')) @@ -83,15 +73,6 @@ _ACT_LAYER_DEFAULT = dict( identity=nn.Identity, ) -_ACT_LAYER_JIT = dict( - silu=nn.SiLU if _has_silu else SwishJit, - swish=nn.SiLU if _has_silu else SwishJit, - mish=nn.Mish if _has_mish else MishJit, - hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit, - hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit, - hard_mish=HardMishJit, -) - _ACT_LAYER_ME = dict( silu=nn.SiLU if _has_silu else SwishMe, swish=nn.SiLU if _has_silu else SwishMe, @@ -101,7 +82,7 @@ _ACT_LAYER_ME = dict( hard_mish=HardMishMe, ) -_ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_JIT, _ACT_LAYER_DEFAULT) +_ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_DEFAULT) for a in _ACT_LAYERS: a.setdefault('hardsigmoid', a.get('hard_sigmoid')) a.setdefault('hardswish', a.get('hard_swish')) @@ -116,14 +97,11 @@ def get_act_fn(name: Union[Callable, str] = 'relu'): return None if isinstance(name, Callable): return name - if not (is_no_jit() or is_exportable() or is_scriptable()): + if not (is_exportable() or is_scriptable()): # If not exporting or scripting the model, first look for a memory-efficient version with # custom autograd, then fallback if name in _ACT_FN_ME: return _ACT_FN_ME[name] - if not (is_no_jit() or is_exportable()): - if name in _ACT_FN_JIT: - return _ACT_FN_JIT[name] return _ACT_FN_DEFAULT[name] @@ -139,12 +117,9 @@ def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'): return name if not name: return None - if not (is_no_jit() or is_exportable() or is_scriptable()): + if not (is_exportable() or is_scriptable()): if name in _ACT_LAYER_ME: return _ACT_LAYER_ME[name] - if not (is_no_jit() or is_exportable()): - if name in _ACT_LAYER_JIT: - return _ACT_LAYER_JIT[name] return _ACT_LAYER_DEFAULT[name] diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index 3f828c6d..cd7d5062 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -73,7 +73,6 @@ class TransformerDecoderLayerOptimal(nn.Module): return tgt -# @torch.jit.script # class ExtrapClasses(object): # def __init__(self, num_queries: int, group_size: int): # self.num_queries = num_queries @@ -88,24 +87,13 @@ class TransformerDecoderLayerOptimal(nn.Module): # out = out.view((h.shape[0], self.group_size * self.num_queries)) # return out -@torch.jit.script -class GroupFC(object): - def __init__(self, embed_len_decoder: int): - self.embed_len_decoder = embed_len_decoder - - def __call__(self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: torch.Tensor): - for i in range(self.embed_len_decoder): - h_i = h[:, i, :] - w_i = duplicate_pooling[i, :, :] - out_extrap[:, i, :] = torch.matmul(h_i, w_i) - - class MLDecoder(nn.Module): def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, initial_num_features=2048): super(MLDecoder, self).__init__() embed_len_decoder = 100 if num_of_groups < 0 else num_of_groups if embed_len_decoder > num_classes: embed_len_decoder = num_classes + self.embed_len_decoder = embed_len_decoder # switching to 768 initial embeddings decoder_embedding = 768 if decoder_embedding < 0 else decoder_embedding @@ -131,7 +119,6 @@ class MLDecoder(nn.Module): self.duplicate_pooling_bias = torch.nn.Parameter(torch.Tensor(num_classes)) torch.nn.init.xavier_normal_(self.duplicate_pooling) torch.nn.init.constant_(self.duplicate_pooling_bias, 0) - self.group_fc = GroupFC(embed_len_decoder) def forward(self, x): if len(x.shape) == 4: # [bs,2048, 7,7] @@ -149,7 +136,10 @@ class MLDecoder(nn.Module): h = h.transpose(0, 1) out_extrap = torch.zeros(h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype) - self.group_fc(h, self.duplicate_pooling, out_extrap) + for i in range(self.embed_len_decoder): # group FC + h_i = h[:, i, :] + w_i = self.duplicate_pooling[i, :, :] + out_extrap[:, i, :] = torch.matmul(h_i, w_i) h_out = out_extrap.flatten(1)[:, :self.num_classes] h_out += self.duplicate_pooling_bias logits = h_out diff --git a/timm/layers/norm.py b/timm/layers/norm.py index 504060c7..4b81dcef 100644 --- a/timm/layers/norm.py +++ b/timm/layers/norm.py @@ -82,7 +82,6 @@ def _is_contiguous(tensor: torch.Tensor) -> bool: return tensor.is_contiguous(memory_format=torch.contiguous_format) -@torch.jit.script def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float): s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True) x = (x - u) * torch.rsqrt(s + eps) diff --git a/timm/layers/space_to_depth.py b/timm/layers/space_to_depth.py index 5867456c..45268154 100644 --- a/timm/layers/space_to_depth.py +++ b/timm/layers/space_to_depth.py @@ -18,29 +18,6 @@ class SpaceToDepth(nn.Module): return x -@torch.jit.script -class SpaceToDepthJit: - def __call__(self, x: torch.Tensor): - # assuming hard-coded that block_size==4 for acceleration - N, C, H, W = x.size() - x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs) - x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) - x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs) - return x - - -class SpaceToDepthModule(nn.Module): - def __init__(self, no_jit=False): - super().__init__() - if not no_jit: - self.op = SpaceToDepthJit() - else: - self.op = SpaceToDepth() - - def forward(self, x): - return self.op(x) - - class DepthToSpace(nn.Module): def __init__(self, block_size): diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 9d09efac..e558c1a6 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -80,7 +80,7 @@ from ._builder import build_model_with_cfg, load_pretrained, load_custom_pretrai set_pretrained_download_progress, set_pretrained_check_hash from ._factory import create_model, parse_model_name, safe_model_name from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet, FeatureDictNet -from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \ +from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, get_graph_node_names, \ register_notrace_module, is_notrace_module, get_notrace_modules, \ register_notrace_function, is_notrace_function, get_notrace_functions from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_state_dict, resume_checkpoint diff --git a/timm/models/_features.py b/timm/models/_features.py index 5bd6f1fb..12f0ab37 100644 --- a/timm/models/_features.py +++ b/timm/models/_features.py @@ -158,7 +158,7 @@ class FeatureHooks: def __init__( self, - hooks: Sequence[str], + hooks: Sequence[Union[str, Dict]], named_modules: dict, out_map: Sequence[Union[int, str]] = None, default_hook_type: str = 'forward', @@ -168,11 +168,13 @@ class FeatureHooks: self._handles = [] modules = {k: v for k, v in named_modules} for i, h in enumerate(hooks): - hook_name = h['module'] + hook_name = h if isinstance(h, str) else h['module'] m = modules[hook_name] hook_id = out_map[i] if out_map else hook_name hook_fn = partial(self._collect_output_hook, hook_id) - hook_type = h.get('hook_type', default_hook_type) + hook_type = default_hook_type + if isinstance(h, dict): + hook_type = h.get('hook_type', default_hook_type) if hook_type == 'forward_pre': handle = m.register_forward_pre_hook(hook_fn) elif hook_type == 'forward': diff --git a/timm/models/_features_fx.py b/timm/models/_features_fx.py index 3b5891e6..b775871c 100644 --- a/timm/models/_features_fx.py +++ b/timm/models/_features_fx.py @@ -9,7 +9,9 @@ from torch import nn from ._features import _get_feature_info, _get_return_layers try: + # NOTE we wrap torchvision fns to use timm leaf / no trace definitions from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor + from torchvision.models.feature_extraction import get_graph_node_names as _get_graph_node_names has_fx_feature_extraction = True except ImportError: has_fx_feature_extraction = False @@ -30,7 +32,7 @@ from timm.layers.norm_act import ( __all__ = ['register_notrace_module', 'is_notrace_module', 'get_notrace_modules', 'register_notrace_function', 'is_notrace_function', 'get_notrace_functions', - 'create_feature_extractor', 'FeatureGraphNet', 'GraphExtractNet'] + 'create_feature_extractor', 'get_graph_node_names', 'FeatureGraphNet', 'GraphExtractNet'] # NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here @@ -92,6 +94,13 @@ def get_notrace_functions(): return list(_autowrap_functions) +def get_graph_node_names(model: nn.Module) -> Tuple[List[str], List[str]]: + return _get_graph_node_names( + model, + tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)} + ) + + def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]): assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' return _create_feature_extractor( diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index a2ff0095..a2b44e1a 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -40,6 +40,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0a from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import named_apply, checkpoint_seq from ._registry import generate_default_cfgs, register_model @@ -948,25 +949,37 @@ class Stem(nn.Sequential): stem_norm_acts = [False] * (num_rep - num_act) + [True] * num_act prev_chs = in_chs curr_stride = 1 + last_feat_idx = -1 for i, (ch, s, na) in enumerate(zip(stem_chs, stem_strides, stem_norm_acts)): layer_fn = layers.conv_norm_act if na else create_conv2d conv_name = f'conv{i + 1}' if i > 0 and s > 1: - self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat)) + last_feat_idx = i - 1 + self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat, stage=0)) self.add_module(conv_name, layer_fn(prev_chs, ch, kernel_size=kernel_size, stride=s)) prev_chs = ch curr_stride *= s prev_feat = conv_name if pool and 'max' in pool.lower(): - self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat)) + last_feat_idx = i + self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat, stage=0)) self.add_module('pool', nn.MaxPool2d(3, 2, 1)) curr_stride *= 2 prev_feat = 'pool' - self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat)) + self.last_feat_idx = last_feat_idx if last_feat_idx >= 0 else None + self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat, stage=0)) assert curr_stride == stride + def forward_intermediates(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + intermediate: Optional[torch.Tensor] = None + for i, m in enumerate(self): + x = m(x) + if self.last_feat_idx is not None and i == self.last_feat_idx: + intermediate = x + return x, intermediate + def create_byob_stem( in_chs: int, @@ -1008,7 +1021,7 @@ def create_byob_stem( if isinstance(stem, Stem): feature_info = [dict(f, module='.'.join([feat_prefix, f['module']])) for f in stem.feature_info] else: - feature_info = [dict(num_chs=out_chs, reduction=2, module=feat_prefix)] + feature_info = [dict(num_chs=out_chs, reduction=2, module=feat_prefix, stage=0)] return stem, feature_info @@ -1122,7 +1135,7 @@ def create_byob_stages( feat_size = reduce_feat_size(feat_size, stride) stages += [nn.Sequential(*blocks)] - prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}') + prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}', stage=stage_idx + 1) feature_info.append(prev_feat) return nn.Sequential(*stages), feature_info @@ -1198,6 +1211,7 @@ class ByobNet(nn.Module): feat_size=feat_size, ) self.feature_info.extend(stage_feat[:-1]) + reduction = stage_feat[-1]['reduction'] prev_chs = stage_feat[-1]['num_chs'] if cfg.num_features: @@ -1207,7 +1221,8 @@ class ByobNet(nn.Module): self.num_features = prev_chs self.final_conv = nn.Identity() self.feature_info += [ - dict(num_chs=self.num_features, reduction=stage_feat[-1]['reduction'], module='final_conv')] + dict(num_chs=self.num_features, reduction=reduction, module='final_conv', stage=len(self.stages))] + self.stage_ends = [f['stage'] for f in self.feature_info] self.head = ClassifierHead( self.num_features, @@ -1241,6 +1256,83 @@ class ByobNet(nn.Module): def reset_classifier(self, num_classes, global_pool='avg'): self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + exclude_final_conv: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + exclude_final_conv: Exclude final_conv from last intermediate + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) + take_indices = [self.stage_ends[i] for i in take_indices] + max_index = self.stage_ends[max_index] + # forward pass + feat_idx = 0 # stem is index 0 + if hasattr(self.stem, 'forward_intermediates'): + # returns last intermediate features in stem (before final stride in stride > 2 stems) + x, x_inter = self.stem.forward_intermediates(x) + else: + x, x_inter = self.stem(x), None + if feat_idx in take_indices: + intermediates.append(x if x_inter is None else x_inter) + last_idx = self.stage_ends[-1] + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index] + for stage in stages: + feat_idx += 1 + x = stage(x) + if not exclude_final_conv and feat_idx == last_idx: + # default feature_info for this model uses final_conv as the last feature output (if present) + x = self.final_conv(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + if exclude_final_conv and feat_idx == last_idx: + x = self.final_conv(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) + max_index = self.stage_ends[max_index] + self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0 + if max_index < self.stage_ends[-1]: + self.final_conv = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + + def forward_features(self, x): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 76bb8136..f10f6c7b 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -411,7 +411,7 @@ class ConvNeXt(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Union[int, List[int], Tuple[int]] = None, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', diff --git a/timm/models/davit.py b/timm/models/davit.py index d4d6ad69..442ca620 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -126,9 +126,9 @@ class ChannelAttention(nn.Module): q, k, v = qkv.unbind(0) k = k * self.scale - attention = k.transpose(-1, -2) @ v - attention = attention.softmax(dim=-1) - x = (attention @ q.transpose(-1, -2)).transpose(-1, -2) + attn = k.transpose(-1, -2) @ v + attn = attn.softmax(dim=-1) + x = (attn @ q.transpose(-1, -2)).transpose(-1, -2) x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) return x diff --git a/timm/models/efficientformer.py b/timm/models/efficientformer.py index fb6ff8ec..798f6435 100644 --- a/timm/models/efficientformer.py +++ b/timm/models/efficientformer.py @@ -12,7 +12,7 @@ Based on Apache 2.0 licensed code at https://github.com/snap-research/EfficientF Modifications and timm support by / Copyright 2022, Ross Wightman """ -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -463,7 +463,7 @@ class EfficientFormer(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Union[int, List[int], Tuple[int]] = None, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index fb04fb2c..44a77506 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -162,7 +162,7 @@ class EfficientNet(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Union[int, List[int], Tuple[int]] = None, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', @@ -183,8 +183,6 @@ class EfficientNet(nn.Module): """ assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' - if stop_early: - assert intermediates_only, 'Must use intermediates_only for early stopping.' intermediates = [] if extra_blocks: take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices) @@ -212,8 +210,9 @@ class EfficientNet(nn.Module): if intermediates_only: return intermediates - x = self.conv_head(x) - x = self.bn2(x) + if feat_idx == self.stage_ends[-1]: + x = self.conv_head(x) + x = self.bn2(x) return x, intermediates diff --git a/timm/models/eva.py b/timm/models/eva.py index e2eeed60..d7763fe1 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -717,12 +717,6 @@ def checkpoint_filter_fn( # fixed embedding no need to load buffer from checkpoint continue - # FIXME here while import new weights, to remove - # if k == 'cls_token': - # print('DEBUG: cls token -> reg') - # k = 'reg_token' - # #v = v + state_dict['pos_embed'][0, :] - if 'patch_embed.proj.weight' in k: _, _, H, W = model.patch_embed.proj.weight.shape if v.shape[-1] != W or v.shape[-2] != H: @@ -951,26 +945,22 @@ default_cfgs = generate_default_cfgs({ num_classes=0, ), - 'vit_medium_patch16_rope_reg1_gap_256.in1k': _cfg( - #hf_hub_id='timm/', - #file='vit_medium_gap1_rope-in1k-20230920-5.pth', + 'vit_medium_patch16_rope_reg1_gap_256.sbb_in1k': _cfg( + hf_hub_id='timm/', input_size=(3, 256, 256), crop_pct=0.95, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5) ), - 'vit_mediumd_patch16_rope_reg1_gap_256.in1k': _cfg( - #hf_hub_id='timm/', - #file='vit_mediumd_gap1_rope-in1k-20230926-5.pth', + 'vit_mediumd_patch16_rope_reg1_gap_256.sbb_in1k': _cfg( + hf_hub_id='timm/', input_size=(3, 256, 256), crop_pct=0.95, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5) ), - 'vit_betwixt_patch16_rope_reg4_gap_256.in1k': _cfg( - #hf_hub_id='timm/', - #file='vit_betwixt_gap4_rope-in1k-20231005-5.pth', + 'vit_betwixt_patch16_rope_reg4_gap_256.sbb_in1k': _cfg( + hf_hub_id='timm/', input_size=(3, 256, 256), crop_pct=0.95, ), - 'vit_base_patch16_rope_reg1_gap_256.in1k': _cfg( - #hf_hub_id='timm/', - #file='vit_base_gap1_rope-in1k-20230930-5.pth', + 'vit_base_patch16_rope_reg1_gap_256.sbb_in1k': _cfg( + hf_hub_id='timm/', input_size=(3, 256, 256), crop_pct=0.95, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5) ), diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 50324597..705ebd25 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -37,7 +37,6 @@ from timm.layers.pool2d_same import AvgPool2dSame, create_pool2d from timm.layers.squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite from timm.layers.selective_kernel import SelectiveKernel from timm.layers.separable_conv import SeparableConv2d, SeparableConvNormAct -from timm.layers.space_to_depth import SpaceToDepthModule from timm.layers.split_attn import SplitAttn from timm.layers.split_batchnorm import SplitBatchNorm2d, convert_splitbn_model from timm.layers.std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame diff --git a/timm/models/levit.py b/timm/models/levit.py index 023f131b..037cae6e 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -25,7 +25,7 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W # Copyright 2020 Ross Wightman, Apache-2.0 License from collections import OrderedDict from functools import partial -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -638,7 +638,7 @@ class Levit(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Union[int, List[int], Tuple[int]] = None, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index 86eed72a..0be7b9b3 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -1255,7 +1255,7 @@ class MaxxVit(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Union[int, List[int], Tuple[int]] = None, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index f7d64349..b775b736 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -40,6 +40,7 @@ Hacked together by / Copyright 2021 Ross Wightman """ import math from functools import partial +from typing import List, Optional, Union, Tuple import torch import torch.nn as nn @@ -47,6 +48,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import named_apply, checkpoint_seq from ._registry import generate_default_cfgs, register_model, register_model_deprecations @@ -211,6 +213,7 @@ class MlpMixer(nn.Module): embed_dim=embed_dim, norm_layer=norm_layer if stem_norm else None, ) + reduction = self.stem.feat_ratio() if hasattr(self.stem, 'feat_ratio') else patch_size # FIXME drop_path (stochastic depth scaling rule or all the same?) self.blocks = nn.Sequential(*[ block_layer( @@ -224,6 +227,8 @@ class MlpMixer(nn.Module): drop_path=drop_path_rate, ) for _ in range(num_blocks)]) + self.feature_info = [ + dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=reduction) for i in range(num_blocks)] self.norm = norm_layer(embed_dim) self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() @@ -257,6 +262,76 @@ class MlpMixer(nn.Module): self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + return_prefix_tokens: Return both prefix and spatial intermediate tokens + norm: Apply norm layer to all intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.' + reshape = output_fmt == 'NCHW' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.blocks), indices) + + # forward pass + B, _, height, width = x.shape + x = self.stem(x) + + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = self.blocks + else: + blocks = self.blocks[:max_index + 1] + for i, blk in enumerate(blocks): + x = blk(x) + if i in take_indices: + # normalize intermediates with final norm layer if enabled + intermediates.append(self.norm(x) if norm else x) + + # process intermediates + if reshape: + # reshape to BCHW output format + H, W = self.stem.dynamic_feat_size((height, width)) + intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates] + + if intermediates_only: + return intermediates + + x = self.norm(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.blocks), indices) + self.blocks = self.blocks[:max_index + 1] # truncate blocks + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): @@ -330,14 +405,13 @@ def checkpoint_filter_fn(state_dict, model): def _create_mixer(variant, pretrained=False, **kwargs): - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for MLP-Mixer models.') - + out_indices = kwargs.pop('out_indices', 3) model = build_model_with_cfg( MlpMixer, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), **kwargs, ) return model diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 3976f0db..a9e3a1a8 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -154,7 +154,7 @@ class MobileNetV3(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Union[int, List[int], Tuple[int]] = None, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', diff --git a/timm/models/mvitv2.py b/timm/models/mvitv2.py index c8afe470..5ad013e4 100644 --- a/timm/models/mvitv2.py +++ b/timm/models/mvitv2.py @@ -837,7 +837,7 @@ class MultiScaleVit(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Union[int, List[int], Tuple[int]] = None, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', diff --git a/timm/models/regnet.py b/timm/models/regnet.py index bc73f540..12187378 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -26,7 +26,7 @@ Hacked together by / Copyright 2020 Ross Wightman import math from dataclasses import dataclass, replace from functools import partial -from typing import Optional, Union, Callable +from typing import Callable, List, Optional, Union, Tuple import numpy as np import torch @@ -36,6 +36,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, GroupNormAct from timm.layers import get_act_layer, get_norm_act_layer, create_conv2d, make_divisible from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint_seq, named_apply from ._registry import generate_default_cfgs, register_model, register_model_deprecations @@ -515,6 +516,73 @@ class RegNet(nn.Module): def reset_classifier(self, num_classes, global_pool='avg'): self.head.reset(num_classes, pool_type=global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(5, indices) + + # forward pass + feat_idx = 0 + x = self.stem(x) + if feat_idx in take_indices: + intermediates.append(x) + + layer_names = ('s1', 's2', 's3', 's4') + if stop_early: + layer_names = layer_names[:max_index] + for n in layer_names: + feat_idx += 1 + x = getattr(self, n)(x) # won't work with torchscript, but keeps code reasonable, FML + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + if feat_idx == 4: + x = self.final_conv(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(5, indices) + layer_names = ('s1', 's2', 's3', 's4') + layer_names = layer_names[max_index:] + for n in layer_names: + setattr(self, n, nn.Identity()) + if max_index < 4: + self.final_conv = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) x = self.s1(x) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index a2e303ca..53dfab9c 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -557,7 +557,7 @@ class ResNet(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Union[int, List[int], Tuple[int]] = None, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', @@ -576,8 +576,6 @@ class ResNet(nn.Module): """ assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' - if stop_early: - assert intermediates_only, 'Must use intermediates_only for early stopping.' intermediates = [] take_indices, max_index = feature_take_indices(5, indices) diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index 5c0f7b4f..6614e4ad 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -611,7 +611,7 @@ class SwinTransformer(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Union[int, List[int], Tuple[int]] = None, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index a6ebb664..6bf2d767 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -612,7 +612,7 @@ class SwinTransformerV2(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Union[int, List[int], Tuple[int]] = None, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index 58cfcd36..4a33c803 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -722,7 +722,7 @@ class SwinTransformerV2Cr(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Union[int, List[int], Tuple[int]] = None, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', diff --git a/timm/models/twins.py b/timm/models/twins.py index b87a9c79..8e898f9f 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -407,7 +407,7 @@ class Twins(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Union[int, List[int], Tuple[int]] = None, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index fb848b07..5fbabb59 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -489,7 +489,7 @@ class VisionTransformer(nn.Module): **embed_args, ) num_patches = self.patch_embed.num_patches - r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size + reduction = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None @@ -523,7 +523,7 @@ class VisionTransformer(nn.Module): ) for i in range(depth)]) self.feature_info = [ - dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)] + dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=reduction) for i in range(depth)] self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() # Classifier Head @@ -1790,23 +1790,39 @@ default_cfgs = { license='mit', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), - 'vit_wee_patch16_reg1_gap_256': _cfg( - #file='', + 'vit_wee_patch16_reg1_gap_256.sbb_in1k': _cfg( input_size=(3, 256, 256), crop_pct=0.95), - 'vit_little_patch16_reg4_gap_256': _cfg( - #file='', + 'vit_pwee_patch16_reg1_gap_256.sbb_in1k': _cfg( + hf_hub_id='timm/', input_size=(3, 256, 256), crop_pct=0.95), - 'vit_medium_patch16_reg1_gap_256': _cfg( - #file='vit_medium_gap1-in1k-20231118-8.pth', + 'vit_little_patch16_reg4_gap_256.sbb_in1k': _cfg( + hf_hub_id='timm/', input_size=(3, 256, 256), crop_pct=0.95), - 'vit_medium_patch16_reg4_gap_256': _cfg( - #file='vit_medium_gap4-in1k-20231115-8.pth', + 'vit_medium_patch16_reg1_gap_256.sbb_in1k': _cfg( + hf_hub_id='timm/', input_size=(3, 256, 256), crop_pct=0.95), - 'vit_betwixt_patch16_reg1_gap_256': _cfg( - #file='vit_betwixt_gap1-in1k-20231121-8.pth', + 'vit_medium_patch16_reg4_gap_256.sbb_in1k': _cfg( + hf_hub_id='timm/', input_size=(3, 256, 256), crop_pct=0.95), - 'vit_betwixt_patch16_reg4_gap_256': _cfg( - #file='vit_betwixt_gap4-in1k-20231106-8.pth', + 'vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_mediumd_patch16_reg4_gap_256.sbb_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821, + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_betwixt_patch16_reg1_gap_256.sbb_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_betwixt_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_betwixt_patch16_reg4_gap_256.sbb_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_betwixt_patch16_reg4_gap_256.sbb_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821, input_size=(3, 256, 256), crop_pct=0.95), 'vit_base_patch16_reg4_gap_256': _cfg( input_size=(3, 256, 256)), @@ -2755,10 +2771,21 @@ def vit_so400m_patch14_siglip_384(pretrained: bool = False, **kwargs) -> VisionT def vit_wee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: model_args = dict( patch_size=16, embed_dim=256, depth=14, num_heads=4, init_values=1e-5, mlp_ratio=5, + class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', + ) + model = _create_vision_transformer( + 'vit_wee_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_pwee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=256, depth=16, num_heads=4, init_values=1e-5, mlp_ratio=5, class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', block_fn=ParallelScalingBlock, ) model = _create_vision_transformer( - 'vit_medium_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + 'vit_pwee_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2769,7 +2796,7 @@ def vit_little_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg', ) model = _create_vision_transformer( - 'vit_medium_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + 'vit_little_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2795,6 +2822,17 @@ def vit_medium_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio return model +@register_model +def vit_mediumd_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=512, depth=20, num_heads=8, init_values=1e-5, + class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg', + ) + model = _create_vision_transformer( + 'vit_mediumd_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + @register_model def vit_betwixt_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: model_args = dict( diff --git a/timm/models/vision_transformer_sam.py b/timm/models/vision_transformer_sam.py index 7bf6363f..fcab4252 100644 --- a/timm/models/vision_transformer_sam.py +++ b/timm/models/vision_transformer_sam.py @@ -543,7 +543,7 @@ class VisionTransformerSAM(nn.Module): def forward_intermediates( self, x: torch.Tensor, - indices: Union[int, List[int], Tuple[int]] = None, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', @@ -598,7 +598,7 @@ class VisionTransformerSAM(nn.Module): def prune_intermediate_layers( self, - indices: Union[int, List[int], Tuple[int]] = None, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, prune_norm: bool = False, prune_head: bool = True, ): diff --git a/timm/utils/__init__.py b/timm/utils/__init__.py index 4c6a00ca..9093b75a 100644 --- a/timm/utils/__init__.py +++ b/timm/utils/__init__.py @@ -1,4 +1,5 @@ from .agc import adaptive_clip_grad +from .attention_extract import AttentionExtract from .checkpoint_saver import CheckpointSaver from .clip_grad import dispatch_clip_grad from .cuda import ApexScaler, NativeScaler diff --git a/timm/utils/attention_extract.py b/timm/utils/attention_extract.py new file mode 100644 index 00000000..90021018 --- /dev/null +++ b/timm/utils/attention_extract.py @@ -0,0 +1,79 @@ +import fnmatch +from collections import OrderedDict +from typing import Union, Optional, List + +import torch + + +class AttentionExtract(torch.nn.Module): + # defaults should cover a significant number of timm models with attention maps. + default_node_names = ['*attn.softmax'] + default_module_names = ['*attn_drop'] + + def __init__( + self, + model: Union[torch.nn.Module], + names: Optional[List[str]] = None, + mode: str = 'eval', + method: str = 'fx', + hook_type: str = 'forward', + ): + """ Extract attention maps (or other activations) from a model by name. + + Args: + model: Instantiated model to extract from. + names: List of concrete or wildcard names to extract. Names are nodes for fx and modules for hooks. + mode: 'train' or 'eval' model mode. + method: 'fx' or 'hook' extraction method. + hook_type: 'forward' or 'forward_pre' hooks used. + """ + super().__init__() + assert mode in ('train', 'eval') + if mode == 'train': + model = model.train() + else: + model = model.eval() + + assert method in ('fx', 'hook') + if method == 'fx': + # names are activation node names + from timm.models._features_fx import get_graph_node_names, GraphExtractNet + + node_names = get_graph_node_names(model)[0 if mode == 'train' else 1] + matched = [] + names = names or self.default_node_names + for n in names: + matched.extend(fnmatch.filter(node_names, n)) + if not matched: + raise RuntimeError(f'No node names found matching {names}.') + + self.model = GraphExtractNet(model, matched) + self.hooks = None + else: + # names are module names + assert hook_type in ('forward', 'forward_pre') + from timm.models._features import FeatureHooks + + module_names = [n for n, m in model.named_modules()] + matched = [] + names = names or self.default_module_names + for n in names: + matched.extend(fnmatch.filter(module_names, n)) + if not matched: + raise RuntimeError(f'No module names found matching {names}.') + + self.model = model + self.hooks = FeatureHooks(matched, model.named_modules(), default_hook_type=hook_type) + + self.names = matched + self.mode = mode + self.method = method + + def forward(self, x): + if self.hooks is not None: + self.model(x) + output = self.hooks.get_output(device=x.device) + else: + output = self.model(x) + output = OrderedDict(zip(self.names, output)) + return output diff --git a/timm/version.py b/timm/version.py index 899e700f..c6092d3e 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '1.0.0.dev0' +__version__ = '1.0.1.dev0'