From 679daef76a22415ac1cb666d970c1543699ecf51 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 9 Apr 2024 21:29:16 -0700 Subject: [PATCH] More forward_intermediates() & features_only work * forward_intermediates() added to beit, deit, eva, mvitv2, twins, vit, vit_sam * add features_only to forward intermediates to allow just intermediate features * fix #2060 * fix #1374 * fix #657 --- timm/models/_features.py | 77 +++++++++++---- timm/models/beit.py | 102 ++++++++++++++++++-- timm/models/deit.py | 4 +- timm/models/eva.py | 93 +++++++++++++++++-- timm/models/mvitv2.py | 78 +++++++++++++--- timm/models/twins.py | 62 ++++++++++++- timm/models/vision_transformer.py | 129 ++++++++++++++++++-------- timm/models/vision_transformer_sam.py | 68 +++++++++++--- 8 files changed, 512 insertions(+), 101 deletions(-) diff --git a/timm/models/_features.py b/timm/models/_features.py index cc4068d4..fa108798 100644 --- a/timm/models/_features.py +++ b/timm/models/_features.py @@ -11,7 +11,7 @@ Hacked together by / Copyright 2020 Ross Wightman from collections import OrderedDict, defaultdict from copy import deepcopy from functools import partial -from typing import Dict, List, Optional, Sequence, Tuple, Union +from typing import Dict, List, Optional, Sequence, Set, Tuple, Union import torch import torch.nn as nn @@ -20,7 +20,39 @@ from torch.utils.checkpoint import checkpoint from timm.layers import Format -__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet'] +__all__ = [ + 'FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet', 'FeatureGetterNet', + 'feature_take_indices' +] + + +def _take_indices(n: Union[int, List[int], Tuple[int]], num_blocks: int) -> Tuple[Set[int], int]: + if isinstance(n, int): + assert n >= 0 + take_indices = {x for x in range(num_blocks - n, num_blocks)} + else: + take_indices = {num_blocks + idx if idx < 0 else idx for idx in n} + return take_indices, max(take_indices) + + +def _take_indices_jit(n: Union[int, List[int], Tuple[int]], num_blocks: int) -> Tuple[List[int], int]: + if isinstance(n, int): + assert n >= 0 + take_indices = [num_blocks - n + i for i in range(n)] + elif isinstance(n, tuple): + # splitting this up is silly, but needed for torchscript type resolution of n + take_indices = [num_blocks + idx if idx < 0 else idx for idx in n] + else: + take_indices = [num_blocks + idx if idx < 0 else idx for idx in n] + return take_indices, max(take_indices) + + +def feature_take_indices(n: Union[int, List[int], Tuple[int]], num_blocks: int) -> Tuple[List[int], int]: + if torch.jit.is_scripting(): + return _take_indices_jit(n, num_blocks) + else: + # NOTE non-jit returns Set[int] instead of List[int] but torchscript can't handle that anno + return _take_indices(n, num_blocks) def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]: @@ -397,29 +429,38 @@ class FeatureGetterNet(nn.ModuleDict): out_map: Optional[Sequence[Union[int, str]]] = None, return_dict: bool = False, output_fmt: str = 'NCHW', + norm: bool = False, + prune: bool = True, ): + """ + + Args: + model: Model to wrap. + out_indices: Indices of features to extract. + out_map: Remap feature names for dict output (WIP, not supported). + return_dict: Return features as dictionary instead of list (WIP, not supported). + norm: Apply final model norm to all output features (if possible). + """ super().__init__() - self.model = model + if prune and hasattr(model, 'prune_intermediate_layers'): + model.prune_intermediate_layers( + out_indices, + prune_norm=not norm, + ) self.feature_info = _get_feature_info(model, out_indices) + self.model = model self.out_indices = out_indices self.out_map = out_map self.return_dict = return_dict self.output_fmt = output_fmt + self.norm = norm - def forward(self, *args, **kwargs): - """ - def get_intermediate_layers( - self, - x: torch.Tensor, - n: Union[int, Sequence] = 1, - reshape: bool = False, - return_prefix_tokens: bool = False, - norm: bool = False, - """ - out = self.model.get_intermediate_layers( - *args, + def forward(self, x): + features = self.model.forward_intermediates( + x, n=self.out_indices, - reshape=True, - **kwargs, + norm=self.norm, + output_fmt=self.output_fmt, + features_only=True, ) - return out + return features diff --git a/timm/models/beit.py b/timm/models/beit.py index 0167099c..46d460d3 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -39,7 +39,7 @@ Modifications by / Copyright 2021 Ross Wightman, original copyrights below # --------------------------------------------------------' import math -from typing import Callable, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -52,8 +52,8 @@ from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._registry import generate_default_cfgs, register_model -from .vision_transformer import checkpoint_filter_fn __all__ = ['Beit'] @@ -333,6 +333,8 @@ class Beit(nn.Module): window_size=self.patch_embed.grid_size if use_rel_pos_bias else None, ) for i in range(depth)]) + self.feature_info = [ + dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=patch_size) for i in range(depth)] use_fc_norm = self.global_pool == 'avg' self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim) @@ -398,6 +400,93 @@ class Beit(nn.Module): self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + def forward_intermediates( + self, + x: torch.Tensor, + n: Optional[Union[int, List[int], Tuple[int]]] = None, + return_prefix_tokens: bool = False, + norm: bool = False, + stop_early: bool = True, + output_fmt: str = 'NCHW', + features_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + n: Take last n blocks if n is an int, if in is a sequence, select by matching indices + return_prefix_tokens: Return both prefix and spatial intermediate tokens + norm: Apply norm layer to all intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + features_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.' + reshape = output_fmt == 'NCHW' + intermediates = [] + num_blocks = len(self.blocks) + if n is None: + n = num_blocks + take_indices, max_index = feature_take_indices(n, num_blocks) + + # forward pass + B, _, height, width = x.shape + x = self.patch_embed(x) + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = self.blocks + else: + blocks = self.blocks[:max_index + 1] + for i, blk in enumerate(blocks): + x = blk(x, shared_rel_pos_bias=rel_pos_bias) + if i in take_indices: + # normalize intermediates with final norm layer if enabled + intermediates.append(self.norm(x) if norm else x) + + # process intermediates + if self.num_prefix_tokens: + # split prefix (e.g. class, distill) and spatial feature tokens + prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates] + intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates] + if reshape: + # reshape == True => BCHW output format + patch_size = self.patch_embed.patch_size + H = int(math.ceil(height / patch_size[0])) + W = int(math.ceil(width / patch_size[1])) + intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates] + if not torch.jit.is_scripting() and return_prefix_tokens: + # return_prefix not support in torchscript due to poor type handling + intermediates = list(zip(intermediates, prefix_tokens)) + + if features_only: + return intermediates + + x = self.norm(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + n: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(n, len(self.blocks)) + self.blocks = self.blocks[:max_index + 1] # truncate blocks + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.fc_norm = nn.Identity() + self.head = nn.Identity() + def forward_features(self, x): x = self.patch_embed(x) x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) @@ -547,14 +636,13 @@ def _beit_checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antia def _create_beit(variant, pretrained=False, **kwargs): - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for BEiT models.') - + out_indices = kwargs.pop('out_indices', 3) model = build_model_with_cfg( Beit, variant, pretrained, - # FIXME an updated filter fn needed to interpolate rel pos emb if fine tuning to diff model sizes pretrained_filter_fn=_beit_checkpoint_filter_fn, - **kwargs) + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), + **kwargs, + ) return model diff --git a/timm/models/deit.py b/timm/models/deit.py index f80087e8..9400549d 100644 --- a/timm/models/deit.py +++ b/timm/models/deit.py @@ -119,14 +119,14 @@ class VisionTransformerDistilled(VisionTransformer): def _create_deit(variant, pretrained=False, distilled=False, **kwargs): - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Vision Transformer models.') + out_indices = kwargs.pop('out_indices', 3) model_cls = VisionTransformerDistilled if distilled else VisionTransformer model = build_model_with_cfg( model_cls, variant, pretrained, pretrained_filter_fn=partial(checkpoint_filter_fn, adapt_layer_scale=True), + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), **kwargs, ) return model diff --git a/timm/models/eva.py b/timm/models/eva.py index 82fff28a..fe121b00 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -24,9 +24,8 @@ Modifications by / Copyright 2023 Ross Wightman, original copyrights below """ # EVA models Copyright (c) 2022 BAAI-Vision # EVA02 models Copyright (c) 2023 BAAI-Vision - import math -from typing import Callable, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -39,6 +38,7 @@ from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, Pa to_2tuple, use_fused_attn from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._registry import generate_default_cfgs, register_model __all__ = ['Eva'] @@ -469,6 +469,8 @@ class Eva(nn.Module): init_values=init_values, ) for i in range(depth)]) + self.feature_info = [ + dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=patch_size) for i in range(depth)] use_fc_norm = self.global_pool == 'avg' self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim) @@ -559,6 +561,85 @@ class Eva(nn.Module): rot_pos_embed = apply_keep_indices_nlc(x, rot_pos_embed, keep_indices) return x, rot_pos_embed + def forward_intermediates( + self, + x: torch.Tensor, + n: Optional[Union[int, List[int], Tuple[int]]] = None, + return_prefix_tokens: bool = False, + norm: bool = False, + stop_early: bool = True, + output_fmt: str = 'NCHW', + features_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + Args: + x: Input image tensor + n: Take last n blocks if n is an int, if in is a sequence, select by matching indices + return_prefix_tokens: Return both prefix and spatial intermediate tokens + norm: Apply norm layer to all intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + features_only: Only return intermediate features + """ + assert output_fmt in ('NCHW', 'NLC'), 'Output format for EVA-ViT features must be one of NCHW or NLC.' + reshape = output_fmt == 'NCHW' + intermediates = [] + num_blocks = len(self.blocks) + if n is None: + n = num_blocks + take_indices, max_index = feature_take_indices(n, num_blocks) + + # forward pass + B, _, height, width = x.shape + x = self.patch_embed(x) + x, rot_pos_embed = self._pos_embed(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = self.blocks + else: + blocks = self.blocks[:max_index + 1] + for i, blk in enumerate(blocks): + x = blk(x, rope=rot_pos_embed) + if i in take_indices: + intermediates.append(self.norm(x) if norm else x) + + # process intermediates + if self.num_prefix_tokens: + # split prefix (e.g. class, distill) and spatial feature tokens + prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates] + intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates] + if reshape: + # reshape == True => BCHW output format + patch_size = self.patch_embed.patch_size + H = int(math.ceil(height / patch_size[0])) + W = int(math.ceil(width / patch_size[1])) + intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates] + if not torch.jit.is_scripting() and return_prefix_tokens: + # return_prefix not support in torchscript due to poor type handling + intermediates = list(zip(intermediates, prefix_tokens)) + + if features_only: + return intermediates + + x = self.norm(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + n: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(n, len(self.blocks)) + self.blocks = self.blocks[:max_index + 1] # truncate blocks + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.fc_norm = nn.Identity() + self.head = nn.Identity() + def forward_features(self, x): x = self.patch_embed(x) x, rot_pos_embed = self._pos_embed(x) @@ -663,13 +744,13 @@ def checkpoint_filter_fn( def _create_eva(variant, pretrained=False, **kwargs): - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Eva models.') - + out_indices = kwargs.pop('out_indices', 3) model = build_model_with_cfg( Eva, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, - **kwargs) + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), + **kwargs, + ) return model diff --git a/timm/models/mvitv2.py b/timm/models/mvitv2.py index 9d035fd6..579aa87e 100644 --- a/timm/models/mvitv2.py +++ b/timm/models/mvitv2.py @@ -26,6 +26,7 @@ from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._features_fx import register_notrace_function from ._registry import register_model, register_model_deprecations, generate_default_cfgs @@ -747,8 +748,10 @@ class MultiScaleVit(nn.Module): num_stages = len(cfg.embed_dim) feat_size = patch_dims + curr_stride = max(cfg.patch_stride) dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] self.stages = nn.ModuleList() + self.feature_info = [] for i in range(num_stages): if cfg.expand_attn: dim_out = cfg.embed_dim[i] @@ -775,6 +778,8 @@ class MultiScaleVit(nn.Module): norm_layer=norm_layer, drop_path=dpr[i], ) + curr_stride *= max(cfg.stride_q[i]) + self.feature_info += [dict(module=f'block.{i}', num_chs=dim_out, reduction=curr_stride)] embed_dim = dim_out feat_size = stage.feat_size self.stages.append(stage) @@ -829,6 +834,51 @@ class MultiScaleVit(nn.Module): ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()) ])) + def forward_intermediates( + self, + x: torch.Tensor, + n: Union[int, List[int], Tuple[int]] = None, + norm: bool = False, + stop_early: bool = True, + output_fmt: str = 'NCHW', + features_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + assert output_fmt in ('NCHW', 'NLC'), 'Output shape for MViT-V2 must be NCHW or NLC.' + reshape = output_fmt == 'NCHW' + intermediates = [] + num_stages = len(self.stages) # block list is two-tiered, first tier == stage + if n is None: + n = num_stages + take_indices, max_index = feature_take_indices(n, num_stages) + + # FIXME slice block/pos_block if < max + + # forward pass + x, feat_size = self.patch_embed(x) + B = x.shape[0] + if self.cls_token is not None: + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + for i, stage in enumerate(self.stages): + x, feat_size = stage(x, feat_size) + if i in take_indices: + if norm and i == (len(self.stages) - 1): + x_inter = self.norm(x) # applying final norm last intermediate + else: + x_inter = x + if reshape: + x_inter = x_inter.reshape(B, feat_size[0], feat_size[1], -1).permute(0, 3, 1, 2) + intermediates.append(x_inter) + + if features_only: + return intermediates + + x = self.norm(x) + + return x, intermediates + def forward_features(self, x): x, feat_size = self.patch_embed(x) B, N, C = x.shape @@ -862,6 +912,18 @@ class MultiScaleVit(nn.Module): def checkpoint_filter_fn(state_dict, model): if 'stages.0.blocks.0.norm1.weight' in state_dict: + # native checkpoint, look for rel_pos interpolations + for k in state_dict.keys(): + if 'rel_pos' in k: + rel_pos = state_dict[k] + dest_rel_pos_shape = model.state_dict()[k].shape + if rel_pos.shape[0] != dest_rel_pos_shape[0]: + rel_pos_resized = torch.nn.functional.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=dest_rel_pos_shape[0], + mode="linear", + ) + state_dict[k] = rel_pos_resized.reshape(-1, dest_rel_pos_shape[0]).permute(1, 0) return state_dict import re @@ -892,16 +954,6 @@ def checkpoint_filter_fn(state_dict, model): k = k.replace('head.projection', 'head.fc') out_dict[k] = v - # for k, v in state_dict.items(): - # if model.pos_embed is not None and k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]: - # # To resize pos embedding when using model at different size from pretrained weights - # v = resize_pos_embed( - # v, - # model.pos_embed, - # 0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1), - # model.patch_embed.grid_size - # ) - return out_dict @@ -948,16 +1000,14 @@ model_cfgs = dict( def _create_mvitv2(variant, cfg_variant=None, pretrained=False, **kwargs): - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Multiscale Vision Transformer models.') - + out_indices = kwargs.pop('out_indices', 4) return build_model_with_cfg( MultiScaleVit, variant, pretrained, model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant], pretrained_filter_fn=checkpoint_filter_fn, - feature_cfg=dict(flatten_sequential=True), + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), **kwargs, ) diff --git a/timm/models/twins.py b/timm/models/twins.py index 3cd25fb4..feba8e37 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -13,7 +13,7 @@ Code/weights from https://github.com/Meituan-AutoML/Twins, original copyright/li # -------------------------------------------------------- import math from functools import partial -from typing import Tuple +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -22,6 +22,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import Mlp, DropPath, to_2tuple, trunc_normal_, use_fused_attn from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._features_fx import register_notrace_module from ._registry import register_model, generate_default_cfgs from .vision_transformer import Attention @@ -324,6 +325,7 @@ class Twins(nn.Module): patch_size = 2 self.blocks = nn.ModuleList() + self.feature_info = [] dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule cur = 0 for k in range(len(depths)): @@ -339,6 +341,7 @@ class Twins(nn.Module): ws=1 if wss is None or i % 2 == 1 else wss[k]) for i in range(depths[k])], ) self.blocks.append(_block) + self.feature_info += [dict(module=f'block.{k}', num_chs=embed_dims[k], reduction=2**(2+k))] cur += depths[k] self.pos_block = nn.ModuleList([PosConv(embed_dim, embed_dim) for embed_dim in embed_dims]) @@ -401,6 +404,53 @@ class Twins(nn.Module): if m.bias is not None: m.bias.data.zero_() + def forward_intermediates( + self, + x: torch.Tensor, + n: Union[int, List[int], Tuple[int]] = None, + norm: bool = False, + stop_early: bool = True, + output_fmt: str = 'NCHW', + features_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + assert output_fmt == 'NCHW', 'Output shape for Twins must be NCHW.' + intermediates = [] + num_stages = len(self.blocks) # block list is two-tiered, first tier == stage + if n is None: + n = num_stages + take_indices, max_index = feature_take_indices(n, num_stages) + + # FIXME slice block/pos_block if < max + + # forward pass + B, _, height, width = x.shape + for i, (embed, drop, blocks, pos_blk) in enumerate(zip( + self.patch_embeds, self.pos_drops, self.blocks, self.pos_block) + ): + x, size = embed(x) + x = drop(x) + for j, blk in enumerate(blocks): + x = blk(x, size) + if j == 0: + x = pos_blk(x, size) # PEG here + + if i < len(self.depths) - 1: + x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() + if i in take_indices: + intermediates.append(x) + else: + if i in take_indices: + # only last feature can be normed + x_feat = self.norm(x) if norm else x + intermediates.append(x_feat.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()) + + if features_only: + return intermediates + + x = self.norm(x) + + return x, intermediates + def forward_features(self, x): B = x.shape[0] for i, (embed, drop, blocks, pos_blk) in enumerate( @@ -429,10 +479,12 @@ class Twins(nn.Module): def _create_twins(variant, pretrained=False, **kwargs): - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Vision Transformer models.') - - model = build_model_with_cfg(Twins, variant, pretrained, **kwargs) + out_indices = kwargs.pop('out_indices', 4) + model = build_model_with_cfg( + Twins, variant, pretrained, + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), + **kwargs, + ) return model diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index b57104ac..24225206 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -45,6 +45,7 @@ from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \ get_act_layer, get_norm_layer, LayerType from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv from ._registry import generate_default_cfgs, register_model, register_model_deprecations @@ -473,7 +474,6 @@ class VisionTransformer(nn.Module): self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg) self.dynamic_img_size = dynamic_img_size self.grad_checkpointing = False - self.feature_info = [] embed_args = {} if dynamic_img_size: @@ -631,58 +631,111 @@ class VisionTransformer(nn.Module): return self.pos_drop(x) - def _intermediate_layers( + def forward_intermediates( self, x: torch.Tensor, - n: Union[int, Sequence] = 1, - ) -> List[torch.Tensor]: - outputs, num_blocks = [], len(self.blocks) - take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n) - last_index_to_take = max(take_indices) + n: Optional[Union[int, List[int], Tuple[int]]] = None, + return_prefix_tokens: bool = False, + norm: bool = False, + stop_early: bool = True, + output_fmt: str = 'NCHW', + features_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + n: Take last n blocks if n is an int, if in is a sequence, select by matching indices + return_prefix_tokens: Return both prefix and spatial intermediate tokens + norm: Apply norm layer to all intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + features_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.' + reshape = output_fmt == 'NCHW' + intermediates = [] + num_blocks = len(self.blocks) + if n is None: + n = num_blocks + take_indices, max_index = feature_take_indices(n, num_blocks) # forward pass + B, _, height, width = x.shape x = self.patch_embed(x) x = self._pos_embed(x) x = self.patch_drop(x) x = self.norm_pre(x) - for i, blk in enumerate(self.blocks[: last_index_to_take + 1]): + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = self.blocks + else: + blocks = self.blocks[:max_index + 1] + for i, blk in enumerate(blocks): x = blk(x) if i in take_indices: - outputs.append(x) + # normalize intermediates with final norm layer if enabled + intermediates.append(self.norm(x) if norm else x) - return outputs + # process intermediates + if self.num_prefix_tokens: + # split prefix (e.g. class, distill) and spatial feature tokens + prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates] + intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates] + if reshape: + # reshape == True => BCHW output format + patch_size = self.patch_embed.patch_size + H = int(math.ceil(height / patch_size[0])) + W = int(math.ceil(width / patch_size[1])) + intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates] + if not torch.jit.is_scripting() and return_prefix_tokens: + # return_prefix not support in torchscript due to poor type handling + intermediates = list(zip(intermediates, prefix_tokens)) + + if features_only: + return intermediates + + x = self.norm(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + n: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(n, len(self.blocks)) + self.blocks = self.blocks[:max_index + 1] # truncate blocks + if prune_norm: + self.norm = nn.Identity() + if prune_head: + if self.attn_pool is not None: + self.attn_pool = None + self.fc_norm = nn.Identity() + self.head = nn.Identity() def get_intermediate_layers( self, x: torch.Tensor, - n: Union[int, Sequence] = 1, + n: Union[int, List[int], Tuple[int]] = 1, reshape: bool = False, return_prefix_tokens: bool = False, norm: bool = False, - ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: - """ Intermediate layer accessor (NOTE: This is a WIP experiment). - Inspired by DINO / DINOv2 interface + ) -> List[torch.Tensor]: + """ Intermediate layer accessor inspired by DINO / DINOv2 interface. + NOTE: This API is for backwards compat, favour using forward_intermediates() directly. """ - # take last n blocks if n is an int, if in is a sequence, select by matching indices - outputs = self._intermediate_layers(x, n) - if norm: - outputs = [self.norm(out) for out in outputs] - prefix_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs] - outputs = [out[:, self.num_prefix_tokens:] for out in outputs] - - if reshape: - patch_size = self.patch_embed.patch_size - batch, _, height, width = x.size() - outputs = [ - out.reshape(batch, int(math.ceil(height / patch_size[0])), int(math.ceil(width / patch_size[1])), -1) - .permute(0, 3, 1, 2) - .contiguous() - for out in outputs - ] - - if return_prefix_tokens: - return tuple(zip(outputs, prefix_tokens)) - return tuple(outputs) + return self.forward_intermediates( + x, n, + return_prefix_tokens=return_prefix_tokens, + norm=norm, + output_fmt='NCHW' if reshape else 'NLC', + features_only=True, + ) def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) @@ -2485,7 +2538,7 @@ def vit_huge_patch14_xp_224(pretrained: bool = False, **kwargs) -> VisionTransfo def vit_small_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-S/14 for DINOv2 """ - model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5, img_size=518) + model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5) model = _create_vision_transformer( 'vit_small_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2495,7 +2548,7 @@ def vit_small_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransf def vit_base_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-B/14 for DINOv2 """ - model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, img_size=518) + model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5) model = _create_vision_transformer( 'vit_base_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2505,7 +2558,7 @@ def vit_base_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransfo def vit_large_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-L/14 for DINOv2 """ - model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5, img_size=518) + model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5) model = _create_vision_transformer( 'vit_large_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2521,7 +2574,7 @@ def vit_giant_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransf # With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192 model_args = dict( patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5, - mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, img_size=518, act_layer=nn.SiLU + mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, act_layer=nn.SiLU ) model = _create_vision_transformer( 'vit_giant_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) diff --git a/timm/models/vision_transformer_sam.py b/timm/models/vision_transformer_sam.py index 59b354fb..1171c8b9 100644 --- a/timm/models/vision_transformer_sam.py +++ b/timm/models/vision_transformer_sam.py @@ -11,21 +11,22 @@ A PyTorch implement of Vision Transformers as described in: """ import logging from functools import partial -from typing import Callable, Optional, Tuple +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from timm.layers import PatchEmbed, Mlp, DropPath, PatchDropout, LayerNorm2d, ClassifierHead, NormMlpClassifierHead, \ + Format, resample_abs_pos_embed_nhwc, RotaryEmbeddingCat, apply_rot_embed_cat, to_2tuple, use_fused_attn from torch.jit import Final -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from timm.layers import PatchEmbed, Mlp, DropPath, PatchDropout, LayerNorm2d, ClassifierHead, NormMlpClassifierHead,\ - Format, resample_abs_pos_embed_nhwc, RotaryEmbeddingCat, apply_rot_embed_cat, to_2tuple, use_fused_attn from ._builder import build_model_with_cfg +from ._features import feature_take_indices +from ._features_fx import register_notrace_function from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model -from ._features_fx import register_notrace_function # model_registry will add each entrypoint fn to this __all__ = ['VisionTransformerSAM'] @@ -343,8 +344,7 @@ class VisionTransformerSAM(nn.Module): attn_drop_rate: float = 0., drop_path_rate: float = 0., weight_init: str = '', - embed_layer: Callable = partial( - PatchEmbed, output_fmt=Format.NHWC, strict_img_size=False), + embed_layer: Callable = partial(PatchEmbed, output_fmt=Format.NHWC, strict_img_size=False), norm_layer: Optional[Callable] = nn.LayerNorm, act_layer: Optional[Callable] = nn.GELU, block_fn: Callable = Block, @@ -469,6 +469,8 @@ class VisionTransformerSAM(nn.Module): rope=self.rope_window if i not in global_attn_indexes else self.rope_global, ) for i in range(depth)]) + self.feature_info = [ + dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=patch_size) for i in range(depth)] if neck_chans: self.neck = nn.Sequential( @@ -536,6 +538,52 @@ class VisionTransformerSAM(nn.Module): def reset_classifier(self, num_classes=0, global_pool=None): self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + n: Union[int, List[int], Tuple[int]] = None, + norm: bool = False, + stop_early: bool = True, + output_fmt: str = 'NCHW', + features_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + assert output_fmt == 'NCHW', 'Output shape for ViT-SAM must be NCHW.' + intermediates = [] + num_blocks = len(self.blocks) + if n is None: + n = num_blocks + take_indices, max_index = feature_take_indices(n, num_blocks) + + # forward pass, collect intermediates + x = self.patch_embed(x) + if self.pos_embed is not None: + # dynamically resize abs pos embedding if needed + x = x + resample_abs_pos_embed_nhwc(self.pos_embed, x.shape[1:3]) + x = self.pos_drop(x) + x = self.patch_drop(x) + x = self.norm_pre(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = self.blocks + else: + blocks = self.blocks[:max_index + 1] + for i, blk in enumerate(blocks): + x = blk(x) + if i in take_indices: + # make output BCHW + if norm: + # norm is intertwined with neck convs so apply both, changes the dim + # FIXME only apply to final? Need experiments + intermediates.append(self.neck(x.permute(0, 3, 1, 2))) + else: + intermediates.append(x.permute(0, 3, 1, 2)) + + if features_only: + return intermediates + + x = self.neck(x.permute(0, 3, 1, 2)) + + return x, intermediates + def forward_features(self, x): x = self.patch_embed(x) if self.pos_embed is not None: @@ -618,15 +666,13 @@ default_cfgs = generate_default_cfgs({ def _create_vision_transformer(variant, pretrained=False, **kwargs): - if kwargs.get('features_only', None): - raise RuntimeError( - 'features_only not implemented for Vision Transformer models.') - + out_indices = kwargs.pop('out_indices', 3) return build_model_with_cfg( VisionTransformerSAM, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), **kwargs, )