diff --git a/timm/models/_features.py b/timm/models/_features.py index 565f1dd8..9dbac1cd 100644 --- a/timm/models/_features.py +++ b/timm/models/_features.py @@ -165,6 +165,7 @@ class FeatureHooks: ): # setup feature hooks self._feature_outputs = defaultdict(OrderedDict) + self._handles = [] modules = {k: v for k, v in named_modules} for i, h in enumerate(hooks): hook_name = h['module'] @@ -173,11 +174,12 @@ class FeatureHooks: hook_fn = partial(self._collect_output_hook, hook_id) hook_type = h.get('hook_type', default_hook_type) if hook_type == 'forward_pre': - m.register_forward_pre_hook(hook_fn) + handle = m.register_forward_pre_hook(hook_fn) elif hook_type == 'forward': - m.register_forward_hook(hook_fn) + handle = m.register_forward_hook(hook_fn) else: assert False, "Unsupported hook type" + self._handles.append(handle) def _collect_output_hook(self, hook_id, *args): x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre diff --git a/timm/models/cait.py b/timm/models/cait.py index 40d56061..50148405 100644 --- a/timm/models/cait.py +++ b/timm/models/cait.py @@ -9,6 +9,7 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W # Copyright (c) 2015-present, Facebook, Inc. # All rights reserved. from functools import partial +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -16,6 +17,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, use_fused_attn from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -246,8 +248,8 @@ class Cait(nn.Module): in_chans=in_chans, embed_dim=embed_dim, ) - num_patches = self.patch_embed.num_patches + r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) @@ -268,6 +270,7 @@ class Cait(nn.Module): mlp_block=mlp_block, init_values=init_values, ) for i in range(depth)]) + self.feature_info = [dict(num_chs=embed_dim, reduction=r, module=f'blocks.{i}') for i in range(depth)] self.blocks_token_only = nn.ModuleList([block_layers_token( dim=embed_dim, @@ -283,7 +286,6 @@ class Cait(nn.Module): self.norm = norm_layer(embed_dim) - self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')] self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() @@ -336,6 +338,80 @@ class Cait(nn.Module): self.global_pool = global_pool self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, + norm: bool = False, + stop_early: bool = True, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to all intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + """ + assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.' + reshape = output_fmt == 'NCHW' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.blocks), indices) + + # forward pass + B, _, height, width = x.shape + x = self.patch_embed(x) + x = x + self.pos_embed + x = self.pos_drop(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = self.blocks + else: + blocks = self.blocks[:max_index + 1] + for i, blk in enumerate(blocks): + x = blk(x) + if i in take_indices: + # normalize intermediates with final norm layer if enabled + intermediates.append(self.norm(x) if norm else x) + + # process intermediates + if reshape: + # reshape to BCHW output format + H, W = self.patch_embed.dynamic_feat_size((height, width)) + intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates] + + if intermediates_only: + return intermediates + + # NOTE not supporting return of class tokens + cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) + for i, blk in enumerate(self.blocks_token_only): + cls_tokens = blk(x, cls_tokens) + x = torch.cat((cls_tokens, x), dim=1) + x = self.norm(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + n: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.blocks), n) + self.blocks = self.blocks[:max_index + 1] # truncate blocks + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.blocks_token_only = nn.ModuleList() # prune token blocks with head + self.head = nn.Identity() + return take_indices + def forward_features(self, x): x = self.patch_embed(x) x = x + self.pos_embed @@ -373,14 +449,13 @@ def checkpoint_filter_fn(state_dict, model=None): def _create_cait(variant, pretrained=False, **kwargs): - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Vision Transformer models.') - + out_indices = kwargs.pop('out_indices', 3) model = build_model_with_cfg( Cait, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), **kwargs, ) return model diff --git a/timm/models/volo.py b/timm/models/volo.py index 260cd20d..d997f909 100644 --- a/timm/models/volo.py +++ b/timm/models/volo.py @@ -20,6 +20,7 @@ Modifications and additions for timm by / Copyright 2022, Ross Wightman # See the License for the specific language governing permissions and # limitations under the License. import math +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -28,8 +29,9 @@ import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, Mlp, to_2tuple, to_ntuple, trunc_normal_ +from timm.layers import DropPath, Mlp, to_2tuple, to_ntuple, trunc_normal_, use_fused_attn from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._registry import register_model, generate_default_cfgs __all__ = ['VOLO'] # model_registry will add each entrypoint fn to this @@ -119,24 +121,24 @@ class Outlooker(nn.Module): qkv_bias=qkv_bias, attn_drop=attn_drop, ) - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, - hidden_features=mlp_hidden_dim, + hidden_features=int(dim * mlp_ratio), act_layer=act_layer, ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): - x = x + self.drop_path(self.attn(self.norm1(x))) - x = x + self.drop_path(self.mlp(self.norm2(x))) + x = x + self.drop_path1(self.attn(self.norm1(x))) + x = x + self.drop_path2(self.mlp(self.norm2(x))) return x class Attention(nn.Module): + fused_attn: torch.jit.Final[bool] def __init__( self, @@ -150,6 +152,7 @@ class Attention(nn.Module): self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 + self.fused_attn = use_fused_attn() self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) @@ -162,11 +165,19 @@ class Attention(nn.Module): qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, k, v, + dropout_p=self.attn_drop.p if self.training else 0., + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v - x = (attn @ v).transpose(1, 2).reshape(B, H, W, C) + x = x.transpose(1, 2).reshape(B, H, W, C) x = self.proj(x) x = self.proj_drop(x) @@ -189,17 +200,15 @@ class Transformer(nn.Module): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop) - - # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): - x = x + self.drop_path(self.attn(self.norm1(x))) - x = x + self.drop_path(self.mlp(self.norm2(x))) + x = x + self.drop_path1(self.attn(self.norm1(x))) + x = x + self.drop_path2(self.mlp(self.norm2(x))) return x @@ -234,8 +243,9 @@ class ClassAttention(nn.Module): kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) k, v = kv.unbind(0) - q = self.q(x[:, :1, :]).reshape(B, self.num_heads, 1, self.head_dim) - attn = ((q * self.scale) @ k.transpose(-2, -1)) + q = self.q(x[:, :1, :]).reshape(B, self.num_heads, 1, self.head_dim) * self.scale + + attn = q @ k.transpose(-2, -1) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) @@ -270,21 +280,21 @@ class ClassBlock(nn.Module): attn_drop=attn_drop, proj_drop=drop, ) - # NOTE: drop path for stochastic depth - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, - hidden_features=mlp_hidden_dim, + hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop, ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): cls_embed = x[:, :1] - cls_embed = cls_embed + self.drop_path(self.attn(self.norm1(x))) - cls_embed = cls_embed + self.drop_path(self.mlp(self.norm2(cls_embed))) + cls_embed = cls_embed + self.drop_path1(self.attn(self.norm1(x))) + cls_embed = cls_embed + self.drop_path2(self.mlp(self.norm2(cls_embed))) return torch.cat([cls_embed, x[:, 1:]], dim=1) @@ -495,6 +505,7 @@ class VOLO(nn.Module): hidden_dim=stem_hidden_dim, embed_dim=embed_dims[0], ) + r = patch_size # inital positional encoding, we add positional encoding after outlooker blocks patch_grid = (img_size[0] // patch_size // pooling_scale, img_size[1] // patch_size // pooling_scale) @@ -502,7 +513,10 @@ class VOLO(nn.Module): self.pos_drop = nn.Dropout(p=pos_drop_rate) # set the main block in network + self.stage_ends = [] + self.feature_info = [] network = [] + block_idx = 0 for i in range(len(layers)): if outlook_attention[i]: # stage 1 @@ -517,7 +531,6 @@ class VOLO(nn.Module): attn_drop=attn_drop_rate, norm_layer=norm_layer, ) - network.append(stage) else: # stage 2 stage = transformer_blocks( @@ -532,11 +545,15 @@ class VOLO(nn.Module): attn_drop=attn_drop_rate, norm_layer=norm_layer, ) - network.append(stage) - + network.append(stage) + self.stage_ends.append(block_idx) + self.feature_info.append(dict(num_chs=embed_dims[i], reduction=r, module=f'network.{block_idx}')) + block_idx += 1 if downsamples[i]: # downsampling between two stages network.append(Downsample(embed_dims[i], embed_dims[i + 1], 2)) + r *= 2 + block_idx += 1 self.network = nn.ModuleList(network) @@ -691,6 +708,83 @@ class VOLO(nn.Module): # return these: 1. class token, 2. classes from all feature tokens, 3. bounding box return x_cls, x_aux, (bbx1, bby1, bbx2, bby2) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, + norm: bool = False, + stop_early: bool = True, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to all intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output format must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) + take_indices = [self.stage_ends[i] for i in take_indices] + max_index = self.stage_ends[max_index] + + # forward pass + B, _, height, width = x.shape + x = self.patch_embed(x).permute(0, 2, 3, 1) # B,C,H,W-> B,H,W,C + + # step2: tokens learning in the two stages + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + network = self.network + else: + network = self.network[:max_index + 1] + for idx, block in enumerate(network): + if idx == 2: + # add positional encoding after outlooker blocks + x = x + self.pos_embed + x = self.pos_drop(x) + x = block(x) + if idx in take_indices: + # normalize intermediates with final norm layer if enabled + intermediates.append(x.permute(0, 3, 1, 2)) + + if intermediates_only: + return intermediates + + # NOTE not supporting return of class tokens + # step3: post network, apply class attention or not + B, H, W, C = x.shape + x = x.reshape(B, -1, C) + if self.post_network is not None: + x = self.forward_cls(x) + x = self.norm(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + n: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stage_ends), n) + max_index = self.stage_ends[max_index] + self.network = self.network[:max_index + 1] # truncate blocks + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.post_network = nn.ModuleList() # prune token blocks with head + self.head = nn.Identity() + return take_indices + def forward_features(self, x): x = self.patch_embed(x).permute(0, 2, 3, 1) # B,C,H,W-> B,H,W,C @@ -728,12 +822,12 @@ class VOLO(nn.Module): def _create_volo(variant, pretrained=False, **kwargs): - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Vision Transformer models.') + out_indices = kwargs.pop('out_indices', 3) return build_model_with_cfg( VOLO, variant, pretrained, + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), **kwargs, ) diff --git a/timm/models/xcit.py b/timm/models/xcit.py index ffcf07ec..941f7bf2 100644 --- a/timm/models/xcit.py +++ b/timm/models/xcit.py @@ -13,14 +13,16 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W import math from functools import partial +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, trunc_normal_, to_2tuple +from timm.layers import DropPath, trunc_normal_, to_2tuple, use_fused_attn from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._features_fx import register_notrace_module from ._registry import register_model, generate_default_cfgs, register_model_deprecations from .cait import ClassAttn @@ -195,6 +197,7 @@ class ClassAttentionBlock(nn.Module): class XCA(nn.Module): + fused_attn: torch.jit.Final[bool] """ Cross-Covariance Attention (XCA) Operation where the channels are updated using a weighted sum. The weights are obtained from the (softmax normalized) Cross-covariance matrix (Q^T \\cdot K \\in d_h \\times d_h) @@ -203,6 +206,7 @@ class XCA(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads + self.fused_attn = use_fused_attn(experimental=True) self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) @@ -214,16 +218,21 @@ class XCA(nn.Module): # Result of next line is (qkv, B, num (H)eads, (C')hannels per head, N) qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) - - # Paper section 3.2 l2-Normalization and temperature scaling - q = torch.nn.functional.normalize(q, dim=-1) - k = torch.nn.functional.normalize(k, dim=-1) - attn = (q @ k.transpose(-2, -1)) * self.temperature - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - # (B, H, C', N), permute -> (B, N, H, C') - x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C) + if self.fused_attn: + q = torch.nn.functional.normalize(q, dim=-1) * self.temperature + k = torch.nn.functional.normalize(k, dim=-1) + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, scale=1.0) + else: + # Paper section 3.2 l2-Normalization and temperature scaling + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.permute(0, 3, 1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x @@ -348,6 +357,7 @@ class Xcit(nn.Module): embed_dim=embed_dim, act_layer=act_layer, ) + r = patch_size self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if use_pos_embed: @@ -370,6 +380,7 @@ class Xcit(nn.Module): eta=eta, ) for _ in range(depth)]) + self.feature_info = [dict(num_chs=embed_dim, reduction=r, module=f'blocks.{i}') for i in range(depth)] self.cls_attn_blocks = nn.ModuleList([ ClassAttentionBlock( @@ -428,6 +439,85 @@ class Xcit(nn.Module): self.global_pool = global_pool self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, + norm: bool = False, + stop_early: bool = True, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to all intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.' + reshape = output_fmt == 'NCHW' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.blocks), indices) + + # forward pass + B, _, height, width = x.shape + x, (Hp, Wp) = self.patch_embed(x) + + if self.pos_embed is not None: + # `pos_embed` (B, C, Hp, Wp), reshape -> (B, C, N), permute -> (B, N, C) + pos_encoding = self.pos_embed(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1) + x = x + pos_encoding + x = self.pos_drop(x) + + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = self.blocks + else: + blocks = self.blocks[:max_index + 1] + for i, blk in enumerate(blocks): + x = blk(x, Hp, Wp) + if i in take_indices: + # normalize intermediates with final norm layer if enabled + intermediates.append(self.norm(x) if norm else x) + + # process intermediates + if reshape: + # reshape to BCHW output format + intermediates = [y.reshape(B, Hp, Wp, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates] + + if intermediates_only: + return intermediates + + # NOTE not supporting return of class tokens + x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1) + for blk in self.cls_attn_blocks: + x = blk(x) + x = self.norm(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + n: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.blocks), n) + self.blocks = self.blocks[:max_index + 1] # truncate blocks + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.cls_attn_blocks = nn.ModuleList() # prune token blocks with head + self.head = nn.Identity() + return take_indices + def forward_features(self, x): B = x.shape[0] # x is (B, N, C). (Hp, Hw) is (height in units of patches, width in units of patches) @@ -498,14 +588,13 @@ def checkpoint_filter_fn(state_dict, model): def _create_xcit(variant, pretrained=False, default_cfg=None, **kwargs): - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Cross-Covariance Image Transformers models.') - + out_indices = kwargs.pop('out_indices', 3) model = build_model_with_cfg( Xcit, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), **kwargs, ) return model