diff --git a/README.md b/README.md index 425e02c2..30be3cb1 100644 --- a/README.md +++ b/README.md @@ -566,7 +566,7 @@ Model validation results can be found in the [results tables](results/README.md) The official documentation can be found at https://huggingface.co/docs/hub/timm. Documentation contributions are welcome. -[Getting Started with PyTorch Image Models (timm): A Practitioner’s Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055) by [Chris Hughes](https://github.com/Chris-hughes10) is an extensive blog post covering many aspects of `timm` in detail. +[Getting Started with PyTorch Image Models (timm): A Practitioner’s Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055-2/) by [Chris Hughes](https://github.com/Chris-hughes10) is an extensive blog post covering many aspects of `timm` in detail. [timmdocs](http://timm.fast.ai/) is an alternate set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs. diff --git a/tests/test_models.py b/tests/test_models.py index 35585a88..a5d41dfd 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -54,6 +54,9 @@ FEAT_INTER_FILTERS = [ 'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2', 'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet', 'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*', 'tnt', + 'tiny_vit', 'vovnet', 'tresnet', 'rexnet', 'resnetv2', 'repghost', 'repvit', 'pvt_v2', 'nextvit', 'nest', + 'mambaout', 'inception_next', 'inception_v4', 'hgnet', 'gcvit', 'focalnet', 'efficientformer_v2', 'edgenext', + 'davit', 'rdnet', 'convnext', 'pit' ] # transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output. @@ -508,8 +511,9 @@ def test_model_forward_intermediates(model_name, batch_size): spatial_axis = get_spatial_dim(output_fmt) import math + inpt = torch.randn((batch_size, *input_size)) output, intermediates = model.forward_intermediates( - torch.randn((batch_size, *input_size)), + inpt, output_fmt=output_fmt, ) assert len(expected_channels) == len(intermediates) @@ -521,6 +525,9 @@ def test_model_forward_intermediates(model_name, batch_size): assert o.shape[0] == batch_size assert not torch.isnan(o).any() + output2 = model.forward_features(inpt) + assert torch.allclose(output, output2) + def _create_fx_model(model, train=False): # This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode diff --git a/timm/data/dataset_factory.py b/timm/data/dataset_factory.py index 021d50be..43488eb2 100644 --- a/timm/data/dataset_factory.py +++ b/timm/data/dataset_factory.py @@ -144,6 +144,7 @@ def create_dataset( use_train = split in _TRAIN_SYNONYM ds = QMNIST(train=use_train, **torch_kwargs) elif name == 'imagenet': + torch_kwargs.pop('download') assert has_imagenet, 'Please update to a newer PyTorch and torchvision for ImageNet dataset.' if split in _EVAL_SYNONYM: split = 'val' diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 47f2bf87..e2eb48d3 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -452,29 +452,29 @@ class ConvNeXt(nn.Module): """ assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' intermediates = [] - take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices) + take_indices, max_index = feature_take_indices(len(self.stages), indices) # forward pass - feat_idx = 0 # stem is index 0 x = self.stem(x) - if feat_idx in take_indices: - intermediates.append(x) + last_idx = len(self.stages) - 1 if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript stages = self.stages else: - stages = self.stages[:max_index] - for stage in stages: - feat_idx += 1 + stages = self.stages[:max_index + 1] + for feat_idx, stage in enumerate(stages): x = stage(x) if feat_idx in take_indices: - # NOTE not bothering to apply norm_pre when norm=True as almost no models have it enabled - intermediates.append(x) + if norm and feat_idx == last_idx: + intermediates.append(self.norm_pre(x)) + else: + intermediates.append(x) if intermediates_only: return intermediates - x = self.norm_pre(x) + if feat_idx == last_idx: + x = self.norm_pre(x) return x, intermediates @@ -486,8 +486,8 @@ class ConvNeXt(nn.Module): ): """ Prune layers not required for specified intermediates. """ - take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices) - self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0 + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 if prune_norm: self.norm_pre = nn.Identity() if prune_head: diff --git a/timm/models/davit.py b/timm/models/davit.py index 65009888..f538ecca 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -12,7 +12,7 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig # All rights reserved. # This source code is licensed under the MIT license from functools import partial -from typing import Optional, Tuple +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -23,6 +23,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import DropPath, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer, use_fused_attn from timm.layers import NormMlpClassifierHead, ClassifierHead from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._features_fx import register_notrace_function from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model @@ -636,6 +637,72 @@ class DaVit(nn.Module): self.num_classes = num_classes self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + last_idx = len(self.stages) - 1 + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + if norm and feat_idx == last_idx: + x_inter = self.norm_pre(x) # applying final norm to last intermediate + else: + x_inter = x + intermediates.append(x_inter) + + if intermediates_only: + return intermediates + + if feat_idx == last_idx: + x = self.norm_pre(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_norm: + self.norm_pre = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): diff --git a/timm/models/edgenext.py b/timm/models/edgenext.py index d768b1dc..e21be971 100644 --- a/timm/models/edgenext.py +++ b/timm/models/edgenext.py @@ -9,7 +9,7 @@ Modifications and additions for timm by / Copyright 2022, Ross Wightman """ import math from functools import partial -from typing import Optional, Tuple +from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -19,6 +19,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, create_conv2d, \ NormMlpClassifierHead, ClassifierHead from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._features_fx import register_notrace_module from ._manipulate import named_apply, checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -418,6 +419,72 @@ class EdgeNeXt(nn.Module): self.num_classes = num_classes self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + last_idx = len(self.stages) - 1 + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + if norm and feat_idx == last_idx: + x_inter = self.norm_pre(x) # applying final norm to last intermediate + else: + x_inter = x + intermediates.append(x_inter) + + if intermediates_only: + return intermediates + + if feat_idx == last_idx: + x = self.norm_pre(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_norm: + self.norm_pre = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) x = self.stages(x) diff --git a/timm/models/efficientformer_v2.py b/timm/models/efficientformer_v2.py index dcf64995..5bdc473f 100644 --- a/timm/models/efficientformer_v2.py +++ b/timm/models/efficientformer_v2.py @@ -16,7 +16,7 @@ Modifications and timm support by / Copyright 2023, Ross Wightman """ import math from functools import partial -from typing import Dict, Optional +from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -25,6 +25,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import create_conv2d, create_norm_layer, get_act_layer, get_norm_layer, ConvNormAct from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple, ndgrid from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model @@ -625,6 +626,73 @@ class EfficientFormerV2(nn.Module): def set_distilled_training(self, enable=True): self.distilled_training = enable + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + + last_idx = len(self.stages) - 1 + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + if feat_idx == last_idx: + x_inter = self.norm(x) if norm else x + intermediates.append(x_inter) + else: + intermediates.append(x) + + if intermediates_only: + return intermediates + + if feat_idx == last_idx: + x = self.norm(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) x = self.stages(x) diff --git a/timm/models/efficientvit_mit.py b/timm/models/efficientvit_mit.py index 34be806b..27872310 100644 --- a/timm/models/efficientvit_mit.py +++ b/timm/models/efficientvit_mit.py @@ -7,7 +7,7 @@ Adapted from official impl at https://github.com/mit-han-lab/efficientvit """ __all__ = ['EfficientVit', 'EfficientVitLarge'] -from typing import List, Optional +from typing import List, Optional, Tuple, Union from functools import partial import torch @@ -17,6 +17,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import SelectAdaptivePool2d, create_conv2d, GELUTanh from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._features_fx import register_notrace_module from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -754,6 +755,63 @@ class EfficientVit(nn.Module): self.num_classes = num_classes self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): @@ -851,6 +909,63 @@ class EfficientVitLarge(nn.Module): self.num_classes = num_classes self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): diff --git a/timm/models/efficientvit_msra.py b/timm/models/efficientvit_msra.py index 7e5c09a4..91caaa5a 100644 --- a/timm/models/efficientvit_msra.py +++ b/timm/models/efficientvit_msra.py @@ -9,7 +9,7 @@ Adapted from official impl at https://github.com/microsoft/Cream/tree/main/Effic __all__ = ['EfficientVitMsra'] import itertools from collections import OrderedDict -from typing import Dict, Optional +from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -17,6 +17,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import SqueezeExcite, SelectAdaptivePool2d, trunc_normal_, _assert from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -475,6 +476,63 @@ class EfficientVitMsra(nn.Module): self.head = NormLinear( self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else torch.nn.Identity() + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.patch_embed(x) + + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.patch_embed(x) if self.grad_checkpointing and not torch.jit.is_scripting(): diff --git a/timm/models/focalnet.py b/timm/models/focalnet.py index 5608facb..ec7cd1cf 100644 --- a/timm/models/focalnet.py +++ b/timm/models/focalnet.py @@ -18,7 +18,7 @@ This impl is/has: # Written by Jianwei Yang (jianwyan@microsoft.com) # -------------------------------------------------------- from functools import partial -from typing import Callable, Optional, Tuple +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -26,6 +26,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import Mlp, DropPath, LayerNorm2d, trunc_normal_, ClassifierHead, NormMlpClassifierHead from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import named_apply, checkpoint from ._registry import generate_default_cfgs, register_model @@ -458,6 +459,72 @@ class FocalNet(nn.Module): self.num_classes = num_classes self.head.reset(num_classes, pool_type=global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.layers), indices) + + # forward pass + x = self.stem(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.layers + else: + stages = self.layers[:max_index + 1] + + last_idx = len(self.layers) - 1 + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + if norm and feat_idx == last_idx: + x_inter = self.norm(x) # applying final norm to last intermediate + else: + x_inter = x + intermediates.append(x_inter) + + if intermediates_only: + return intermediates + + if feat_idx == last_idx: + x = self.norm(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.layers), indices) + self.layers = self.layers[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) x = self.layers(x) diff --git a/timm/models/gcvit.py b/timm/models/gcvit.py index b31b5768..c862dc4a 100644 --- a/timm/models/gcvit.py +++ b/timm/models/gcvit.py @@ -30,6 +30,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d, \ get_attn, get_act_layer, get_norm_layer, RelPosBias, _assert from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._features_fx import register_notrace_function from ._manipulate import named_apply, checkpoint from ._registry import register_model, generate_default_cfgs @@ -397,7 +398,7 @@ class GlobalContextVit(nn.Module): act_layer = get_act_layer(act_layer) norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps) norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps) - + self.feature_info = [] img_size = to_2tuple(img_size) feat_size = tuple(d // 4 for d in img_size) # stem reduction by 4 self.global_pool = global_pool @@ -441,6 +442,7 @@ class GlobalContextVit(nn.Module): norm_layer=norm_layer, norm_layer_cl=norm_layer_cl, )) + self.feature_info += [dict(num_chs=stages[-1].dim, reduction=2**(i+2), module=f'stages.{i}')] self.stages = nn.Sequential(*stages) # Classifier head @@ -494,6 +496,62 @@ class GlobalContextVit(nn.Module): global_pool = self.head.global_pool.pool_type self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.stem(x) x = self.stages(x) @@ -509,9 +567,11 @@ class GlobalContextVit(nn.Module): def _create_gcvit(variant, pretrained=False, **kwargs): - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Vision Transformer models.') - model = build_model_with_cfg(GlobalContextVit, variant, pretrained, **kwargs) + model = build_model_with_cfg( + GlobalContextVit, variant, pretrained, + feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True), + **kwargs + ) return model diff --git a/timm/models/hgnet.py b/timm/models/hgnet.py index ea0a92d9..3e44c9dc 100644 --- a/timm/models/hgnet.py +++ b/timm/models/hgnet.py @@ -6,7 +6,7 @@ The Paddle Implement of PP-HGNet (https://github.com/PaddlePaddle/PaddleClas/blo PP-HGNet: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet.py PP-HGNetv2: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet_v2.py """ -from typing import Dict, Optional +from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -15,6 +15,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import SelectAdaptivePool2d, DropPath, create_conv2d from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._registry import register_model, generate_default_cfgs from ._manipulate import checkpoint_seq @@ -508,6 +509,62 @@ class HighPerfGpuNet(nn.Module): self.num_classes = num_classes self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, 'avg') + return take_indices + def forward_features(self, x): x = self.stem(x) return self.stages(x) diff --git a/timm/models/inception_next.py b/timm/models/inception_next.py index 3c4906aa..2fcf123f 100644 --- a/timm/models/inception_next.py +++ b/timm/models/inception_next.py @@ -4,7 +4,7 @@ Original implementation & weights from: https://github.com/sail-sg/inceptionnext """ from functools import partial -from typing import Optional +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -12,6 +12,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import trunc_normal_, DropPath, to_2tuple, get_padding, SelectAdaptivePool2d from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -349,6 +350,62 @@ class MetaNeXt(nn.Module): def no_weight_decay(self): return set() + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, 'avg') + return take_indices + def forward_features(self, x): x = self.stem(x) x = self.stages(x) diff --git a/timm/models/inception_v4.py b/timm/models/inception_v4.py index a435533f..315328a2 100644 --- a/timm/models/inception_v4.py +++ b/timm/models/inception_v4.py @@ -3,6 +3,7 @@ Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License) """ from functools import partial +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -10,6 +11,7 @@ import torch.nn as nn from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.layers import create_classifier, ConvNormAct from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._registry import register_model, generate_default_cfgs __all__ = ['InceptionV4'] @@ -285,6 +287,66 @@ class InceptionV4(nn.Module): self.global_pool, self.last_linear = create_classifier( self.num_features, self.num_classes, pool_type=global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info] + take_indices, max_index = feature_take_indices(len(stage_ends), indices) + take_indices = [stage_ends[i] for i in take_indices] + max_index = stage_ends[max_index] + + # forward pass + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.features + else: + stages = self.features[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info] + take_indices, max_index = feature_take_indices(len(stage_ends), indices) + max_index = stage_ends[max_index] + self.features = self.features[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): return self.features(x) diff --git a/timm/models/mambaout.py b/timm/models/mambaout.py index f53a9cdf..71d12fe6 100644 --- a/timm/models/mambaout.py +++ b/timm/models/mambaout.py @@ -6,7 +6,7 @@ MetaFormer (https://github.com/sail-sg/metaformer), InceptionNeXt (https://github.com/sail-sg/inceptionnext) """ from collections import OrderedDict -from typing import Optional +from typing import List, Optional, Tuple, Union import torch from torch import nn @@ -14,6 +14,7 @@ from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead, get_act_layer from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -417,6 +418,67 @@ class MambaOut(nn.Module): self.num_classes = num_classes self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW', 'NHWC'), 'Output format must be one of NCHW or NHWC.' + channel_first = output_fmt == 'NCHW' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if channel_first: + # reshape to BCHW output format + intermediates = [y.permute(0, 3, 1, 2).contiguous() for y in intermediates] + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) x = self.stages(x) diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index e4375b34..b7d4e7e4 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -1302,7 +1302,8 @@ class MaxxVit(nn.Module): if intermediates_only: return intermediates - x = self.norm(x) + if feat_idx == last_idx: + x = self.norm(x) return x, intermediates diff --git a/timm/models/metaformer.py b/timm/models/metaformer.py index 23ef3724..490852cf 100644 --- a/timm/models/metaformer.py +++ b/timm/models/metaformer.py @@ -28,7 +28,7 @@ Adapted from https://github.com/sail-sg/metaformer, original copyright below from collections import OrderedDict from functools import partial -from typing import Optional +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -40,6 +40,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d, GroupNorm1, LayerNorm, LayerNorm2d, Mlp, \ use_fused_attn from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model @@ -597,6 +598,62 @@ class MetaFormer(nn.Module): final = nn.Identity() self.head.fc = final + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_head(self, x: Tensor, pre_logits: bool = False): # NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :( x = self.head.global_pool(x) diff --git a/timm/models/mvitv2.py b/timm/models/mvitv2.py index f790fd0d..c048a072 100644 --- a/timm/models/mvitv2.py +++ b/timm/models/mvitv2.py @@ -870,10 +870,11 @@ class MultiScaleVit(nn.Module): if self.pos_embed is not None: x = x + self.pos_embed - for i, stage in enumerate(self.stages): + last_idx = len(self.stages) - 1 + for feat_idx, stage in enumerate(self.stages): x, feat_size = stage(x, feat_size) - if i in take_indices: - if norm and i == (len(self.stages) - 1): + if feat_idx in take_indices: + if norm and feat_idx == last_idx: x_inter = self.norm(x) # applying final norm last intermediate else: x_inter = x @@ -887,7 +888,8 @@ class MultiScaleVit(nn.Module): if intermediates_only: return intermediates - x = self.norm(x) + if feat_idx == last_idx: + x = self.norm(x) return x, intermediates diff --git a/timm/models/nest.py b/timm/models/nest.py index 1d9c7521..9a423a97 100644 --- a/timm/models/nest.py +++ b/timm/models/nest.py @@ -19,6 +19,7 @@ import collections.abc import logging import math from functools import partial +from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -28,6 +29,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_, _assert from timm.layers import create_conv2d, create_pool2d, to_ntuple, use_fused_attn, LayerNorm from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._features_fx import register_notrace_function from ._manipulate import checkpoint_seq, named_apply from ._registry import register_model, generate_default_cfgs, register_model_deprecations @@ -420,6 +422,73 @@ class Nest(nn.Module): self.global_pool, self.head = create_classifier( self.num_features, self.num_classes, pool_type=global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.levels), indices) + + # forward pass + x = self.patch_embed(x) + last_idx = len(self.num_blocks) - 1 + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.levels + else: + stages = self.levels[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + if norm and feat_idx == last_idx: + x_inter = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + intermediates.append(x_inter) + else: + intermediates.append(x) + + if intermediates_only: + return intermediates + + if feat_idx == last_idx: + # Layer norm done over channel dim only (to NHWC and back) + x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.levels), indices) + self.levels = self.levels[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.patch_embed(x) x = self.levels(x) diff --git a/timm/models/nextvit.py b/timm/models/nextvit.py index 01e63fce..2f232e29 100644 --- a/timm/models/nextvit.py +++ b/timm/models/nextvit.py @@ -6,7 +6,7 @@ Next-ViT model defs and weights adapted from https://github.com/bytedance/Next-V """ # Copyright (c) ByteDance Inc. All rights reserved. from functools import partial -from typing import Optional +from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -16,6 +16,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import DropPath, trunc_normal_, ConvMlp, get_norm_layer, get_act_layer, use_fused_attn from timm.layers import ClassifierHead from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._features_fx import register_notrace_function from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model @@ -560,6 +561,72 @@ class NextViT(nn.Module): self.num_classes = num_classes self.head.reset(num_classes, pool_type=global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + last_idx = len(self.stages) - 1 + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + if feat_idx == last_idx: + x_inter = self.norm(x) if norm else x + intermediates.append(x_inter) + else: + intermediates.append(x) + + if intermediates_only: + return intermediates + + if feat_idx == last_idx: + x = self.norm(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): diff --git a/timm/models/pit.py b/timm/models/pit.py index 3a1090b8..1d5386a9 100644 --- a/timm/models/pit.py +++ b/timm/models/pit.py @@ -14,7 +14,7 @@ Modifications for timm by / Copyright 2020 Ross Wightman import math import re from functools import partial -from typing import Optional, Sequence, Tuple +from typing import List, Optional, Sequence, Tuple, Union import torch from torch import nn @@ -22,6 +22,7 @@ from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import trunc_normal_, to_2tuple from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._registry import register_model, generate_default_cfgs from .vision_transformer import Block @@ -254,6 +255,71 @@ class PoolingVisionTransformer(nn.Module): if self.head_dist is not None: self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.transformers), indices) + + # forward pass + x = self.patch_embed(x) + x = self.pos_drop(x + self.pos_embed) + cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) + + last_idx = len(self.transformers) - 1 + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.transformers + else: + stages = self.transformers[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x, cls_tokens = stage((x, cls_tokens)) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + if feat_idx == last_idx: + cls_tokens = self.norm(cls_tokens) + + return cls_tokens, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.transformers), indices) + self.transformers = self.transformers[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.patch_embed(x) x = self.pos_drop(x + self.pos_embed) @@ -314,7 +380,7 @@ def _create_pit(variant, pretrained=False, **kwargs): variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, - feature_cfg=dict(feature_cls='hook', no_rewrite=True, out_indices=out_indices), + feature_cfg=dict(feature_cls='hook', out_indices=out_indices), **kwargs, ) return model diff --git a/timm/models/pvt_v2.py b/timm/models/pvt_v2.py index 6977350a..8cd42fe8 100644 --- a/timm/models/pvt_v2.py +++ b/timm/models/pvt_v2.py @@ -16,7 +16,7 @@ Modifications and timm support by / Copyright 2022, Ross Wightman """ import math -from typing import Callable, List, Optional, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -25,6 +25,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import DropPath, to_2tuple, to_ntuple, trunc_normal_, LayerNorm, use_fused_attn from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint from ._registry import register_model, generate_default_cfgs @@ -386,6 +387,62 @@ class PyramidVisionTransformerV2(nn.Module): self.global_pool = global_pool self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.patch_embed(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.patch_embed(x) x = self.stages(x) diff --git a/timm/models/rdnet.py b/timm/models/rdnet.py index 393dc97b..a3a205ff 100644 --- a/timm/models/rdnet.py +++ b/timm/models/rdnet.py @@ -302,29 +302,33 @@ class RDNet(nn.Module): """ assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' intermediates = [] - take_indices, max_index = feature_take_indices(len(self.dense_stages) + 1, indices) + stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info] + take_indices, max_index = feature_take_indices(len(stage_ends), indices) + take_indices = [stage_ends[i] for i in take_indices] + max_index = stage_ends[max_index] # forward pass - feat_idx = 0 # stem is index 0 x = self.stem(x) - if feat_idx in take_indices: - intermediates.append(x) + last_idx = len(self.dense_stages) - 1 if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript dense_stages = self.dense_stages else: - dense_stages = self.dense_stages[:max_index] - for stage in dense_stages: - feat_idx += 1 + dense_stages = self.dense_stages[:max_index + 1] + for feat_idx, stage in enumerate(dense_stages): x = stage(x) if feat_idx in take_indices: - # NOTE not bothering to apply norm_pre when norm=True as almost no models have it enabled - intermediates.append(x) + if norm and feat_idx == last_idx: + x_inter = self.norm_pre(x) # applying final norm to last intermediate + else: + x_inter = x + intermediates.append(x_inter) if intermediates_only: return intermediates - x = self.norm_pre(x) + if feat_idx == last_idx: + x = self.norm_pre(x) return x, intermediates @@ -336,8 +340,10 @@ class RDNet(nn.Module): ): """ Prune layers not required for specified intermediates. """ - take_indices, max_index = feature_take_indices(len(self.dense_stages) + 1, indices) - self.dense_stages = self.dense_stages[:max_index] # truncate blocks w/ stem as idx 0 + stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info] + take_indices, max_index = feature_take_indices(len(stage_ends), indices) + max_index = stage_ends[max_index] + self.dense_stages = self.dense_stages[:max_index + 1] # truncate blocks w/ stem as idx 0 if prune_norm: self.norm_pre = nn.Identity() if prune_head: @@ -355,6 +361,7 @@ class RDNet(nn.Module): def forward_features(self, x): x = self.stem(x) x = self.dense_stages(x) + x = self.norm_pre(x) return x def forward_head(self, x, pre_logits: bool = False): diff --git a/timm/models/repghost.py b/timm/models/repghost.py index 4b802d79..77fc35d5 100644 --- a/timm/models/repghost.py +++ b/timm/models/repghost.py @@ -6,7 +6,7 @@ Original implementation: https://github.com/ChengpengChen/RepGhost """ import copy from functools import partial -from typing import Optional +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -16,6 +16,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import SelectAdaptivePool2d, Linear, make_divisible from ._builder import build_model_with_cfg from ._efficientnet_blocks import SqueezeExcite, ConvBnAct +from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -294,6 +295,72 @@ class RepGhostNet(nn.Module): self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled self.classifier = Linear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity() + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + stage_ends = [-1] + [int(info['module'].split('.')[-1]) for info in self.feature_info[1:]] + take_indices, max_index = feature_take_indices(len(stage_ends), indices) + take_indices = [stage_ends[i]+1 for i in take_indices] + max_index = stage_ends[max_index] + + # forward pass + feat_idx = 0 + x = self.conv_stem(x) + if feat_idx in take_indices: + intermediates.append(x) + x = self.bn1(x) + x = self.act1(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.blocks + else: + stages = self.blocks[:max_index + 1] + + for feat_idx, stage in enumerate(stages, start=1): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + stage_ends = [-1] + [int(info['module'].split('.')[-1]) for info in self.feature_info[1:]] + take_indices, max_index = feature_take_indices(len(stage_ends), indices) + max_index = stage_ends[max_index] + self.blocks = self.blocks[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.conv_stem(x) x = self.bn1(x) diff --git a/timm/models/repvit.py b/timm/models/repvit.py index 7dcb2cd9..ddcfed55 100644 --- a/timm/models/repvit.py +++ b/timm/models/repvit.py @@ -14,9 +14,7 @@ Paper: `RepViT: Revisiting Mobile CNN From ViT Perspective` Adapted from official impl at https://github.com/jameslahm/RepViT """ - -__all__ = ['RepVit'] -from typing import Optional +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -24,9 +22,12 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import SqueezeExcite, trunc_normal_, to_ntuple, to_2tuple from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs +__all__ = ['RepVit'] + class ConvNorm(nn.Sequential): def __init__(self, in_dim, out_dim, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1): @@ -333,6 +334,62 @@ class RepVit(nn.Module): def set_distilled_training(self, enable=True): self.head.distilled_training = enable + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index d7d2905b..5cc164ae 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -31,7 +31,7 @@ Original copyright of Google code below, modifications by Ross Wightman, Copyrig from collections import OrderedDict # pylint: disable=g-importing-member from functools import partial -from typing import Optional +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -40,6 +40,7 @@ from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.layers import GroupNormAct, BatchNormAct2d, EvoNorm2dS0, FilterResponseNormTlu2d, ClassifierHead, \ DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d, get_act_layer, get_norm_act_layer, make_divisible from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint_seq, named_apply, adapt_input_conv from ._registry import generate_default_cfgs, register_model, register_model_deprecations @@ -543,6 +544,79 @@ class ResNetV2(nn.Module): self.num_classes = num_classes self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(5, indices) + + # forward pass + feat_idx = 0 + H, W = x.shape[-2:] + for stem in self.stem: + x = stem(x) + if x.shape[-2:] == (H //2, W //2): + x_down = x + if feat_idx in take_indices: + intermediates.append(x_down) + last_idx = len(self.stages) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index] + + for feat_idx, stage in enumerate(stages, start=1): + x = stage(x) + if feat_idx in take_indices: + if feat_idx == last_idx: + x_inter = self.norm(x) if norm else x + intermediates.append(x_inter) + else: + intermediates.append(x) + + if intermediates_only: + return intermediates + + if feat_idx == last_idx: + x = self.norm(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(5, indices) + self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0 + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py index 9971728c..dd3cb4f3 100644 --- a/timm/models/rexnet.py +++ b/timm/models/rexnet.py @@ -12,7 +12,7 @@ Copyright 2020 Ross Wightman from functools import partial from math import ceil -from typing import Optional +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -21,6 +21,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import ClassifierHead, create_act_layer, ConvNormAct, DropPath, make_divisible, SEModule from ._builder import build_model_with_cfg from ._efficientnet_builder import efficientnet_init_weights +from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model @@ -234,6 +235,67 @@ class RexNet(nn.Module): self.num_classes = num_classes self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info] + take_indices, max_index = feature_take_indices(len(stage_ends), indices) + take_indices = [stage_ends[i] for i in take_indices] + max_index = stage_ends[max_index] + + # forward pass + x = self.stem(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.features + else: + stages = self.features[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info] + take_indices, max_index = feature_take_indices(len(stage_ends), indices) + max_index = stage_ends[max_index] + self.features = self.features[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): diff --git a/timm/models/tiny_vit.py b/timm/models/tiny_vit.py index 12a5ef2f..d238fa5b 100644 --- a/timm/models/tiny_vit.py +++ b/timm/models/tiny_vit.py @@ -10,7 +10,7 @@ __all__ = ['TinyVit'] import itertools from functools import partial -from typing import Dict, Optional +from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -20,6 +20,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import LayerNorm2d, NormMlpClassifierHead, DropPath,\ trunc_normal_, resize_rel_pos_bias_table_levit, use_fused_attn from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._features_fx import register_notrace_module from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -536,6 +537,62 @@ class TinyVit(nn.Module): self.num_classes = num_classes self.head.reset(num_classes, pool_type=global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.patch_embed(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.patch_embed(x) if self.grad_checkpointing and not torch.jit.is_scripting(): diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index 37c37e2f..0fb76fa4 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -7,13 +7,14 @@ Original model: https://github.com/mrT23/TResNet """ from collections import OrderedDict from functools import partial -from typing import Optional +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn from timm.layers import SpaceToDepth, BlurPool2d, ClassifierHead, SEModule, ConvNormAct, DropPath from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs, register_model_deprecations @@ -228,6 +229,65 @@ class TResNet(nn.Module): self.num_classes = num_classes self.head.reset(num_classes, pool_type=global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + stage_ends = [1, 2, 3, 4, 5] + take_indices, max_index = feature_take_indices(len(stage_ends), indices) + take_indices = [stage_ends[i] for i in take_indices] + max_index = stage_ends[max_index] + # forward pass + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.body + else: + stages = self.body[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + stage_ends = [1, 2, 3, 4, 5] + take_indices, max_index = feature_take_indices(len(stage_ends), indices) + max_index = stage_ends[max_index] + self.body = self.body[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): if self.grad_checkpointing and not torch.jit.is_scripting(): x = self.body.s2d(x) diff --git a/timm/models/vovnet.py b/timm/models/vovnet.py index 08e6d0b6..0b48a34c 100644 --- a/timm/models/vovnet.py +++ b/timm/models/vovnet.py @@ -11,7 +11,7 @@ for some reference, rewrote most of the code. Hacked together by / Copyright 2020 Ross Wightman """ -from typing import List, Optional +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -20,6 +20,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import ConvNormAct, SeparableConvNormAct, BatchNormAct2d, ClassifierHead, DropPath, \ create_attn, create_norm_act_layer from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -264,6 +265,67 @@ class VovNet(nn.Module): self.num_classes = num_classes self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(5, indices) + + # forward pass + feat_idx = 0 + x = self.stem[:-1](x) + if feat_idx in take_indices: + intermediates.append(x) + + x = self.stem[-1](x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index] + + for feat_idx, stage in enumerate(stages, start=1): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(5, indices) + self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) return self.stages(x) diff --git a/timm/models/xcit.py b/timm/models/xcit.py index e6cf87b7..250749f1 100644 --- a/timm/models/xcit.py +++ b/timm/models/xcit.py @@ -494,7 +494,8 @@ class Xcit(nn.Module): # NOTE not supporting return of class tokens x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1) for blk in self.cls_attn_blocks: - x = blk(x) + x = blk(x) + x = self.norm(x) return x, intermediates