diff --git a/timm/models/tnt.py b/timm/models/tnt.py index a48fa5ac..7decfa9a 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -10,7 +10,7 @@ The official pytorch code is released and available at https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/tnt_pytorch """ import math -from typing import Optional +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -18,6 +18,7 @@ import torch.nn as nn from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.layers import Mlp, DropPath, trunc_normal_, _assert, to_2tuple, resample_abs_pos_embed from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint from ._registry import register_model @@ -172,7 +173,16 @@ class PixelEmbed(nn.Module): else: self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size) - def forward(self, x, pixel_pos): + def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]: + if as_scalar: + return max(self.patch_size) + else: + return self.patch_size + + def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]: + return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1] + + def forward(self, x: torch.Tensor, pixel_pos: torch.Tensor) -> torch.Tensor: B, C, H, W = x.shape _assert(H == self.img_size[0], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") @@ -222,6 +232,7 @@ class TNT(nn.Module): self.num_classes = num_classes self.global_pool = global_pool self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models + self.num_prefix_tokens = 1 self.grad_checkpointing = False self.pixel_embed = PixelEmbed( @@ -233,6 +244,7 @@ class TNT(nn.Module): legacy=legacy, ) num_patches = self.pixel_embed.num_patches + r = self.pixel_embed.feat_ratio() if hasattr(self.pixel_embed, 'feat_ratio') else patch_size self.num_patches = num_patches new_patch_size = self.pixel_embed.new_patch_size num_pixel = new_patch_size[0] * new_patch_size[1] @@ -264,8 +276,10 @@ class TNT(nn.Module): legacy=legacy, )) self.blocks = nn.ModuleList(blocks) + self.feature_info = [ + dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)] + self.norm = norm_layer(embed_dim) - self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() @@ -313,6 +327,92 @@ class TNT(nn.Module): self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + return_prefix_tokens: bool = False, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if an int, if is a sequence, select by matching indices + return_prefix_tokens: Return both prefix and spatial intermediate tokens + norm: Apply norm layer to all intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.' + reshape = output_fmt == 'NCHW' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.blocks), indices) + + # forward pass + B, _, height, width = x.shape + + pixel_embed = self.pixel_embed(x, self.pixel_pos) + + patch_embed = self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1)))) + patch_embed = torch.cat((self.cls_token.expand(B, -1, -1), patch_embed), dim=1) + patch_embed = patch_embed + self.patch_pos + patch_embed = self.pos_drop(patch_embed) + + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = self.blocks + else: + blocks = self.blocks[:max_index + 1] + + for i, blk in enumerate(blocks): + pixel_embed, patch_embed = blk(pixel_embed, patch_embed) + if i in take_indices: + # normalize intermediates with final norm layer if enabled + intermediates.append(self.norm(patch_embed) if norm else patch_embed) + + # process intermediates + if self.num_prefix_tokens: + # split prefix (e.g. class, distill) and spatial feature tokens + prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates] + intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates] + + if reshape: + # reshape to BCHW output format + H, W = self.pixel_embed.dynamic_feat_size((height, width)) + intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates] + if not torch.jit.is_scripting() and return_prefix_tokens: + # return_prefix not support in torchscript due to poor type handling + intermediates = list(zip(intermediates, prefix_tokens)) + + if intermediates_only: + return intermediates + + patch_embed = self.norm(patch_embed) + + return patch_embed, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.blocks), indices) + self.blocks = self.blocks[:max_index + 1] # truncate blocks + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): B = x.shape[0] pixel_embed = self.pixel_embed(x, self.pixel_pos) @@ -322,11 +422,10 @@ class TNT(nn.Module): patch_embed = patch_embed + self.patch_pos patch_embed = self.pos_drop(patch_embed) - if self.grad_checkpointing and not torch.jit.is_scripting(): - for blk in self.blocks: + for blk in self.blocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): pixel_embed, patch_embed = checkpoint(blk, pixel_embed, patch_embed) - else: - for blk in self.blocks: + else: pixel_embed, patch_embed = blk(pixel_embed, patch_embed) patch_embed = self.norm(patch_embed) @@ -334,7 +433,7 @@ class TNT(nn.Module): def forward_head(self, x, pre_logits: bool = False): if self.global_pool: - x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] x = self.head_drop(x) return x if pre_logits else self.head(x) @@ -344,6 +443,30 @@ class TNT(nn.Module): return x +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'pixel_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + 'tnt_s_patch16_224.in1k': _cfg( + # hf_hub_id='timm/', + # url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar', + url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_s_81.5.pth.tar', + ), + 'tnt_b_patch16_224.in1k': _cfg( + # hf_hub_id='timm/', + url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_b_82.9.pth.tar', + ), +} + + def checkpoint_filter_fn(state_dict, model): state_dict.pop('outer_tokens', None) @@ -380,40 +503,15 @@ def checkpoint_filter_fn(state_dict, model): def _create_tnt(variant, pretrained=False, **kwargs): - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Vision Transformer models.') - + out_indices = kwargs.pop('out_indices', 3) model = build_model_with_cfg( TNT, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), **kwargs) return model -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, - 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, - 'first_conv': 'pixel_embed.proj', 'classifier': 'head', - **kwargs - } - - -default_cfgs = { - 'tnt_s_patch16_224': _cfg( - # hf_hub_id='timm/', - # url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar', - url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_s_81.5.pth.tar', - ), - 'tnt_b_patch16_224': _cfg( - # hf_hub_id='timm/', - url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_b_82.9.pth.tar', - ), -} - - @register_model def tnt_s_patch16_224(pretrained=False, **kwargs) -> TNT: model_cfg = dict(