diff --git a/timm/models/tnt.py b/timm/models/tnt.py index 135272e4..7e86103d 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -22,13 +22,13 @@ from ._features import feature_take_indices from ._manipulate import checkpoint from ._registry import generate_default_cfgs, register_model - __all__ = ['TNT'] # model_registry will add each entrypoint fn to this class Attention(nn.Module): """ Multi-Head Attention """ + def __init__(self, dim, hidden_dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): super().__init__() self.hidden_dim = hidden_dim @@ -46,7 +46,7 @@ class Attention(nn.Module): def forward(self, x): B, N, C = x.shape qk = self.qk(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) - q, k = qk.unbind(0) # make torchscript happy (cannot use tensor as tuple) + q, k = qk.unbind(0) # make torchscript happy (cannot use tensor as tuple) v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) attn = (q @ k.transpose(-2, -1)) * self.scale @@ -62,6 +62,7 @@ class Attention(nn.Module): class Block(nn.Module): """ TNT Block """ + def __init__( self, dim, @@ -89,7 +90,7 @@ class Block(nn.Module): attn_drop=attn_drop, proj_drop=proj_drop, ) - + self.norm_mlp_in = norm_layer(dim) self.mlp_in = Mlp( in_features=dim, @@ -118,7 +119,7 @@ class Block(nn.Module): proj_drop=proj_drop, ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - + self.norm_mlp = norm_layer(dim_out) self.mlp = Mlp( in_features=dim_out, @@ -136,13 +137,13 @@ class Block(nn.Module): B, N, C = patch_embed.size() if self.legacy: patch_embed = torch.cat([ - patch_embed[:, 0:1], patch_embed[:, 1:] + \ - self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)), + patch_embed[:, 0:1], + patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)), ], dim=1) else: patch_embed = torch.cat([ - patch_embed[:, 0:1], patch_embed[:, 1:] + \ - self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, N - 1, -1)))), + patch_embed[:, 0:1], + patch_embed[:, 1:] + self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, N - 1, -1)))), ], dim=1) patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed))) patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed))) @@ -152,7 +153,16 @@ class Block(nn.Module): class PixelEmbed(nn.Module): """ Image to Pixel Embedding """ - def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4, legacy=False): + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + in_dim=48, + stride=4, + legacy=False, + ): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) @@ -184,14 +194,17 @@ class PixelEmbed(nn.Module): def forward(self, x: torch.Tensor, pixel_pos: torch.Tensor) -> torch.Tensor: B, C, H, W = x.shape - _assert(H == self.img_size[0], + _assert( + H == self.img_size[0], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") - _assert(W == self.img_size[1], + _assert( + W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") if self.legacy: x = self.proj(x) x = self.unfold(x) - x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1]) + x = x.transpose(1, 2).reshape( + B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1]) else: x = self.unfold(x) x = x.transpose(1, 2).reshape(B * self.num_patches, C, self.patch_size[0], self.patch_size[1]) @@ -204,6 +217,7 @@ class PixelEmbed(nn.Module): class TNT(nn.Module): """ Transformer in Transformer - https://arxiv.org/abs/2103.00112 """ + def __init__( self, img_size=224, @@ -248,7 +262,7 @@ class TNT(nn.Module): self.num_patches = num_patches new_patch_size = self.pixel_embed.new_patch_size num_pixel = new_patch_size[0] * new_patch_size[1] - + self.norm1_proj = norm_layer(num_pixel * inner_dim) self.proj = nn.Linear(num_pixel * inner_dim, embed_dim) self.norm2_proj = norm_layer(embed_dim) @@ -278,7 +292,7 @@ class TNT(nn.Module): self.blocks = nn.ModuleList(blocks) self.feature_info = [ dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)] - + self.norm = norm_layer(embed_dim) self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() @@ -359,7 +373,7 @@ class TNT(nn.Module): B, _, height, width = x.shape pixel_embed = self.pixel_embed(x, self.pixel_pos) - + patch_embed = self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1)))) patch_embed = torch.cat((self.cls_token.expand(B, -1, -1), patch_embed), dim=1) patch_embed = patch_embed + self.patch_pos @@ -381,7 +395,7 @@ class TNT(nn.Module): # split prefix (e.g. class, distill) and spatial feature tokens prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates] intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates] - + if reshape: # reshape to BCHW output format H, W = self.pixel_embed.dynamic_feat_size((height, width)) @@ -416,7 +430,7 @@ class TNT(nn.Module): def forward_features(self, x): B = x.shape[0] pixel_embed = self.pixel_embed(x, self.pixel_pos) - + patch_embed = self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1)))) patch_embed = torch.cat((self.cls_token.expand(B, -1, -1), patch_embed), dim=1) patch_embed = patch_embed + self.patch_pos @@ -458,42 +472,47 @@ def _cfg(url='', **kwargs): default_cfgs = generate_default_cfgs({ + 'tnt_s_legacy_patch16_224.in1k': _cfg( + hf_hub_id='timm/', + #url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar', + ), 'tnt_s_patch16_224.in1k': _cfg( - # hf_hub_id='timm/', - # url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar', - url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_s_81.5.pth.tar', + hf_hub_id='timm/', + #url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_s_81.5.pth.tar', ), 'tnt_b_patch16_224.in1k': _cfg( - # hf_hub_id='timm/', - url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_b_82.9.pth.tar', + hf_hub_id='timm/', + #url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_b_82.9.pth.tar', ), }) def checkpoint_filter_fn(state_dict, model): state_dict.pop('outer_tokens', None) - - out_dict = {} - for k, v in state_dict.items(): - k = k.replace('outer_pos', 'patch_pos') - k = k.replace('inner_pos', 'pixel_pos') - k = k.replace('patch_embed', 'pixel_embed') - k = k.replace('proj_norm1', 'norm1_proj') - k = k.replace('proj_norm2', 'norm2_proj') - k = k.replace('inner_norm1', 'norm_in') - k = k.replace('inner_attn', 'attn_in') - k = k.replace('inner_norm2', 'norm_mlp_in') - k = k.replace('inner_mlp', 'mlp_in') - k = k.replace('outer_norm1', 'norm_out') - k = k.replace('outer_attn', 'attn_out') - k = k.replace('outer_norm2', 'norm_mlp') - k = k.replace('outer_mlp', 'mlp') - if k == 'pixel_pos' and model.pixel_embed.legacy == False: - B, N, C = v.shape - H = W = int(N ** 0.5) - assert H * W == N - v = v.permute(0, 2, 1).reshape(B, C, H, W) - out_dict[k] = v + if 'patch_pos' in state_dict: + out_dict = state_dict + else: + out_dict = {} + for k, v in state_dict.items(): + k = k.replace('outer_pos', 'patch_pos') + k = k.replace('inner_pos', 'pixel_pos') + k = k.replace('patch_embed', 'pixel_embed') + k = k.replace('proj_norm1', 'norm1_proj') + k = k.replace('proj_norm2', 'norm2_proj') + k = k.replace('inner_norm1', 'norm_in') + k = k.replace('inner_attn', 'attn_in') + k = k.replace('inner_norm2', 'norm_mlp_in') + k = k.replace('inner_mlp', 'mlp_in') + k = k.replace('outer_norm1', 'norm_out') + k = k.replace('outer_attn', 'attn_out') + k = k.replace('outer_norm2', 'norm_mlp') + k = k.replace('outer_mlp', 'mlp') + if k == 'pixel_pos' and model.pixel_embed.legacy == False: + B, N, C = v.shape + H = W = int(N ** 0.5) + assert H * W == N + v = v.permute(0, 2, 1).reshape(B, C, H, W) + out_dict[k] = v """ convert patch embedding weight from manual patchify + linear proj to conv""" if out_dict['patch_pos'].shape != model.patch_pos.shape: @@ -515,6 +534,15 @@ def _create_tnt(variant, pretrained=False, **kwargs): return model +@register_model +def tnt_s_legacy_patch16_224(pretrained=False, **kwargs) -> TNT: + model_cfg = dict( + patch_size=16, embed_dim=384, inner_dim=24, depth=12, num_heads_outer=6, + qkv_bias=False, legacy=True) + model = _create_tnt('tnt_s_legacy_patch16_224', pretrained=pretrained, **dict(model_cfg, **kwargs)) + return model + + @register_model def tnt_s_patch16_224(pretrained=False, **kwargs) -> TNT: model_cfg = dict(