diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index 0c9ac989..52b3ce45 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -8,6 +8,7 @@ import math import logging from functools import partial from collections import OrderedDict +from dataclasses import dataclass from typing import Optional, Tuple import torch @@ -16,7 +17,7 @@ import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg, named_apply +from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, to_2tuple from .registry import register_model @@ -47,9 +48,16 @@ default_cfgs = { 'vit_relpos_base_patch16_224': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_224-sw-49049aed.pth'), + 'vit_srelpos_small_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_small_patch16_224-sw-6cdb8849.pth'), + 'vit_srelpos_medium_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_medium_patch16_224-sw-ad702b8c.pth'), + + 'vit_relpos_medium_patch16_cls_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_cls_224-sw-cfe8e259.pth'), 'vit_relpos_base_patch16_cls_224': _cfg( url=''), - 'vit_relpos_base_patch16_gapcls_224': _cfg( + 'vit_relpos_base_patch16_clsgap_224': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_gapcls_224-sw-1a341d6c.pth'), 'vit_relpos_small_patch16_rpn_224': _cfg(url=''), @@ -59,35 +67,43 @@ default_cfgs = { } -def gen_relative_position_index(win_size: Tuple[int, int], class_token: int = 0) -> torch.Tensor: - # cut and paste w/ modifications from swin / beit codebase - # cls to token & token 2 cls & cls to cls +def gen_relative_position_index( + q_size: Tuple[int, int], + k_size: Tuple[int, int] = None, + class_token: bool = False) -> torch.Tensor: + # Adapted with significant modifications from Swin / BeiT codebases # get pair-wise relative position index for each token inside the window - window_area = win_size[0] * win_size[1] - coords = torch.stack(torch.meshgrid([torch.arange(win_size[0]), torch.arange(win_size[1])])).flatten(1) # 2, Wh, Ww - relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += win_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += win_size[1] - 1 - relative_coords[:, :, 0] *= 2 * win_size[1] - 1 + q_coords = torch.stack(torch.meshgrid([torch.arange(q_size[0]), torch.arange(q_size[1])])).flatten(1) # 2, Wh, Ww + if k_size is None: + k_coords = q_coords + k_size = q_size + else: + # different q vs k sizes is a WIP + k_coords = torch.stack(torch.meshgrid([torch.arange(k_size[0]), torch.arange(k_size[1])])).flatten(1) + relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0) # Wh*Ww, Wh*Ww, 2 + _, relative_position_index = torch.unique(relative_coords.view(-1, 2), return_inverse=True, dim=0) + if class_token: - num_relative_distance = (2 * win_size[0] - 1) * (2 * win_size[1] - 1) + 3 - relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype) - relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + # handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias + # NOTE not intended or tested with MLP log-coords + max_size = (max(q_size[0], k_size[0]), max(q_size[1], k_size[1])) + num_relative_distance = (2 * max_size[0] - 1) * (2 * max_size[1] - 1) + 3 + relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0]) relative_position_index[0, 0:] = num_relative_distance - 3 relative_position_index[0:, 0] = num_relative_distance - 2 relative_position_index[0, 0] = num_relative_distance - 1 - else: - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - return relative_position_index + + return relative_position_index.contiguous() def gen_relative_log_coords( win_size: Tuple[int, int], pretrained_win_size: Tuple[int, int] = (0, 0), - mode='swin' + mode='swin', ): - # as per official swin-v2 impl, supporting timm swin-v2-cr coords as well + assert mode in ('swin', 'cr', 'rw') + # as per official swin-v2 impl, supporting timm specific 'cr' and 'rw' log coords as well relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32) relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32) relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w])) @@ -100,12 +116,22 @@ def gen_relative_log_coords( relative_coords_table[:, :, 0] /= (win_size[0] - 1) relative_coords_table[:, :, 1] /= (win_size[1] - 1) relative_coords_table *= 8 # normalize to -8, 8 - scale = math.log2(8) + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + 1.0 + relative_coords_table.abs()) / math.log2(8) else: - # FIXME we should support a form of normalization (to -1/1) for this mode? - scale = math.log2(math.e) - relative_coords_table = torch.sign(relative_coords_table) * torch.log2( - 1.0 + relative_coords_table.abs()) / scale + if mode == 'rw': + # cr w/ window size normalization -> [-1,1] log coords + relative_coords_table[:, :, 0] /= (win_size[0] - 1) + relative_coords_table[:, :, 1] /= (win_size[1] - 1) + relative_coords_table *= 8 # scale to -8, 8 + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + 1.0 + relative_coords_table.abs()) + relative_coords_table /= math.log2(9) # -> [-1, 1] + else: + # mode == 'cr' + relative_coords_table = torch.sign(relative_coords_table) * torch.log( + 1.0 + relative_coords_table.abs()) + return relative_coords_table @@ -115,19 +141,29 @@ class RelPosMlp(nn.Module): window_size, num_heads=8, hidden_dim=128, - class_token=False, + prefix_tokens=0, mode='cr', pretrained_window_size=(0, 0) ): super().__init__() self.window_size = window_size self.window_area = self.window_size[0] * self.window_size[1] - self.class_token = 1 if class_token else 0 + self.prefix_tokens = prefix_tokens self.num_heads = num_heads self.bias_shape = (self.window_area,) * 2 + (num_heads,) - self.apply_sigmoid = mode == 'swin' + if mode == 'swin': + self.bias_act = nn.Sigmoid() + self.bias_gain = 16 + mlp_bias = (True, False) + elif mode == 'rw': + self.bias_act = nn.Tanh() + self.bias_gain = 4 + mlp_bias = True + else: + self.bias_act = nn.Identity() + self.bias_gain = None + mlp_bias = True - mlp_bias = (True, False) if mode == 'swin' else True self.mlp = Mlp( 2, # x, y hidden_features=hidden_dim, @@ -155,10 +191,11 @@ class RelPosMlp(nn.Module): self.relative_position_index.view(-1)] # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.view(self.bias_shape) relative_position_bias = relative_position_bias.permute(2, 0, 1) - if self.apply_sigmoid: - relative_position_bias = 16 * torch.sigmoid(relative_position_bias) - if self.class_token: - relative_position_bias = F.pad(relative_position_bias, [self.class_token, 0, self.class_token, 0]) + relative_position_bias = self.bias_act(relative_position_bias) + if self.bias_gain is not None: + relative_position_bias = self.bias_gain * relative_position_bias + if self.prefix_tokens: + relative_position_bias = F.pad(relative_position_bias, [self.prefix_tokens, 0, self.prefix_tokens, 0]) return relative_position_bias.unsqueeze(0).contiguous() def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None): @@ -167,18 +204,18 @@ class RelPosMlp(nn.Module): class RelPosBias(nn.Module): - def __init__(self, window_size, num_heads, class_token=False): + def __init__(self, window_size, num_heads, prefix_tokens=0): super().__init__() + assert prefix_tokens <= 1 self.window_size = window_size self.window_area = window_size[0] * window_size[1] - self.class_token = 1 if class_token else 0 - self.bias_shape = (self.window_area + self.class_token,) * 2 + (num_heads,) + self.bias_shape = (self.window_area + prefix_tokens,) * 2 + (num_heads,) - num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * self.class_token + num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * prefix_tokens self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads)) self.register_buffer( "relative_position_index", - gen_relative_position_index(self.window_size, class_token=self.class_token), + gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0), persistent=False, ) @@ -306,11 +343,32 @@ class VisionTransformerRelPos(nn.Module): """ def __init__( - self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='avg', - embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=1e-6, - class_token=False, fc_norm=False, rel_pos_type='mlp', shared_rel_pos=False, rel_pos_dim=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='skip', - embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=RelPosBlock): + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + global_pool='avg', + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + qkv_bias=True, + init_values=1e-6, + class_token=False, + fc_norm=False, + rel_pos_type='mlp', + rel_pos_dim=None, + shared_rel_pos=False, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + weight_init='skip', + embed_layer=PatchEmbed, + norm_layer=None, + act_layer=None, + block_fn=RelPosBlock + ): """ Args: img_size (int, tuple): input image size @@ -345,19 +403,22 @@ class VisionTransformerRelPos(nn.Module): self.num_classes = num_classes self.global_pool = global_pool self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models - self.num_tokens = 1 if class_token else 0 + self.num_prefix_tokens = 1 if class_token else 0 self.grad_checkpointing = False self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) feat_size = self.patch_embed.grid_size - rel_pos_args = dict(window_size=feat_size, class_token=class_token) + rel_pos_args = dict(window_size=feat_size, prefix_tokens=self.num_prefix_tokens) if rel_pos_type.startswith('mlp'): if rel_pos_dim: rel_pos_args['hidden_dim'] = rel_pos_dim + # FIXME experimenting with different relpos log coord configs if 'swin' in rel_pos_type: rel_pos_args['mode'] = 'swin' + elif 'rw' in rel_pos_type: + rel_pos_args['mode'] = 'rw' rel_pos_cls = partial(RelPosMlp, **rel_pos_args) else: rel_pos_cls = partial(RelPosBias, **rel_pos_args) @@ -367,7 +428,7 @@ class VisionTransformerRelPos(nn.Module): # NOTE shared rel pos currently mutually exclusive w/ per-block, but could support both... rel_pos_cls = None - self.cls_token = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim)) if self.num_tokens else None + self.cls_token = nn.Parameter(torch.zeros(1, self.num_prefix_tokens, embed_dim)) if class_token else None dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.blocks = nn.ModuleList([ @@ -434,7 +495,7 @@ class VisionTransformerRelPos(nn.Module): def forward_head(self, x, pre_logits: bool = False): if self.global_pool: - x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] x = self.fc_norm(x) return x if pre_logits else self.head(x) @@ -502,6 +563,41 @@ def vit_relpos_base_patch16_224(pretrained=False, **kwargs): return model +@register_model +def vit_srelpos_small_patch16_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ shared relative log-coord position, no class token + """ + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=6, qkv_bias=False, fc_norm=False, + rel_pos_dim=384, shared_rel_pos=True, **kwargs) + model = _create_vision_transformer_relpos('vit_srelpos_small_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_srelpos_medium_patch16_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ shared relative log-coord position, no class token + """ + model_kwargs = dict( + patch_size=16, embed_dim=512, depth=12, num_heads=8, qkv_bias=False, fc_norm=False, + rel_pos_dim=512, shared_rel_pos=True, **kwargs) + model = _create_vision_transformer_relpos( + 'vit_srelpos_medium_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_relpos_medium_patch16_cls_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-M/16) w/ relative log-coord position, class token present + """ + model_kwargs = dict( + patch_size=16, embed_dim=512, depth=12, num_heads=8, qkv_bias=False, fc_norm=False, + rel_pos_dim=256, class_token=True, global_pool='token', **kwargs) + model = _create_vision_transformer_relpos( + 'vit_relpos_medium_patch16_cls_224', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_relpos_base_patch16_cls_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) w/ relative log-coord position, class token present @@ -514,14 +610,14 @@ def vit_relpos_base_patch16_cls_224(pretrained=False, **kwargs): @register_model -def vit_relpos_base_patch16_gapcls_224(pretrained=False, **kwargs): +def vit_relpos_base_patch16_clsgap_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) w/ relative log-coord position, class token present NOTE this config is a bit of a mistake, class token was enabled but global avg-pool w/ fc-norm was not disabled Leaving here for comparisons w/ a future re-train as it performs quite well. """ model_kwargs = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, fc_norm=True, class_token=True, **kwargs) - model = _create_vision_transformer_relpos('vit_relpos_base_patch16_gapcls_224', pretrained=pretrained, **model_kwargs) + model = _create_vision_transformer_relpos('vit_relpos_base_patch16_clsgap_224', pretrained=pretrained, **model_kwargs) return model