diff --git a/timm/models/tiny_vit.py b/timm/models/tiny_vit.py index 12a5ef2f..d238fa5b 100644 --- a/timm/models/tiny_vit.py +++ b/timm/models/tiny_vit.py @@ -10,7 +10,7 @@ __all__ = ['TinyVit'] import itertools from functools import partial -from typing import Dict, Optional +from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -20,6 +20,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import LayerNorm2d, NormMlpClassifierHead, DropPath,\ trunc_normal_, resize_rel_pos_bias_table_levit, use_fused_attn from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._features_fx import register_notrace_module from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -536,6 +537,62 @@ class TinyVit(nn.Module): self.num_classes = num_classes self.head.reset(num_classes, pool_type=global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.patch_embed(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.patch_embed(x) if self.grad_checkpointing and not torch.jit.is_scripting(): diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index 59f2507e..28e90285 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -253,7 +253,6 @@ class TResNet(nn.Module): assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' intermediates = [] take_indices, max_index = feature_take_indices(len(self.body) - 1, indices) - print(take_indices, max_index) # forward pass x = self.body[0](x) # s2d @@ -261,7 +260,6 @@ class TResNet(nn.Module): stages = [self.body[1], self.body[2], self.body[3], self.body[4], self.body[5]] else: stages = self.body[1:max_index + 2] - print(len(stages)) for feat_idx, stage in enumerate(stages): x = stage(x)