Merge pull request #2162 from huggingface/more_fwd_intermediates

Add forward_intermediates support for xcit, cait, and volo.
This commit is contained in:
Ross Wightman 2024-04-29 21:09:57 -07:00 committed by GitHub
commit 6de529bb3c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 314 additions and 53 deletions

View File

@ -49,10 +49,11 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
# models with forward_intermediates() and support for FeatureGetterNet features_only wrapper # models with forward_intermediates() and support for FeatureGetterNet features_only wrapper
FEAT_INTER_FILTERS = [ FEAT_INTER_FILTERS = [
'vit_*', 'twins_*', 'deit*', 'beit*', 'mvitv2*', 'eva*', 'samvit_*', 'flexivit*' 'vit_*', 'twins_*', 'deit*', 'beit*', 'mvitv2*', 'eva*', 'samvit_*', 'flexivit*',
'cait_*', 'xcit_*', 'volo_*',
] ]
# transformer models don't support many of the spatial / feature based model functionalities # transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
NON_STD_FILTERS = [ NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*',

View File

@ -165,6 +165,7 @@ class FeatureHooks:
): ):
# setup feature hooks # setup feature hooks
self._feature_outputs = defaultdict(OrderedDict) self._feature_outputs = defaultdict(OrderedDict)
self._handles = []
modules = {k: v for k, v in named_modules} modules = {k: v for k, v in named_modules}
for i, h in enumerate(hooks): for i, h in enumerate(hooks):
hook_name = h['module'] hook_name = h['module']
@ -173,11 +174,12 @@ class FeatureHooks:
hook_fn = partial(self._collect_output_hook, hook_id) hook_fn = partial(self._collect_output_hook, hook_id)
hook_type = h.get('hook_type', default_hook_type) hook_type = h.get('hook_type', default_hook_type)
if hook_type == 'forward_pre': if hook_type == 'forward_pre':
m.register_forward_pre_hook(hook_fn) handle = m.register_forward_pre_hook(hook_fn)
elif hook_type == 'forward': elif hook_type == 'forward':
m.register_forward_hook(hook_fn) handle = m.register_forward_hook(hook_fn)
else: else:
assert False, "Unsupported hook type" assert False, "Unsupported hook type"
self._handles.append(handle)
def _collect_output_hook(self, hook_id, *args): def _collect_output_hook(self, hook_id, *args):
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre

View File

