diff --git a/timm/models/mambaout.py b/timm/models/mambaout.py index f53a9cdf..8eac6e7b 100644 --- a/timm/models/mambaout.py +++ b/timm/models/mambaout.py @@ -6,7 +6,7 @@ MetaFormer (https://github.com/sail-sg/metaformer), InceptionNeXt (https://github.com/sail-sg/inceptionnext) """ from collections import OrderedDict -from typing import Optional +from typing import List, Optional, Tuple, Union import torch from torch import nn @@ -14,6 +14,7 @@ from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead, get_act_layer from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -417,6 +418,68 @@ class MambaOut(nn.Module): self.num_classes = num_classes self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW', 'NHWC'), 'Output format must be one of NCHW or NHWC.' + channel_first = output_fmt == 'NCHW' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if channel_first: + # reshape to BCHW output format + intermediates = [y.permute(0, 3, 1, 2).contiguous() for y in intermediates] + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + + def forward_features(self, x): x = self.stem(x) x = self.stages(x) diff --git a/timm/models/metaformer.py b/timm/models/metaformer.py index 23ef3724..490852cf 100644 --- a/timm/models/metaformer.py +++ b/timm/models/metaformer.py @@ -28,7 +28,7 @@ Adapted from https://github.com/sail-sg/metaformer, original copyright below from collections import OrderedDict from functools import partial -from typing import Optional +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -40,6 +40,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d, GroupNorm1, LayerNorm, LayerNorm2d, Mlp, \ use_fused_attn from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model @@ -597,6 +598,62 @@ class MetaFormer(nn.Module): final = nn.Identity() self.head.fc = final + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_head(self, x: Tensor, pre_logits: bool = False): # NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :( x = self.head.global_pool(x) diff --git a/timm/models/nest.py b/timm/models/nest.py index 1d9c7521..9ee50463 100644 --- a/timm/models/nest.py +++ b/timm/models/nest.py @@ -19,6 +19,7 @@ import collections.abc import logging import math from functools import partial +from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -28,6 +29,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_, _assert from timm.layers import create_conv2d, create_pool2d, to_ntuple, use_fused_attn, LayerNorm from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._features_fx import register_notrace_function from ._manipulate import checkpoint_seq, named_apply from ._registry import register_model, generate_default_cfgs, register_model_deprecations @@ -420,6 +422,67 @@ class Nest(nn.Module): self.global_pool, self.head = create_classifier( self.num_features, self.num_classes, pool_type=global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.levels), indices) + + # forward pass + x = self.patch_embed(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.levels + else: + stages = self.levels[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + # Layer norm done over channel dim only (to NHWC and back) + x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.levels), indices) + self.levels = self.levels[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.patch_embed(x) x = self.levels(x) diff --git a/timm/models/nextvit.py b/timm/models/nextvit.py index 01e63fce..5d6ec972 100644 --- a/timm/models/nextvit.py +++ b/timm/models/nextvit.py @@ -6,7 +6,7 @@ Next-ViT model defs and weights adapted from https://github.com/bytedance/Next-V """ # Copyright (c) ByteDance Inc. All rights reserved. from functools import partial -from typing import Optional +from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -16,6 +16,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import DropPath, trunc_normal_, ConvMlp, get_norm_layer, get_act_layer, use_fused_attn from timm.layers import ClassifierHead from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._features_fx import register_notrace_function from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model @@ -560,6 +561,66 @@ class NextViT(nn.Module): self.num_classes = num_classes self.head.reset(num_classes, pool_type=global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + x = self.norm(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): diff --git a/timm/models/pvt_v2.py b/timm/models/pvt_v2.py index 6977350a..8cd42fe8 100644 --- a/timm/models/pvt_v2.py +++ b/timm/models/pvt_v2.py @@ -16,7 +16,7 @@ Modifications and timm support by / Copyright 2022, Ross Wightman """ import math -from typing import Callable, List, Optional, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -25,6 +25,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import DropPath, to_2tuple, to_ntuple, trunc_normal_, LayerNorm, use_fused_attn from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint from ._registry import register_model, generate_default_cfgs @@ -386,6 +387,62 @@ class PyramidVisionTransformerV2(nn.Module): self.global_pool = global_pool self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.patch_embed(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.patch_embed(x) x = self.stages(x)