@ -9,6 +9,7 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
# Copyright (c) 2015-present, Facebook, Inc. # Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved. # All rights reserved.
from functools import partial from functools import partial
from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -16,6 +17,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, use_fused_attn from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, use_fused_attn
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs from ._registry import register_model, generate_default_cfgs
@ -246,8 +248,8 @@ class Cait(nn.Module):
in_chans=in_chans, in_chans=in_chans,
embed_dim=embed_dim, embed_dim=embed_dim,
) )
num_patches = self.patch_embed.num_patches num_patches = self.patch_embed.num_patches
r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
@ -268,6 +270,7 @@ class Cait(nn.Module):
mlp_block=mlp_block, mlp_block=mlp_block,
init_values=init_values, init_values=init_values,
) for i in range(depth)]) ) for i in range(depth)])
self.feature_info = [dict(num_chs=embed_dim, reduction=r, module=f'blocks.{i}') for i in range(depth)]
self.blocks_token_only = nn.ModuleList([block_layers_token( self.blocks_token_only = nn.ModuleList([block_layers_token(
dim=embed_dim, dim=embed_dim,
@ -283,7 +286,6 @@ class Cait(nn.Module):
self.norm = norm_layer(embed_dim) self.norm = norm_layer(embed_dim)
self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')]
self.head_drop = nn.Dropout(drop_rate) self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
@ -336,6 +338,80 @@ class Cait(nn.Module):
self.global_pool = global_pool self.global_pool = global_pool
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
norm: bool = False,
stop_early: bool = True,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to all intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
"""
assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.'
reshape = output_fmt == 'NCHW'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
# forward pass
B, _, height, width = x.shape
x = self.patch_embed(x)
x = x + self.pos_embed
x = self.pos_drop(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
blocks = self.blocks
else:
blocks = self.blocks[:max_index + 1]
for i, blk in enumerate(blocks):
x = blk(x)
if i in take_indices:
# normalize intermediates with final norm layer if enabled
intermediates.append(self.norm(x) if norm else x)
# process intermediates
if reshape:
# reshape to BCHW output format
H, W = self.patch_embed.dynamic_feat_size((height, width))
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
if intermediates_only:
return intermediates
# NOTE not supporting return of class tokens
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
for i, blk in enumerate(self.blocks_token_only):
cls_tokens = blk(x, cls_tokens)
x = torch.cat((cls_tokens, x), dim=1)
x = self.norm(x)
return x, intermediates
def prune_intermediate_layers(
self,
n: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.blocks), n)
self.blocks = self.blocks[:max_index + 1] # truncate blocks
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.blocks_token_only = nn.ModuleList() # prune token blocks with head
self.head = nn.Identity()
return take_indices
def forward_features(self, x): def forward_features(self, x):
x = self.patch_embed(x) x = self.patch_embed(x)
x = x + self.pos_embed x = x + self.pos_embed
@ -373,14 +449,13 @@ def checkpoint_filter_fn(state_dict, model=None):
def _create_cait(variant, pretrained=False, **kwargs): def _create_cait(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None): out_indices = kwargs.pop('out_indices', 3)
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg( model = build_model_with_cfg(
Cait, Cait,
variant, variant,
pretrained, pretrained,
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
**kwargs, **kwargs,
) )
return model return model

View File

@ -20,6 +20,7 @@ Modifications and additions for timm by / Copyright 2022, Ross Wightman
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math import math
from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@ -28,8 +29,9 @@ import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, Mlp, to_2tuple, to_ntuple, trunc_normal_ from timm.layers import DropPath, Mlp, to_2tuple, to_ntuple, trunc_normal_, use_fused_attn
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._registry import register_model, generate_default_cfgs from ._registry import register_model, generate_default_cfgs
__all__ = ['VOLO'] # model_registry will add each entrypoint fn to this __all__ = ['VOLO'] # model_registry will add each entrypoint fn to this
@ -119,24 +121,24 @@ class Outlooker(nn.Module):
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
attn_drop=attn_drop, attn_drop=attn_drop,
) )
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp( self.mlp = Mlp(
in_features=dim, in_features=dim,
hidden_features=mlp_hidden_dim, hidden_features=int(dim * mlp_ratio),
act_layer=act_layer, act_layer=act_layer,
) )
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x): def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x))) x = x + self.drop_path1(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x))) x = x + self.drop_path2(self.mlp(self.norm2(x)))
return x return x
class Attention(nn.Module): class Attention(nn.Module):
fused_attn: torch.jit.Final[bool]
def __init__( def __init__(
self, self,
@ -150,6 +152,7 @@ class Attention(nn.Module):
self.num_heads = num_heads self.num_heads = num_heads
head_dim = dim // num_heads head_dim = dim // num_heads
self.scale = head_dim ** -0.5 self.scale = head_dim ** -0.5
self.fused_attn = use_fused_attn()
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop) self.attn_drop = nn.Dropout(attn_drop)
@ -162,11 +165,19 @@ class Attention(nn.Module):
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) q, k, v = qkv.unbind(0)
attn = (q @ k.transpose(-2, -1)) * self.scale if self.fused_attn:
attn = attn.softmax(dim=-1) x = F.scaled_dot_product_attention(
attn = self.attn_drop(attn) q, k, v,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = (attn @ v).transpose(1, 2).reshape(B, H, W, C) x = x.transpose(1, 2).reshape(B, H, W, C)
x = self.proj(x) x = self.proj(x)
x = self.proj_drop(x) x = self.proj_drop(x)
@ -189,17 +200,15 @@ class Transformer(nn.Module):
super().__init__() super().__init__()
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop) self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x): def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x))) x = x + self.drop_path1(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x))) x = x + self.drop_path2(self.mlp(self.norm2(x)))
return x return x
@ -234,8 +243,9 @@ class ClassAttention(nn.Module):
kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
k, v = kv.unbind(0) k, v = kv.unbind(0)
q = self.q(x[:, :1, :]).reshape(B, self.num_heads, 1, self.head_dim) q = self.q(x[:, :1, :]).reshape(B, self.num_heads, 1, self.head_dim) * self.scale
attn = ((q * self.scale) @ k.transpose(-2, -1))
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn) attn = self.attn_drop(attn)
@ -270,21 +280,21 @@ class ClassBlock(nn.Module):
attn_drop=attn_drop, attn_drop=attn_drop,
proj_drop=drop, proj_drop=drop,
) )
# NOTE: drop path for stochastic depth self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp( self.mlp = Mlp(
in_features=dim, in_features=dim,
hidden_features=mlp_hidden_dim, hidden_features=int(dim * mlp_ratio),
act_layer=act_layer, act_layer=act_layer,
drop=drop, drop=drop,
) )
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x): def forward(self, x):
cls_embed = x[:, :1] cls_embed = x[:, :1]
cls_embed = cls_embed + self.drop_path(self.attn(self.norm1(x))) cls_embed = cls_embed + self.drop_path1(self.attn(self.norm1(x)))
cls_embed = cls_embed + self.drop_path(self.mlp(self.norm2(cls_embed))) cls_embed = cls_embed + self.drop_path2(self.mlp(self.norm2(cls_embed)))
return torch.cat([cls_embed, x[:, 1:]], dim=1) return torch.cat([cls_embed, x[:, 1:]], dim=1)
@ -495,6 +505,7 @@ class VOLO(nn.Module):
hidden_dim=stem_hidden_dim, hidden_dim=stem_hidden_dim,
embed_dim=embed_dims[0], embed_dim=embed_dims[0],
) )
r = patch_size
# inital positional encoding, we add positional encoding after outlooker blocks # inital positional encoding, we add positional encoding after outlooker blocks
patch_grid = (img_size[0] // patch_size // pooling_scale, img_size[1] // patch_size // pooling_scale) patch_grid = (img_size[0] // patch_size // pooling_scale, img_size[1] // patch_size // pooling_scale)
@ -502,7 +513,10 @@ class VOLO(nn.Module):
self.pos_drop = nn.Dropout(p=pos_drop_rate) self.pos_drop = nn.Dropout(p=pos_drop_rate)
# set the main block in network # set the main block in network
self.stage_ends = []
self.feature_info = []
network = [] network = []
block_idx = 0
for i in range(len(layers)): for i in range(len(layers)):
if outlook_attention[i]: if outlook_attention[i]:
# stage 1 # stage 1
@ -517,7 +531,6 @@ class VOLO(nn.Module):
attn_drop=attn_drop_rate, attn_drop=attn_drop_rate,
norm_layer=norm_layer, norm_layer=norm_layer,
) )
network.append(stage)
else: else:
# stage 2 # stage 2
stage = transformer_blocks( stage = transformer_blocks(
@ -532,11 +545,15 @@ class VOLO(nn.Module):
attn_drop=attn_drop_rate, attn_drop=attn_drop_rate,
norm_layer=norm_layer, norm_layer=norm_layer,
) )
network.append(stage) network.append(stage)
self.stage_ends.append(block_idx)
self.feature_info.append(dict(num_chs=embed_dims[i], reduction=r, module=f'network.{block_idx}'))
block_idx += 1
if downsamples[i]: if downsamples[i]:
# downsampling between two stages # downsampling between two stages
network.append(Downsample(embed_dims[i], embed_dims[i + 1], 2)) network.append(Downsample(embed_dims[i], embed_dims[i + 1], 2))
r *= 2
block_idx += 1
self.network = nn.ModuleList(network) self.network = nn.ModuleList(network)
@ -691,6 +708,83 @@ class VOLO(nn.Module):
# return these: 1. class token, 2. classes from all feature tokens, 3. bounding box # return these: 1. class token, 2. classes from all feature tokens, 3. bounding box
return x_cls, x_aux, (bbx1, bby1, bbx2, bby2) return x_cls, x_aux, (bbx1, bby1, bbx2, bby2)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
norm: bool = False,
stop_early: bool = True,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to all intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output format must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
take_indices = [self.stage_ends[i] for i in take_indices]
max_index = self.stage_ends[max_index]
# forward pass
B, _, height, width = x.shape
x = self.patch_embed(x).permute(0, 2, 3, 1) # B,C,H,W-> B,H,W,C
# step2: tokens learning in the two stages
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
network = self.network
else:
network = self.network[:max_index + 1]
for idx, block in enumerate(network):
if idx == 2:
# add positional encoding after outlooker blocks
x = x + self.pos_embed
x = self.pos_drop(x)
x = block(x)
if idx in take_indices:
# normalize intermediates with final norm layer if enabled
intermediates.append(x.permute(0, 3, 1, 2))
if intermediates_only:
return intermediates
# NOTE not supporting return of class tokens
# step3: post network, apply class attention or not
B, H, W, C = x.shape
x = x.reshape(B, -1, C)
if self.post_network is not None:
x = self.forward_cls(x)
x = self.norm(x)
return x, intermediates
def prune_intermediate_layers(
self,
n: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stage_ends), n)
max_index = self.stage_ends[max_index]
self.network = self.network[:max_index + 1] # truncate blocks
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.post_network = nn.ModuleList() # prune token blocks with head
self.head = nn.Identity()
return take_indices
def forward_features(self, x): def forward_features(self, x):
x = self.patch_embed(x).permute(0, 2, 3, 1) # B,C,H,W-> B,H,W,C x = self.patch_embed(x).permute(0, 2, 3, 1) # B,C,H,W-> B,H,W,C
@ -728,12 +822,12 @@ class VOLO(nn.Module):
def _create_volo(variant, pretrained=False, **kwargs): def _create_volo(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None): out_indices = kwargs.pop('out_indices', 3)
raise RuntimeError('features_only not implemented for Vision Transformer models.')
return build_model_with_cfg( return build_model_with_cfg(
VOLO, VOLO,
variant, variant,
pretrained, pretrained,
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
**kwargs, **kwargs,
) )

View File

@ -13,14 +13,16 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
import math import math
from functools import partial from functools import partial
from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, trunc_normal_, to_2tuple from timm.layers import DropPath, trunc_normal_, to_2tuple, use_fused_attn
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_module from ._features_fx import register_notrace_module
from ._registry import register_model, generate_default_cfgs, register_model_deprecations from ._registry import register_model, generate_default_cfgs, register_model_deprecations
from .cait import ClassAttn from .cait import ClassAttn
@ -195,6 +197,7 @@ class ClassAttentionBlock(nn.Module):
class XCA(nn.Module): class XCA(nn.Module):
fused_attn: torch.jit.Final[bool]
""" Cross-Covariance Attention (XCA) """ Cross-Covariance Attention (XCA)
Operation where the channels are updated using a weighted sum. The weights are obtained from the (softmax Operation where the channels are updated using a weighted sum. The weights are obtained from the (softmax
normalized) Cross-covariance matrix (Q^T \\cdot K \\in d_h \\times d_h) normalized) Cross-covariance matrix (Q^T \\cdot K \\in d_h \\times d_h)
@ -203,6 +206,7 @@ class XCA(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
self.fused_attn = use_fused_attn(experimental=True)
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop) self.attn_drop = nn.Dropout(attn_drop)
@ -214,16 +218,21 @@ class XCA(nn.Module):
# Result of next line is (qkv, B, num (H)eads, (C')hannels per head, N) # Result of next line is (qkv, B, num (H)eads, (C')hannels per head, N)
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1) qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
# Paper section 3.2 l2-Normalization and temperature scaling
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
# (B, H, C', N), permute -> (B, N, H, C') if self.fused_attn:
x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C) q = torch.nn.functional.normalize(q, dim=-1) * self.temperature
k = torch.nn.functional.normalize(k, dim=-1)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, scale=1.0)
else:
# Paper section 3.2 l2-Normalization and temperature scaling
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.permute(0, 3, 1, 2).reshape(B, N, C)
x = self.proj(x) x = self.proj(x)
x = self.proj_drop(x) x = self.proj_drop(x)
return x return x
@ -348,6 +357,7 @@ class Xcit(nn.Module):
embed_dim=embed_dim, embed_dim=embed_dim,
act_layer=act_layer, act_layer=act_layer,
) )
r = patch_size
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
if use_pos_embed: if use_pos_embed:
@ -370,6 +380,7 @@ class Xcit(nn.Module):
eta=eta, eta=eta,
) )
for _ in range(depth)]) for _ in range(depth)])
self.feature_info = [dict(num_chs=embed_dim, reduction=r, module=f'blocks.{i}') for i in range(depth)]
self.cls_attn_blocks = nn.ModuleList([ self.cls_attn_blocks = nn.ModuleList([
ClassAttentionBlock( ClassAttentionBlock(
@ -428,6 +439,85 @@ class Xcit(nn.Module):
self.global_pool = global_pool self.global_pool = global_pool
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
norm: bool = False,
stop_early: bool = True,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to all intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.'
reshape = output_fmt == 'NCHW'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
# forward pass
B, _, height, width = x.shape
x, (Hp, Wp) = self.patch_embed(x)
if self.pos_embed is not None:
# `pos_embed` (B, C, Hp, Wp), reshape -> (B, C, N), permute -> (B, N, C)
pos_encoding = self.pos_embed(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1)
x = x + pos_encoding
x = self.pos_drop(x)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
blocks = self.blocks
else:
blocks = self.blocks[:max_index + 1]
for i, blk in enumerate(blocks):
x = blk(x, Hp, Wp)
if i in take_indices:
# normalize intermediates with final norm layer if enabled
intermediates.append(self.norm(x) if norm else x)
# process intermediates
if reshape:
# reshape to BCHW output format
intermediates = [y.reshape(B, Hp, Wp, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
if intermediates_only:
return intermediates
# NOTE not supporting return of class tokens
x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1)
for blk in self.cls_attn_blocks:
x = blk(x)
x = self.norm(x)
return x, intermediates
def prune_intermediate_layers(
self,
n: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.blocks), n)
self.blocks = self.blocks[:max_index + 1] # truncate blocks
if prune_norm:
self.norm = nn.Identity()
if prune_head:
self.cls_attn_blocks = nn.ModuleList() # prune token blocks with head
self.head = nn.Identity()
return take_indices
def forward_features(self, x): def forward_features(self, x):
B = x.shape[0] B = x.shape[0]
# x is (B, N, C). (Hp, Hw) is (height in units of patches, width in units of patches) # x is (B, N, C). (Hp, Hw) is (height in units of patches, width in units of patches)
@ -498,14 +588,13 @@ def checkpoint_filter_fn(state_dict, model):
def _create_xcit(variant, pretrained=False, default_cfg=None, **kwargs): def _create_xcit(variant, pretrained=False, default_cfg=None, **kwargs):
if kwargs.get('features_only', None): out_indices = kwargs.pop('out_indices', 3)
raise RuntimeError('features_only not implemented for Cross-Covariance Image Transformers models.')
model = build_model_with_cfg( model = build_model_with_cfg(
Xcit, Xcit,
variant, variant,
pretrained, pretrained,
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
**kwargs, **kwargs,
) )
return model return model