From 4d135421a3c28e4de03388efceeace58239f1961 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 7 Apr 2023 14:43:15 -0700 Subject: [PATCH] Implement patch dropout for eva / vision_transformer, refactor / improve consistency of dropout args across all vit based models --- timm/layers/__init__.py | 4 +- timm/layers/patch_dropout.py | 53 ++++ timm/layers/pos_embed_sincos.py | 10 + timm/models/beit.py | 8 +- timm/models/cait.py | 176 +++++++---- timm/models/coat.py | 189 +++++++++--- timm/models/convit.py | 107 +++++-- timm/models/convmixer.py | 16 +- timm/models/crossvit.py | 178 ++++++++--- timm/models/davit.py | 1 - timm/models/edgenext.py | 77 +++-- timm/models/efficientformer.py | 29 +- timm/models/efficientformer_v2.py | 16 +- timm/models/eva.py | 44 ++- timm/models/focalnet.py | 6 - timm/models/mlp_mixer.py | 5 +- timm/models/mobilevit.py | 2 +- timm/models/mvitv2.py | 218 +++++++------- timm/models/nest.py | 112 +++++-- timm/models/pit.py | 115 +++++--- timm/models/poolformer.py | 82 ++++-- timm/models/pvt_v2.py | 46 ++- timm/models/swin_transformer.py | 17 +- timm/models/swin_transformer_v2.py | 21 +- timm/models/swin_transformer_v2_cr.py | 52 ++-- timm/models/tnt.py | 133 ++++++--- timm/models/twins.py | 87 ++++-- timm/models/visformer.py | 228 ++++++++++---- timm/models/vision_transformer.py | 133 +++++---- timm/models/vision_transformer_relpos.py | 18 +- timm/models/volo.py | 210 +++++++++---- timm/models/xcit.py | 359 ++++++++++++++--------- 32 files changed, 1899 insertions(+), 853 deletions(-) create mode 100644 timm/layers/patch_dropout.py diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index afae6415..576af8d1 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -33,12 +33,14 @@ from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\ SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d from .padding import get_padding, get_same_padding, pad_same +from .patch_dropout import PatchDropout from .patch_embed import PatchEmbed, resample_patch_embed from .pool2d_same import AvgPool2dSame, create_pool2d from .pos_embed import resample_abs_pos_embed from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_embed, build_fourier_pos_embed, \ - build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, FourierEmbed, RotaryEmbedding, RotaryEmbeddingCat + build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, apply_rot_embed_list, apply_keep_indices_nlc, \ + FourierEmbed, RotaryEmbedding, RotaryEmbeddingCat from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite from .selective_kernel import SelectiveKernel from .separable_conv import SeparableConv2d, SeparableConvNormAct diff --git a/timm/layers/patch_dropout.py b/timm/layers/patch_dropout.py new file mode 100644 index 00000000..32dd1519 --- /dev/null +++ b/timm/layers/patch_dropout.py @@ -0,0 +1,53 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + + +class PatchDropout(nn.Module): + """ + https://arxiv.org/abs/2212.00794 + """ + return_indices: torch.jit.Final[bool] + + def __init__( + self, + prob: float = 0.5, + num_prefix_tokens: int = 1, + ordered: bool = False, + return_indices: bool = False, + ): + super().__init__() + assert 0 <= prob < 1. + self.prob = prob + self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens) + self.ordered = ordered + self.return_indices = return_indices + + def forward(self, x) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: + if not self.training or self.prob == 0.: + if self.return_indices: + return x, None + return x + + if self.num_prefix_tokens: + prefix_tokens, x = x[:, :self.num_prefix_tokens], x[:, self.num_prefix_tokens:] + else: + prefix_tokens = None + + B = x.shape[0] + L = x.shape[1] + num_keep = max(1, int(L * (1. - self.prob))) + keep_indices = torch.argsort(torch.randn(B, L, device=x.device), dim=-1)[:, :num_keep] + if self.ordered: + # NOTE does not need to maintain patch order in typical transformer use, + # but possibly useful for debug / visualization + keep_indices = keep_indices.sort(dim=-1)[0] + x = x.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + x.shape[2:])) + + if prefix_tokens is not None: + x = torch.cat((prefix_tokens, x), dim=1) + + if self.return_indices: + return x, keep_indices + return x diff --git a/timm/layers/pos_embed_sincos.py b/timm/layers/pos_embed_sincos.py index a305aa8a..7f340021 100644 --- a/timm/layers/pos_embed_sincos.py +++ b/timm/layers/pos_embed_sincos.py @@ -194,6 +194,8 @@ def rot(x): def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb): + if sin_emb.ndim == 3: + return x * cos_emb.unsqueeze(1).expand_as(x) + rot(x) * sin_emb.unsqueeze(1).expand_as(x) return x * cos_emb + rot(x) * sin_emb @@ -205,9 +207,17 @@ def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb): def apply_rot_embed_cat(x: torch.Tensor, emb): sin_emb, cos_emb = emb.tensor_split(2, -1) + if sin_emb.ndim == 3: + return x * cos_emb.unsqueeze(1).expand_as(x) + rot(x) * sin_emb.unsqueeze(1).expand_as(x) return x * cos_emb + rot(x) * sin_emb +def apply_keep_indices_nlc(x, pos_embed, keep_indices): + pos_embed = pos_embed.unsqueeze(0).expand(x.shape[0], -1, -1) + pos_embed = pos_embed.gather(1, keep_indices.unsqueeze(-1).expand(-1, -1, pos_embed.shape[-1])) + return pos_embed + + def build_rotary_pos_embed( feat_shape: List[int], bands: Optional[torch.Tensor] = None, diff --git a/timm/models/beit.py b/timm/models/beit.py index 0424d93e..457d86eb 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -279,6 +279,8 @@ class Beit(nn.Module): swiglu_mlp: bool = False, scale_mlp: bool = False, drop_rate: float = 0., + pos_drop_rate: float = 0., + proj_drop_rate: float = 0., attn_drop_rate: float = 0., drop_path_rate: float = 0., norm_layer: Callable = LayerNorm, @@ -306,7 +308,7 @@ class Beit(nn.Module): self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) if use_abs_pos_emb else None - self.pos_drop = nn.Dropout(p=drop_rate) + self.pos_drop = nn.Dropout(p=pos_drop_rate) if use_shared_rel_pos_bias: self.rel_pos_bias = RelativePositionBias( @@ -325,7 +327,7 @@ class Beit(nn.Module): mlp_ratio=mlp_ratio, scale_mlp=scale_mlp, swiglu_mlp=swiglu_mlp, - proj_drop=drop_rate, + proj_drop=proj_drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, @@ -337,6 +339,7 @@ class Beit(nn.Module): use_fc_norm = self.global_pool == 'avg' self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim) self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() + self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.apply(self._init_weights) @@ -417,6 +420,7 @@ class Beit(nn.Module): if self.global_pool: x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] x = self.fc_norm(x) + x = self.head_drop(x) return x if pre_logits else self.head(x) def forward(self, x): diff --git a/timm/models/cait.py b/timm/models/cait.py index 15dcd956..98c58397 100644 --- a/timm/models/cait.py +++ b/timm/models/cait.py @@ -110,17 +110,38 @@ class LayerScaleBlockClassAttn(nn.Module): # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # with slight modifications to add CA and LayerScale def __init__( - self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=ClassAttn, - mlp_block=Mlp, init_values=1e-4): + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + proj_drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + attn_block=ClassAttn, + mlp_block=Mlp, + init_values=1e-4, + ): super().__init__() self.norm1 = norm_layer(dim) self.attn = attn_block( - dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp = mlp_block( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=proj_drop, + ) self.gamma_1 = nn.Parameter(init_values * torch.ones(dim)) self.gamma_2 = nn.Parameter(init_values * torch.ones(dim)) @@ -177,17 +198,38 @@ class LayerScaleBlock(nn.Module): # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # with slight modifications to add layerScale def __init__( - self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=TalkingHeadAttn, - mlp_block=Mlp, init_values=1e-4): + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + proj_drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + attn_block=TalkingHeadAttn, + mlp_block=Mlp, + init_values=1e-4, + ): super().__init__() self.norm1 = norm_layer(dim) self.attn = attn_block( - dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp = mlp_block( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=proj_drop, + ) self.gamma_1 = nn.Parameter(init_values * torch.ones(dim)) self.gamma_2 = nn.Parameter(init_values * torch.ones(dim)) @@ -201,9 +243,22 @@ class Cait(nn.Module): # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # with slight modifications to adapt to our cait models def __init__( - self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', - embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0., + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + global_pool='token', + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + qkv_bias=True, + drop_rate=0., + pos_drop_rate=0., + proj_drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., block_layers=LayerScaleBlock, block_layers_token=LayerScaleBlockClassAttn, patch_layer=PatchEmbed, @@ -226,33 +281,50 @@ class Cait(nn.Module): self.grad_checkpointing = False self.patch_embed = patch_layer( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - self.pos_drop = nn.Dropout(p=drop_rate) + self.pos_drop = nn.Dropout(p=pos_drop_rate) dpr = [drop_path_rate for i in range(depth)] - self.blocks = nn.Sequential(*[ - block_layers( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, - act_layer=act_layer, attn_block=attn_block, mlp_block=mlp_block, init_values=init_values) - for i in range(depth)]) + self.blocks = nn.Sequential(*[block_layers( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + attn_block=attn_block, + mlp_block=mlp_block, + init_values=init_values, + ) for i in range(depth)]) - self.blocks_token_only = nn.ModuleList([ - block_layers_token( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio_token_only, qkv_bias=qkv_bias, - drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=norm_layer, - act_layer=act_layer, attn_block=attn_block_token_only, - mlp_block=mlp_block_token_only, init_values=init_values) - for i in range(depth_token_only)]) + self.blocks_token_only = nn.ModuleList([block_layers_token( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio_token_only, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + attn_block=attn_block_token_only, + mlp_block=mlp_block_token_only, + init_values=init_values, + ) for _ in range(depth_token_only)]) self.norm = norm_layer(embed_dim) self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')] + self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() trunc_normal_(self.pos_embed, std=.02) @@ -322,6 +394,7 @@ class Cait(nn.Module): def forward_head(self, x, pre_logits: bool = False): if self.global_pool: x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + x = self.head_drop(x) return x if pre_logits else self.head(x) def forward(self, x): @@ -344,77 +417,80 @@ def _create_cait(variant, pretrained=False, **kwargs): raise RuntimeError('features_only not implemented for Vision Transformer models.') model = build_model_with_cfg( - Cait, variant, pretrained, + Cait, + variant, + pretrained, pretrained_filter_fn=checkpoint_filter_fn, - **kwargs) + **kwargs, + ) return model @register_model def cait_xxs24_224(pretrained=False, **kwargs): - model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_values=1e-5, **kwargs) - model = _create_cait('cait_xxs24_224', pretrained=pretrained, **model_args) + model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_values=1e-5) + model = _create_cait('cait_xxs24_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def cait_xxs24_384(pretrained=False, **kwargs): - model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_values=1e-5, **kwargs) - model = _create_cait('cait_xxs24_384', pretrained=pretrained, **model_args) + model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_values=1e-5) + model = _create_cait('cait_xxs24_384', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def cait_xxs36_224(pretrained=False, **kwargs): - model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_values=1e-5, **kwargs) - model = _create_cait('cait_xxs36_224', pretrained=pretrained, **model_args) + model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_values=1e-5) + model = _create_cait('cait_xxs36_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def cait_xxs36_384(pretrained=False, **kwargs): - model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_values=1e-5, **kwargs) - model = _create_cait('cait_xxs36_384', pretrained=pretrained, **model_args) + model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_values=1e-5) + model = _create_cait('cait_xxs36_384', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def cait_xs24_384(pretrained=False, **kwargs): - model_args = dict(patch_size=16, embed_dim=288, depth=24, num_heads=6, init_values=1e-5, **kwargs) - model = _create_cait('cait_xs24_384', pretrained=pretrained, **model_args) + model_args = dict(patch_size=16, embed_dim=288, depth=24, num_heads=6, init_values=1e-5) + model = _create_cait('cait_xs24_384', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def cait_s24_224(pretrained=False, **kwargs): - model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_values=1e-5, **kwargs) - model = _create_cait('cait_s24_224', pretrained=pretrained, **model_args) + model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_values=1e-5) + model = _create_cait('cait_s24_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def cait_s24_384(pretrained=False, **kwargs): - model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_values=1e-5, **kwargs) - model = _create_cait('cait_s24_384', pretrained=pretrained, **model_args) + model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_values=1e-5) + model = _create_cait('cait_s24_384', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def cait_s36_384(pretrained=False, **kwargs): - model_args = dict(patch_size=16, embed_dim=384, depth=36, num_heads=8, init_values=1e-6, **kwargs) - model = _create_cait('cait_s36_384', pretrained=pretrained, **model_args) + model_args = dict(patch_size=16, embed_dim=384, depth=36, num_heads=8, init_values=1e-6) + model = _create_cait('cait_s36_384', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def cait_m36_384(pretrained=False, **kwargs): - model_args = dict(patch_size=16, embed_dim=768, depth=36, num_heads=16, init_values=1e-6, **kwargs) - model = _create_cait('cait_m36_384', pretrained=pretrained, **model_args) + model_args = dict(patch_size=16, embed_dim=768, depth=36, num_heads=16, init_values=1e-6) + model = _create_cait('cait_m36_384', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def cait_m48_448(pretrained=False, **kwargs): - model_args = dict(patch_size=16, embed_dim=768, depth=48, num_heads=16, init_values=1e-6, **kwargs) - model = _create_cait('cait_m48_448', pretrained=pretrained, **model_args) + model_args = dict(patch_size=16, embed_dim=768, depth=48, num_heads=16, init_values=1e-6) + model = _create_cait('cait_m48_448', pretrained=pretrained, **dict(model_args, **kwargs)) return model diff --git a/timm/models/coat.py b/timm/models/coat.py index 4ed6d8e8..f58d57a7 100644 --- a/timm/models/coat.py +++ b/timm/models/coat.py @@ -54,7 +54,7 @@ default_cfgs = { class ConvRelPosEnc(nn.Module): """ Convolutional relative position encoding. """ - def __init__(self, Ch, h, window): + def __init__(self, head_chs, num_heads, window): """ Initialization. Ch: Channels per head. @@ -70,7 +70,7 @@ class ConvRelPosEnc(nn.Module): if isinstance(window, int): # Set the same window size for all attention heads. - window = {window: h} + window = {window: num_heads} self.window = window elif isinstance(window, dict): self.window = window @@ -84,18 +84,20 @@ class ConvRelPosEnc(nn.Module): # Determine padding size. # Ref: https://discuss.pytorch.org/t/how-to-keep-the-shape-of-input-and-output-same-when-dilation-conv/14338 padding_size = (cur_window + (cur_window - 1) * (dilation - 1)) // 2 - cur_conv = nn.Conv2d(cur_head_split*Ch, cur_head_split*Ch, + cur_conv = nn.Conv2d( + cur_head_split * head_chs, + cur_head_split * head_chs, kernel_size=(cur_window, cur_window), padding=(padding_size, padding_size), dilation=(dilation, dilation), - groups=cur_head_split*Ch, + groups=cur_head_split * head_chs, ) self.conv_list.append(cur_conv) self.head_splits.append(cur_head_split) - self.channel_splits = [x*Ch for x in self.head_splits] + self.channel_splits = [x * head_chs for x in self.head_splits] def forward(self, q, v, size: Tuple[int, int]): - B, h, N, Ch = q.shape + B, num_heads, N, C = q.shape H, W = size _assert(N == 1 + H * W, '') @@ -103,13 +105,13 @@ class ConvRelPosEnc(nn.Module): q_img = q[:, :, 1:, :] # [B, h, H*W, Ch] v_img = v[:, :, 1:, :] # [B, h, H*W, Ch] - v_img = v_img.transpose(-1, -2).reshape(B, h * Ch, H, W) + v_img = v_img.transpose(-1, -2).reshape(B, num_heads * C, H, W) v_img_list = torch.split(v_img, self.channel_splits, dim=1) # Split according to channels conv_v_img_list = [] for i, conv in enumerate(self.conv_list): conv_v_img_list.append(conv(v_img_list[i])) conv_v_img = torch.cat(conv_v_img_list, dim=1) - conv_v_img = conv_v_img.reshape(B, h, Ch, H * W).transpose(-1, -2) + conv_v_img = conv_v_img.reshape(B, num_heads, C, H * W).transpose(-1, -2) EV_hat = q_img * conv_v_img EV_hat = F.pad(EV_hat, (0, 0, 1, 0, 0, 0)) # [B, h, N, Ch]. @@ -118,7 +120,15 @@ class ConvRelPosEnc(nn.Module): class FactorAttnConvRelPosEnc(nn.Module): """ Factorized attention with convolutional relative position encoding class. """ - def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., shared_crpe=None): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + attn_drop=0., + proj_drop=0., + shared_crpe=None, + ): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads @@ -188,8 +198,20 @@ class ConvPosEnc(nn.Module): class SerialBlock(nn.Module): """ Serial block class. Note: In this implementation, each serial block only contains a conv-attention and a FFN (MLP) module. """ - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, shared_cpe=None, shared_crpe=None): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + proj_drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + shared_cpe=None, + shared_crpe=None, + ): super().__init__() # Conv-Attention. @@ -197,13 +219,24 @@ class SerialBlock(nn.Module): self.norm1 = norm_layer(dim) self.factoratt_crpe = FactorAttnConvRelPosEnc( - dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpe) + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + shared_crpe=shared_crpe, + ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() # MLP. self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=proj_drop, + ) def forward(self, x, size: Tuple[int, int]): # Conv-Attention. @@ -222,8 +255,19 @@ class SerialBlock(nn.Module): class ParallelBlock(nn.Module): """ Parallel block class. """ - def __init__(self, dims, num_heads, mlp_ratios=[], qkv_bias=False, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, shared_crpes=None): + def __init__( + self, + dims, + num_heads, + mlp_ratios=[], + qkv_bias=False, + proj_drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + shared_crpes=None, + ): super().__init__() # Conv-Attention. @@ -231,16 +275,28 @@ class ParallelBlock(nn.Module): self.norm13 = norm_layer(dims[2]) self.norm14 = norm_layer(dims[3]) self.factoratt_crpe2 = FactorAttnConvRelPosEnc( - dims[1], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, - shared_crpe=shared_crpes[1] + dims[1], + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + shared_crpe=shared_crpes[1], ) self.factoratt_crpe3 = FactorAttnConvRelPosEnc( - dims[2], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, - shared_crpe=shared_crpes[2] + dims[2], + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + shared_crpe=shared_crpes[2], ) self.factoratt_crpe4 = FactorAttnConvRelPosEnc( - dims[3], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, - shared_crpe=shared_crpes[3] + dims[3], + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + shared_crpe=shared_crpes[3], ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -253,7 +309,11 @@ class ParallelBlock(nn.Module): assert mlp_ratios[1] == mlp_ratios[2] == mlp_ratios[3] mlp_hidden_dim = int(dims[1] * mlp_ratios[1]) self.mlp2 = self.mlp3 = self.mlp4 = Mlp( - in_features=dims[1], hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + in_features=dims[1], + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=proj_drop, + ) def upsample(self, x, factor: float, size: Tuple[int, int]): """ Feature map up-sampling. """ @@ -319,10 +379,27 @@ class ParallelBlock(nn.Module): class CoaT(nn.Module): """ CoaT class. """ def __init__( - self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=(0, 0, 0, 0), - serial_depths=(0, 0, 0, 0), parallel_depth=0, num_heads=0, mlp_ratios=(0, 0, 0, 0), qkv_bias=True, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), - return_interm_layers=False, out_features=None, crpe_window=None, global_pool='token'): + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dims=(0, 0, 0, 0), + serial_depths=(0, 0, 0, 0), + parallel_depth=0, + num_heads=0, + mlp_ratios=(0, 0, 0, 0), + qkv_bias=True, + drop_rate=0., + proj_drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_layer=partial(nn.LayerNorm, eps=1e-6), + return_interm_layers=False, + out_features=None, + crpe_window=None, + global_pool='token', + ): super().__init__() assert global_pool in ('token', 'avg') crpe_window = crpe_window or {3: 2, 5: 3, 7: 3} @@ -361,21 +438,31 @@ class CoaT(nn.Module): self.cpe4 = ConvPosEnc(dim=embed_dims[3], k=3) # Convolutional relative position encodings. - self.crpe1 = ConvRelPosEnc(Ch=embed_dims[0] // num_heads, h=num_heads, window=crpe_window) - self.crpe2 = ConvRelPosEnc(Ch=embed_dims[1] // num_heads, h=num_heads, window=crpe_window) - self.crpe3 = ConvRelPosEnc(Ch=embed_dims[2] // num_heads, h=num_heads, window=crpe_window) - self.crpe4 = ConvRelPosEnc(Ch=embed_dims[3] // num_heads, h=num_heads, window=crpe_window) + self.crpe1 = ConvRelPosEnc(head_chs=embed_dims[0] // num_heads, num_heads=num_heads, window=crpe_window) + self.crpe2 = ConvRelPosEnc(head_chs=embed_dims[1] // num_heads, num_heads=num_heads, window=crpe_window) + self.crpe3 = ConvRelPosEnc(head_chs=embed_dims[2] // num_heads, num_heads=num_heads, window=crpe_window) + self.crpe4 = ConvRelPosEnc(head_chs=embed_dims[3] // num_heads, num_heads=num_heads, window=crpe_window) # Disable stochastic depth. dpr = drop_path_rate assert dpr == 0.0 - + skwargs = dict( + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr, + norm_layer=norm_layer, + ) + # Serial blocks 1. self.serial_blocks1 = nn.ModuleList([ SerialBlock( - dim=embed_dims[0], num_heads=num_heads, mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, - shared_cpe=self.cpe1, shared_crpe=self.crpe1 + dim=embed_dims[0], + mlp_ratio=mlp_ratios[0], + shared_cpe=self.cpe1, + shared_crpe=self.crpe1, + **skwargs, ) for _ in range(serial_depths[0])] ) @@ -383,9 +470,11 @@ class CoaT(nn.Module): # Serial blocks 2. self.serial_blocks2 = nn.ModuleList([ SerialBlock( - dim=embed_dims[1], num_heads=num_heads, mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, - shared_cpe=self.cpe2, shared_crpe=self.crpe2 + dim=embed_dims[1], + mlp_ratio=mlp_ratios[1], + shared_cpe=self.cpe2, + shared_crpe=self.crpe2, + **skwargs, ) for _ in range(serial_depths[1])] ) @@ -393,9 +482,11 @@ class CoaT(nn.Module): # Serial blocks 3. self.serial_blocks3 = nn.ModuleList([ SerialBlock( - dim=embed_dims[2], num_heads=num_heads, mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, - shared_cpe=self.cpe3, shared_crpe=self.crpe3 + dim=embed_dims[2], + mlp_ratio=mlp_ratios[2], + shared_cpe=self.cpe3, + shared_crpe=self.crpe3, + **skwargs, ) for _ in range(serial_depths[2])] ) @@ -403,9 +494,11 @@ class CoaT(nn.Module): # Serial blocks 4. self.serial_blocks4 = nn.ModuleList([ SerialBlock( - dim=embed_dims[3], num_heads=num_heads, mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, - shared_cpe=self.cpe4, shared_crpe=self.crpe4 + dim=embed_dims[3], + mlp_ratio=mlp_ratios[3], + shared_cpe=self.cpe4, + shared_crpe=self.crpe4, + **skwargs, ) for _ in range(serial_depths[3])] ) @@ -415,9 +508,10 @@ class CoaT(nn.Module): if self.parallel_depth > 0: self.parallel_blocks = nn.ModuleList([ ParallelBlock( - dims=embed_dims, num_heads=num_heads, mlp_ratios=mlp_ratios, qkv_bias=qkv_bias, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, - shared_crpes=(self.crpe1, self.crpe2, self.crpe3, self.crpe4) + dims=embed_dims, + mlp_ratios=mlp_ratios, + shared_crpes=(self.crpe1, self.crpe2, self.crpe3, self.crpe4), + **skwargs, ) for _ in range(parallel_depth)] ) @@ -437,10 +531,12 @@ class CoaT(nn.Module): # CoaT series: Aggregate features of last three scales for classification. assert embed_dims[1] == embed_dims[2] == embed_dims[3] self.aggregate = torch.nn.Conv1d(in_channels=3, out_channels=1, kernel_size=1) + self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() else: # CoaT-Lite series: Use feature of last scale for classification. self.aggregate = None + self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() # Initialize weights. @@ -587,6 +683,7 @@ class CoaT(nn.Module): x = self.aggregate(x).squeeze(dim=1) # Shape: [B, C] else: x = x_feat[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x_feat[:, 0] + x = self.head_drop(x) return x if pre_logits else self.head(x) def forward(self, x) -> torch.Tensor: diff --git a/timm/models/convit.py b/timm/models/convit.py index d117ccdc..19217418 100644 --- a/timm/models/convit.py +++ b/timm/models/convit.py @@ -62,7 +62,14 @@ default_cfgs = { @register_notrace_module # reason: FX can't symbolically trace control flow in forward method class GPSA(nn.Module): def __init__( - self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., locality_strength=1.): + self, + dim, + num_heads=8, + qkv_bias=False, + attn_drop=0., + proj_drop=0., + locality_strength=1., + ): super().__init__() self.num_heads = num_heads self.dim = dim @@ -145,7 +152,14 @@ class GPSA(nn.Module): class MHSA(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + attn_drop=0., + proj_drop=0., + ): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads @@ -195,20 +209,48 @@ class MHSA(nn.Module): class Block(nn.Module): def __init__( - self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_gpsa=True, **kwargs): + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + proj_drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + use_gpsa=True, + locality_strength=1., + ): super().__init__() self.norm1 = norm_layer(dim) self.use_gpsa = use_gpsa if self.use_gpsa: self.attn = GPSA( - dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, **kwargs) + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + locality_strength=locality_strength, + ) else: - self.attn = MHSA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.attn = MHSA( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=proj_drop, + ) def forward(self, x): x = x + self.drop_path(self.attn(self.norm1(x))) @@ -221,10 +263,28 @@ class ConViT(nn.Module): """ def __init__( - self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', - embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0., - drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, - local_up_to_layer=3, locality_strength=1., use_pos_embed=True): + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + global_pool='token', + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + qkv_bias=False, + drop_rate=0., + pos_drop_rate=0., + proj_drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + hybrid_backbone=None, + norm_layer=nn.LayerNorm, + local_up_to_layer=3, + locality_strength=1., + use_pos_embed=True, + ): super().__init__() assert global_pool in ('', 'avg', 'token') embed_dim *= num_heads @@ -245,7 +305,7 @@ class ConViT(nn.Module): self.num_patches = num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - self.pos_drop = nn.Dropout(p=drop_rate) + self.pos_drop = nn.Dropout(p=pos_drop_rate) if self.use_pos_embed: self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) @@ -254,20 +314,22 @@ class ConViT(nn.Module): dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.blocks = nn.ModuleList([ Block( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, - use_gpsa=True, - locality_strength=locality_strength) - if i < local_up_to_layer else - Block( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, - use_gpsa=False) - for i in range(depth)]) + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + use_gpsa=i < local_up_to_layer, + locality_strength=locality_strength, + ) for i in range(depth)]) self.norm = norm_layer(embed_dim) # Classifier head self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')] + self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() trunc_normal_(self.cls_token, std=.02) @@ -327,6 +389,7 @@ class ConViT(nn.Module): def forward_head(self, x, pre_logits: bool = False): if self.global_pool: x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + x = self.head_drop(x) return x if pre_logits else self.head(x) def forward(self, x): diff --git a/timm/models/convmixer.py b/timm/models/convmixer.py index 3a8c6cf5..ff9f2143 100644 --- a/timm/models/convmixer.py +++ b/timm/models/convmixer.py @@ -42,8 +42,18 @@ class Residual(nn.Module): class ConvMixer(nn.Module): def __init__( - self, dim, depth, kernel_size=9, patch_size=7, in_chans=3, num_classes=1000, global_pool='avg', - act_layer=nn.GELU, **kwargs): + self, + dim, + depth, + kernel_size=9, + patch_size=7, + in_chans=3, + num_classes=1000, + global_pool='avg', + drop_rate=0., + act_layer=nn.GELU, + **kwargs, + ): super().__init__() self.num_classes = num_classes self.num_features = dim @@ -67,6 +77,7 @@ class ConvMixer(nn.Module): ) for i in range(depth)] ) self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) + self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(dim, num_classes) if num_classes > 0 else nn.Identity() @torch.jit.ignore @@ -98,6 +109,7 @@ class ConvMixer(nn.Module): def forward_head(self, x, pre_logits: bool = False): x = self.pooling(x) + x = self.head_drop(x) return x if pre_logits else self.head(x) def forward(self, x): diff --git a/timm/models/crossvit.py b/timm/models/crossvit.py index 908fcf6d..54995291 100644 --- a/timm/models/crossvit.py +++ b/timm/models/crossvit.py @@ -130,12 +130,19 @@ class PatchEmbed(nn.Module): class CrossAttention(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + attn_drop=0., + proj_drop=0., + ): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights - self.scale = qk_scale or head_dim ** -0.5 + self.scale = head_dim ** -0.5 self.wq = nn.Linear(dim, dim, bias=qkv_bias) self.wk = nn.Linear(dim, dim, bias=qkv_bias) @@ -166,12 +173,26 @@ class CrossAttention(nn.Module): class CrossAttentionBlock(nn.Module): def __init__( - self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + proj_drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): super().__init__() self.norm1 = norm_layer(dim) self.attn = CrossAttention( - dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + ) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -182,8 +203,20 @@ class CrossAttentionBlock(nn.Module): class MultiScaleBlock(nn.Module): - def __init__(self, dim, patches, depth, num_heads, mlp_ratio, qkv_bias=False, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + def __init__( + self, + dim, + patches, + depth, + num_heads, + mlp_ratio, + qkv_bias=False, + proj_drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): super().__init__() num_branches = len(dim) @@ -194,8 +227,15 @@ class MultiScaleBlock(nn.Module): tmp = [] for i in range(depth[d]): tmp.append(Block( - dim=dim[d], num_heads=num_heads[d], mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, - drop=drop, attn_drop=attn_drop, drop_path=drop_path[i], norm_layer=norm_layer)) + dim=dim[d], + num_heads=num_heads[d], + mlp_ratio=mlp_ratio[d], + qkv_bias=qkv_bias, + proj_drop=proj_drop, + attn_drop=attn_drop, + drop_path=drop_path[i], + norm_layer=norm_layer, + )) if len(tmp) != 0: self.blocks.append(nn.Sequential(*tmp)) @@ -217,14 +257,28 @@ class MultiScaleBlock(nn.Module): if depth[-1] == 0: # backward capability: self.fusion.append( CrossAttentionBlock( - dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, - drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer)) + dim=dim[d_], + num_heads=nh, + mlp_ratio=mlp_ratio[d], + qkv_bias=qkv_bias, + proj_drop=proj_drop, + attn_drop=attn_drop, + drop_path=drop_path[-1], + norm_layer=norm_layer, + )) else: tmp = [] for _ in range(depth[-1]): tmp.append(CrossAttentionBlock( - dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, - drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer)) + dim=dim[d_], + num_heads=nh, + mlp_ratio=mlp_ratio[d], + qkv_bias=qkv_bias, + proj_drop=proj_drop, + attn_drop=attn_drop, + drop_path=drop_path[-1], + norm_layer=norm_layer, + )) self.fusion.append(nn.Sequential(*tmp)) self.revert_projs = nn.ModuleList() @@ -288,10 +342,26 @@ class CrossViT(nn.Module): """ def __init__( - self, img_size=224, img_scale=(1.0, 1.0), patch_size=(8, 16), in_chans=3, num_classes=1000, - embed_dim=(192, 384), depth=((1, 3, 1), (1, 3, 1), (1, 3, 1)), num_heads=(6, 12), mlp_ratio=(2., 2., 4.), - multi_conv=False, crop_scale=False, qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., - norm_layer=partial(nn.LayerNorm, eps=1e-6), global_pool='token', + self, + img_size=224, + img_scale=(1.0, 1.0), + patch_size=(8, 16), + in_chans=3, + num_classes=1000, + embed_dim=(192, 384), + depth=((1, 3, 1), (1, 3, 1), (1, 3, 1)), + num_heads=(6, 12), + mlp_ratio=(2., 2., 4.), + multi_conv=False, + crop_scale=False, + qkv_bias=True, + drop_rate=0., + pos_drop_rate=0., + proj_drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_layer=partial(nn.LayerNorm, eps=1e-6), + global_pool='token', ): super().__init__() assert global_pool in ('token', 'avg') @@ -315,9 +385,15 @@ class CrossViT(nn.Module): for im_s, p, d in zip(self.img_size_scaled, patch_size, embed_dim): self.patch_embed.append( - PatchEmbed(img_size=im_s, patch_size=p, in_chans=in_chans, embed_dim=d, multi_conv=multi_conv)) + PatchEmbed( + img_size=im_s, + patch_size=p, + in_chans=in_chans, + embed_dim=d, + multi_conv=multi_conv, + )) - self.pos_drop = nn.Dropout(p=drop_rate) + self.pos_drop = nn.Dropout(p=pos_drop_rate) total_depth = sum([sum(x[-2:]) for x in depth]) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_depth)] # stochastic depth decay rule @@ -327,12 +403,22 @@ class CrossViT(nn.Module): curr_depth = max(block_cfg[:-1]) + block_cfg[-1] dpr_ = dpr[dpr_ptr:dpr_ptr + curr_depth] blk = MultiScaleBlock( - embed_dim, num_patches, block_cfg, num_heads=num_heads, mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr_, norm_layer=norm_layer) + embed_dim, + num_patches, + block_cfg, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr_, + norm_layer=norm_layer, + ) dpr_ptr += curr_depth self.blocks.append(blk) self.norm = nn.ModuleList([norm_layer(embed_dim[i]) for i in range(self.num_branches)]) + self.head_drop = nn.Dropout(drop_rate) self.head = nn.ModuleList([ nn.Linear(embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in range(self.num_branches)]) @@ -411,6 +497,7 @@ class CrossViT(nn.Module): def forward_head(self, xs: List[torch.Tensor], pre_logits: bool = False) -> torch.Tensor: xs = [x[:, 1:].mean(dim=1) for x in xs] if self.global_pool == 'avg' else [x[:, 0] for x in xs] + xs = [self.head_drop(x) for x in xs] if pre_logits or isinstance(self.head[0], nn.Identity): return torch.cat([x for x in xs], dim=1) return torch.mean(torch.stack([head(xs[i]) for i, head in enumerate(self.head)], dim=0), dim=0) @@ -436,17 +523,20 @@ def _create_crossvit(variant, pretrained=False, **kwargs): return new_state_dict return build_model_with_cfg( - CrossViT, variant, pretrained, + CrossViT, + variant, + pretrained, pretrained_filter_fn=pretrained_filter_fn, - **kwargs) + **kwargs, + ) @register_model def crossvit_tiny_240(pretrained=False, **kwargs): model_args = dict( img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[96, 192], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]], - num_heads=[3, 3], mlp_ratio=[4, 4, 1], **kwargs) - model = _create_crossvit(variant='crossvit_tiny_240', pretrained=pretrained, **model_args) + num_heads=[3, 3], mlp_ratio=[4, 4, 1]) + model = _create_crossvit(variant='crossvit_tiny_240', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -454,8 +544,8 @@ def crossvit_tiny_240(pretrained=False, **kwargs): def crossvit_small_240(pretrained=False, **kwargs): model_args = dict( img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]], - num_heads=[6, 6], mlp_ratio=[4, 4, 1], **kwargs) - model = _create_crossvit(variant='crossvit_small_240', pretrained=pretrained, **model_args) + num_heads=[6, 6], mlp_ratio=[4, 4, 1]) + model = _create_crossvit(variant='crossvit_small_240', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -463,8 +553,8 @@ def crossvit_small_240(pretrained=False, **kwargs): def crossvit_base_240(pretrained=False, **kwargs): model_args = dict( img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[384, 768], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]], - num_heads=[12, 12], mlp_ratio=[4, 4, 1], **kwargs) - model = _create_crossvit(variant='crossvit_base_240', pretrained=pretrained, **model_args) + num_heads=[12, 12], mlp_ratio=[4, 4, 1]) + model = _create_crossvit(variant='crossvit_base_240', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -472,8 +562,8 @@ def crossvit_base_240(pretrained=False, **kwargs): def crossvit_9_240(pretrained=False, **kwargs): model_args = dict( img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]], - num_heads=[4, 4], mlp_ratio=[3, 3, 1], **kwargs) - model = _create_crossvit(variant='crossvit_9_240', pretrained=pretrained, **model_args) + num_heads=[4, 4], mlp_ratio=[3, 3, 1]) + model = _create_crossvit(variant='crossvit_9_240', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -481,8 +571,8 @@ def crossvit_9_240(pretrained=False, **kwargs): def crossvit_15_240(pretrained=False, **kwargs): model_args = dict( img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]], - num_heads=[6, 6], mlp_ratio=[3, 3, 1], **kwargs) - model = _create_crossvit(variant='crossvit_15_240', pretrained=pretrained, **model_args) + num_heads=[6, 6], mlp_ratio=[3, 3, 1]) + model = _create_crossvit(variant='crossvit_15_240', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -491,7 +581,7 @@ def crossvit_18_240(pretrained=False, **kwargs): model_args = dict( img_scale=(1.0, 224 / 240), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]], num_heads=[7, 7], mlp_ratio=[3, 3, 1], **kwargs) - model = _create_crossvit(variant='crossvit_18_240', pretrained=pretrained, **model_args) + model = _create_crossvit(variant='crossvit_18_240', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -499,8 +589,8 @@ def crossvit_18_240(pretrained=False, **kwargs): def crossvit_9_dagger_240(pretrained=False, **kwargs): model_args = dict( img_scale=(1.0, 224 / 240), patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]], - num_heads=[4, 4], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs) - model = _create_crossvit(variant='crossvit_9_dagger_240', pretrained=pretrained, **model_args) + num_heads=[4, 4], mlp_ratio=[3, 3, 1], multi_conv=True) + model = _create_crossvit(variant='crossvit_9_dagger_240', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -508,8 +598,8 @@ def crossvit_9_dagger_240(pretrained=False, **kwargs): def crossvit_15_dagger_240(pretrained=False, **kwargs): model_args = dict( img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]], - num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs) - model = _create_crossvit(variant='crossvit_15_dagger_240', pretrained=pretrained, **model_args) + num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True) + model = _create_crossvit(variant='crossvit_15_dagger_240', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -517,8 +607,8 @@ def crossvit_15_dagger_240(pretrained=False, **kwargs): def crossvit_15_dagger_408(pretrained=False, **kwargs): model_args = dict( img_scale=(1.0, 384/408), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]], - num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs) - model = _create_crossvit(variant='crossvit_15_dagger_408', pretrained=pretrained, **model_args) + num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True) + model = _create_crossvit(variant='crossvit_15_dagger_408', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -526,8 +616,8 @@ def crossvit_15_dagger_408(pretrained=False, **kwargs): def crossvit_18_dagger_240(pretrained=False, **kwargs): model_args = dict( img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]], - num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs) - model = _create_crossvit(variant='crossvit_18_dagger_240', pretrained=pretrained, **model_args) + num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True) + model = _create_crossvit(variant='crossvit_18_dagger_240', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -535,6 +625,6 @@ def crossvit_18_dagger_240(pretrained=False, **kwargs): def crossvit_18_dagger_408(pretrained=False, **kwargs): model_args = dict( img_scale=(1.0, 384/408), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]], - num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs) - model = _create_crossvit(variant='crossvit_18_dagger_408', pretrained=pretrained, **model_args) + num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True) + model = _create_crossvit(variant='crossvit_18_dagger_408', pretrained=pretrained, **dict(model_args, **kwargs)) return model diff --git a/timm/models/davit.py b/timm/models/davit.py index 4b3ef1ab..bbc5e421 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -468,7 +468,6 @@ class DaViT(nn.Module): ffn=True, cpe_act=False, drop_rate=0., - attn_drop_rate=0., drop_path_rate=0., num_classes=1000, global_pool='avg', diff --git a/timm/models/edgenext.py b/timm/models/edgenext.py index d90471fb..e770a5be 100644 --- a/timm/models/edgenext.py +++ b/timm/models/edgenext.py @@ -21,49 +21,11 @@ from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdap from ._builder import build_model_with_cfg from ._features_fx import register_notrace_module from ._manipulate import named_apply, checkpoint_seq -from ._registry import register_model +from ._registry import register_model, generate_default_cfgs __all__ = ['EdgeNeXt'] # model_registry will add each entrypoint fn to this -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8), - 'crop_pct': 0.9, 'interpolation': 'bicubic', - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'stem.0', 'classifier': 'head.fc', - **kwargs - } - - -default_cfgs = dict( - edgenext_xx_small=_cfg( - url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_xx_small.pth", - test_input_size=(3, 288, 288), test_crop_pct=1.0), - edgenext_x_small=_cfg( - url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_x_small.pth", - test_input_size=(3, 288, 288), test_crop_pct=1.0), - # edgenext_small=_cfg( - # url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_small.pth"), - edgenext_small=_cfg( # USI weights - url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.1/edgenext_small_usi.pth", - crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0, - ), - # edgenext_base=_cfg( - # url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.2/edgenext_base_usi.pth"), - edgenext_base=_cfg( # USI weights - url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.2/edgenext_base_usi.pth", - crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0, - ), - - edgenext_small_rw=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/edgenext_small_rw-sw-b00041bb.pth', - test_input_size=(3, 320, 320), test_crop_pct=1.0, - ), -) - - @register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method class PositionalEncodingFourier(nn.Module): def __init__(self, hidden_dim=32, dim=768, temperature=10000): @@ -519,6 +481,43 @@ def _create_edgenext(variant, pretrained=False, **kwargs): return model +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8), + 'crop_pct': 0.9, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.0', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'edgenext_xx_small.in1k': _cfg( + url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_xx_small.pth", + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'edgenext_x_small.in1k': _cfg( + url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_x_small.pth", + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'edgenext_small.usi_in1k': _cfg( # USI weights + url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.1/edgenext_small_usi.pth", + crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0, + ), + 'edgenext_base.usi_in1k': _cfg( # USI weights + url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.2/edgenext_base_usi.pth", + crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0, + ), + 'edgenext_base.in21k_ft_in1k': _cfg( # USI weights + url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.21/edgenext_base_IN21K.pth", + crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0, + ), + 'edgenext_small_rw.sw_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/edgenext_small_rw-sw-b00041bb.pth', + test_input_size=(3, 320, 320), test_crop_pct=1.0, + ), +}) + + @register_model def edgenext_xx_small(pretrained=False, **kwargs): # 1.33M & 260.58M @ 256 resolution diff --git a/timm/models/efficientformer.py b/timm/models/efficientformer.py index 3fd1cc7f..62ff36de 100644 --- a/timm/models/efficientformer.py +++ b/timm/models/efficientformer.py @@ -211,7 +211,7 @@ class MetaBlock1d(nn.Module): mlp_ratio=4., act_layer=nn.GELU, norm_layer=nn.LayerNorm, - drop=0., + proj_drop=0., drop_path=0., layer_scale_init_value=1e-5 ): @@ -219,7 +219,12 @@ class MetaBlock1d(nn.Module): self.norm1 = norm_layer(dim) self.token_mixer = Attention(dim) self.norm2 = norm_layer(dim) - self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.ls1 = LayerScale(dim, layer_scale_init_value) @@ -251,7 +256,7 @@ class MetaBlock2d(nn.Module): mlp_ratio=4., act_layer=nn.GELU, norm_layer=nn.BatchNorm2d, - drop=0., + proj_drop=0., drop_path=0., layer_scale_init_value=1e-5 ): @@ -261,7 +266,12 @@ class MetaBlock2d(nn.Module): self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.mlp = ConvMlpWithNorm( - dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, norm_layer=norm_layer, drop=drop) + dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + norm_layer=norm_layer, + drop=proj_drop, + ) self.ls2 = LayerScale2d(dim, layer_scale_init_value) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -285,7 +295,7 @@ class EfficientFormerStage(nn.Module): act_layer=nn.GELU, norm_layer=nn.BatchNorm2d, norm_layer_cl=nn.LayerNorm, - drop=.0, + proj_drop=.0, drop_path=0., layer_scale_init_value=1e-5, ): @@ -312,7 +322,7 @@ class EfficientFormerStage(nn.Module): mlp_ratio=mlp_ratio, act_layer=act_layer, norm_layer=norm_layer_cl, - drop=drop, + proj_drop=proj_drop, drop_path=drop_path[block_idx], layer_scale_init_value=layer_scale_init_value, )) @@ -324,7 +334,7 @@ class EfficientFormerStage(nn.Module): mlp_ratio=mlp_ratio, act_layer=act_layer, norm_layer=norm_layer, - drop=drop, + proj_drop=proj_drop, drop_path=drop_path[block_idx], layer_scale_init_value=layer_scale_init_value, )) @@ -360,6 +370,7 @@ class EfficientFormer(nn.Module): norm_layer=nn.BatchNorm2d, norm_layer_cl=nn.LayerNorm, drop_rate=0., + proj_drop_rate=0., drop_path_rate=0., **kwargs ): @@ -386,7 +397,7 @@ class EfficientFormer(nn.Module): act_layer=act_layer, norm_layer_cl=norm_layer_cl, norm_layer=norm_layer, - drop=drop_rate, + proj_drop=proj_drop_rate, drop_path=dpr[i], layer_scale_init_value=layer_scale_init_value, ) @@ -398,6 +409,7 @@ class EfficientFormer(nn.Module): # Classifier head self.num_features = embed_dims[-1] self.norm = norm_layer_cl(self.num_features) + self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() # assuming model is always distilled (valid for current checkpoints, will split def if that changes) self.head_dist = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() @@ -453,6 +465,7 @@ class EfficientFormer(nn.Module): def forward_head(self, x, pre_logits: bool = False): if self.global_pool == 'avg': x = x.mean(dim=1) + x = self.head_drop(x) if pre_logits: return x x, x_dist = self.head(x), self.head_dist(x) diff --git a/timm/models/efficientformer_v2.py b/timm/models/efficientformer_v2.py index 444fa73a..166aaaae 100644 --- a/timm/models/efficientformer_v2.py +++ b/timm/models/efficientformer_v2.py @@ -342,7 +342,8 @@ class ConvMlpWithNorm(nn.Module): out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = ConvNormAct( - in_features, hidden_features, 1, bias=True, norm_layer=norm_layer, act_layer=act_layer) + in_features, hidden_features, 1, + bias=True, norm_layer=norm_layer, act_layer=act_layer) if mid_conv: self.mid = ConvNormAct( hidden_features, hidden_features, 3, @@ -380,7 +381,7 @@ class EfficientFormerV2Block(nn.Module): mlp_ratio=4., act_layer=nn.GELU, norm_layer=nn.BatchNorm2d, - drop=0., + proj_drop=0., drop_path=0., layer_scale_init_value=1e-5, resolution=7, @@ -409,7 +410,7 @@ class EfficientFormerV2Block(nn.Module): hidden_features=int(dim * mlp_ratio), act_layer=act_layer, norm_layer=norm_layer, - drop=drop, + drop=proj_drop, mid_conv=True, ) self.ls2 = LayerScale2d( @@ -451,7 +452,7 @@ class EfficientFormerV2Stage(nn.Module): block_use_attn=False, num_vit=1, mlp_ratio=4., - drop=.0, + proj_drop=.0, drop_path=0., layer_scale_init_value=1e-5, act_layer=nn.GELU, @@ -487,7 +488,7 @@ class EfficientFormerV2Stage(nn.Module): stride=block_stride, mlp_ratio=mlp_ratio[block_idx], use_attn=block_use_attn and block_idx > remain_idx, - drop=drop, + proj_drop=proj_drop, drop_path=drop_path[block_idx], layer_scale_init_value=layer_scale_init_value, act_layer=act_layer, @@ -520,6 +521,7 @@ class EfficientFormerV2(nn.Module): act_layer='gelu', num_classes=1000, drop_rate=0., + proj_drop_rate=0., drop_path_rate=0., layer_scale_init_value=1e-5, num_vit=0, @@ -556,7 +558,7 @@ class EfficientFormerV2(nn.Module): block_use_attn=i >= 2, num_vit=num_vit, mlp_ratio=mlp_ratios[i], - drop=drop_rate, + proj_drop=proj_drop_rate, drop_path=dpr[i], layer_scale_init_value=layer_scale_init_value, act_layer=act_layer, @@ -572,6 +574,7 @@ class EfficientFormerV2(nn.Module): # Classifier head self.num_features = embed_dims[-1] self.norm = norm_layer(embed_dims[-1]) + self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() self.dist = distillation if self.dist: @@ -630,6 +633,7 @@ class EfficientFormerV2(nn.Module): def forward_head(self, x, pre_logits: bool = False): if self.global_pool == 'avg': x = x.mean(dim=(2, 3)) + x = self.head_drop(x) if pre_logits: return x x, x_dist = self.head(x), self.head_dist(x) diff --git a/timm/models/eva.py b/timm/models/eva.py index 21c20335..b5cd1fec 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -34,8 +34,8 @@ import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD -from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, RotaryEmbeddingCat, \ - apply_rot_embed_cat, trunc_normal_, resample_patch_embed, resample_abs_pos_embed, to_2tuple +from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, PatchDropout, RotaryEmbeddingCat, \ + apply_rot_embed_cat, apply_keep_indices_nlc, trunc_normal_, resample_patch_embed, resample_abs_pos_embed, to_2tuple from ._builder import build_model_with_cfg from ._registry import generate_default_cfgs, register_model @@ -355,10 +355,14 @@ class Eva(nn.Module): scale_mlp: bool = False, scale_attn_inner: bool = False, drop_rate: float = 0., + pos_drop_rate: float = 0., + patch_drop_rate: float = 0., + proj_drop_rate: float = 0., attn_drop_rate: float = 0., drop_path_rate: float = 0., norm_layer: Callable = LayerNorm, init_values: Optional[float] = None, + class_token: bool = True, use_abs_pos_emb: bool = True, use_rot_pos_emb: bool = False, use_post_norm: bool = False, @@ -383,10 +387,13 @@ class Eva(nn.Module): scale_mlp: scale_attn_inner: drop_rate: + pos_drop_rate: + proj_drop_rate: attn_drop_rate: drop_path_rate: norm_layer: init_values: + class_token: use_abs_pos_emb: use_rot_pos_emb: use_post_norm: @@ -397,7 +404,7 @@ class Eva(nn.Module): self.num_classes = num_classes self.global_pool = global_pool self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models - self.num_prefix_tokens = 1 + self.num_prefix_tokens = 1 if class_token else 0 self.grad_checkpointing = False self.patch_embed = PatchEmbed( @@ -408,9 +415,19 @@ class Eva(nn.Module): ) num_patches = self.patch_embed.num_patches - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) if use_abs_pos_emb else None - self.pos_drop = nn.Dropout(p=drop_rate) + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None + + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + self.num_prefix_tokens, embed_dim)) if use_abs_pos_emb else None + self.pos_drop = nn.Dropout(p=pos_drop_rate) + if patch_drop_rate > 0: + self.patch_drop = PatchDropout( + patch_drop_rate, + num_prefix_tokens=self.num_prefix_tokens, + return_indices=True, + ) + else: + self.patch_drop = None if use_rot_pos_emb: ref_feat_shape = to_2tuple(ref_feat_shape) if ref_feat_shape is not None else None @@ -435,7 +452,7 @@ class Eva(nn.Module): swiglu_mlp=swiglu_mlp, scale_mlp=scale_mlp, scale_attn_inner=scale_attn_inner, - proj_drop=drop_rate, + proj_drop=proj_drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, @@ -446,6 +463,7 @@ class Eva(nn.Module): use_fc_norm = self.global_pool == 'avg' self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim) self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() + self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.apply(self._init_weights) @@ -505,12 +523,21 @@ class Eva(nn.Module): def forward_features(self, x): x = self.patch_embed(x) - x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + + if self.cls_token is not None: + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + + # apply abs position embedding if self.pos_embed is not None: x = x + self.pos_embed x = self.pos_drop(x) + # obtain shared rotary position embedding and apply patch dropout rot_pos_embed = self.rope.get_embed() if self.rope is not None else None + if self.patch_drop is not None: + x, keep_indices = self.patch_drop(x) + if rot_pos_embed is not None and keep_indices is not None: + rot_pos_embed = apply_keep_indices_nlc(x, rot_pos_embed, keep_indices) for blk in self.blocks: if self.grad_checkpointing and not torch.jit.is_scripting(): @@ -525,6 +552,7 @@ class Eva(nn.Module): if self.global_pool: x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] x = self.fc_norm(x) + x = self.head_drop(x) return x if pre_logits else self.head(x) def forward(self, x): diff --git a/timm/models/focalnet.py b/timm/models/focalnet.py index 57e2352f..104b09fc 100644 --- a/timm/models/focalnet.py +++ b/timm/models/focalnet.py @@ -75,12 +75,6 @@ class FocalModulation(nn.Module): self.norm = norm_layer(dim) if self.use_post_norm else nn.Identity() def forward(self, x): - """ - Args: - x: input features with shape of (B, H, W, C) - """ - C = x.shape[1] - # pre linear projection x = self.f(x) q, ctx, gates = torch.split(x, self.input_split, 1) diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index bdac18e1..2f8d55fa 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -192,6 +192,7 @@ class MlpMixer(nn.Module): norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop_rate=0., + proj_drop_rate=0., drop_path_rate=0., nlhb=False, stem_norm=False, @@ -219,11 +220,12 @@ class MlpMixer(nn.Module): mlp_layer=mlp_layer, norm_layer=norm_layer, act_layer=act_layer, - drop=drop_rate, + drop=proj_drop_rate, drop_path=drop_path_rate, ) for _ in range(num_blocks)]) self.norm = norm_layer(embed_dim) + self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() self.init_weights(nlhb=nlhb) @@ -267,6 +269,7 @@ class MlpMixer(nn.Module): def forward_head(self, x, pre_logits: bool = False): if self.global_pool == 'avg': x = x.mean(dim=1) + x = self.head_drop(x) return x if pre_logits else self.head(x) def forward(self, x): diff --git a/timm/models/mobilevit.py b/timm/models/mobilevit.py index 6466f127..6d51c263 100644 --- a/timm/models/mobilevit.py +++ b/timm/models/mobilevit.py @@ -271,7 +271,7 @@ class MobileVitBlock(nn.Module): num_heads=num_heads, qkv_bias=True, attn_drop=attn_drop, - drop=drop, + proj_drop=drop, drop_path=drop_path_rate, act_layer=layers.act, norm_layer=transformer_norm_layer, diff --git a/timm/models/mvitv2.py b/timm/models/mvitv2.py index 5c0a6650..23d7a947 100644 --- a/timm/models/mvitv2.py +++ b/timm/models/mvitv2.py @@ -27,43 +27,11 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple from ._builder import build_model_with_cfg from ._features_fx import register_notrace_function -from ._registry import register_model +from ._registry import register_model, register_model_deprecations, generate_default_cfgs __all__ = ['MultiScaleVit', 'MultiScaleVitCfg'] # model_registry will add each entrypoint fn to this -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'crop_pct': .9, 'interpolation': 'bicubic', - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc', - 'fixed_input_size': True, - **kwargs - } - - -default_cfgs = dict( - mvitv2_tiny=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_T_in1k.pyth'), - mvitv2_small=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_S_in1k.pyth'), - mvitv2_base=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_B_in1k.pyth'), - mvitv2_large=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_L_in1k.pyth'), - - mvitv2_base_in21k=_cfg( - url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_B_in21k.pyth', - num_classes=19168), - mvitv2_large_in21k=_cfg( - url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_L_in21k.pyth', - num_classes=19168), - mvitv2_huge_in21k=_cfg( - url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_H_in21k.pyth', - num_classes=19168), - - mvitv2_small_cls=_cfg(url=''), -) - - @dataclass class MultiScaleVitCfg: depths: Tuple[int, ...] = (2, 3, 16, 3) @@ -113,40 +81,6 @@ class MultiScaleVitCfg: self.stride_kv = tuple(pool_kv_stride) -model_cfgs = dict( - mvitv2_tiny=MultiScaleVitCfg( - depths=(1, 2, 5, 2), - ), - mvitv2_small=MultiScaleVitCfg( - depths=(1, 2, 11, 2), - ), - mvitv2_base=MultiScaleVitCfg( - depths=(2, 3, 16, 3), - ), - mvitv2_large=MultiScaleVitCfg( - depths=(2, 6, 36, 4), - embed_dim=144, - num_heads=2, - expand_attn=False, - ), - - mvitv2_base_in21k=MultiScaleVitCfg( - depths=(2, 3, 16, 3), - ), - mvitv2_large_in21k=MultiScaleVitCfg( - depths=(2, 6, 36, 4), - embed_dim=144, - num_heads=2, - expand_attn=False, - ), - - mvitv2_small_cls=MultiScaleVitCfg( - depths=(1, 2, 11, 2), - use_cls_token=True, - ), -) - - def prod(iterable): return reduce(operator.mul, iterable, 1) @@ -229,26 +163,32 @@ def cal_rel_pos_type( # Scale up rel pos if shapes for q and k are different. q_h_ratio = max(k_h / q_h, 1.0) k_h_ratio = max(q_h / k_h, 1.0) - dist_h = torch.arange(q_h)[:, None] * q_h_ratio - torch.arange(k_h)[None, :] * k_h_ratio + dist_h = ( + torch.arange(q_h, device=q.device).unsqueeze(-1) * q_h_ratio - + torch.arange(k_h, device=q.device).unsqueeze(0) * k_h_ratio + ) dist_h += (k_h - 1) * k_h_ratio q_w_ratio = max(k_w / q_w, 1.0) k_w_ratio = max(q_w / k_w, 1.0) - dist_w = torch.arange(q_w)[:, None] * q_w_ratio - torch.arange(k_w)[None, :] * k_w_ratio + dist_w = ( + torch.arange(q_w, device=q.device).unsqueeze(-1) * q_w_ratio - + torch.arange(k_w, device=q.device).unsqueeze(0) * k_w_ratio + ) dist_w += (k_w - 1) * k_w_ratio - Rh = rel_pos_h[dist_h.long()] - Rw = rel_pos_w[dist_w.long()] + rel_h = rel_pos_h[dist_h.long()] + rel_w = rel_pos_w[dist_w.long()] B, n_head, q_N, dim = q.shape r_q = q[:, :, sp_idx:].reshape(B, n_head, q_h, q_w, dim) - rel_h = torch.einsum("byhwc,hkc->byhwk", r_q, Rh) - rel_w = torch.einsum("byhwc,wkc->byhwk", r_q, Rw) + rel_h = torch.einsum("byhwc,hkc->byhwk", r_q, rel_h) + rel_w = torch.einsum("byhwc,wkc->byhwk", r_q, rel_w) attn[:, :, sp_idx:, sp_idx:] = ( attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_h, q_w, k_h, k_w) - + rel_h[:, :, :, :, :, None] - + rel_w[:, :, :, :, None, :] + + rel_h.unsqueeze(-1) + + rel_w.unsqueeze(-2) ).view(B, -1, q_h * q_w, k_h * k_w) return attn @@ -390,18 +330,18 @@ class MultiScaleAttentionPoolFirst(nn.Module): v = self.norm_v(v) q_N = q_size[0] * q_size[1] + int(self.has_cls_token) - q = q.permute(0, 2, 1, 3).reshape(B, q_N, -1) - q = self.q(q).reshape(B, q_N, self.num_heads, -1).permute(0, 2, 1, 3) + q = q.transpose(1, 2).reshape(B, q_N, -1) + q = self.q(q).reshape(B, q_N, self.num_heads, -1).transpose(1, 2) k_N = k_size[0] * k_size[1] + int(self.has_cls_token) - k = k.permute(0, 2, 1, 3).reshape(B, k_N, -1) - k = self.k(k).reshape(B, k_N, self.num_heads, -1).permute(0, 2, 1, 3) + k = k.transpose(1, 2).reshape(B, k_N, -1) + k = self.k(k).reshape(B, k_N, self.num_heads, -1) v_N = v_size[0] * v_size[1] + int(self.has_cls_token) - v = v.permute(0, 2, 1, 3).reshape(B, v_N, -1) - v = self.v(v).reshape(B, v_N, self.num_heads, -1).permute(0, 2, 1, 3) + v = v.transpose(1, 2).reshape(B, v_N, -1) + v = self.v(v).reshape(B, v_N, self.num_heads, -1).transpose(1, 2) - attn = (q * self.scale) @ k.transpose(-2, -1) + attn = (q * self.scale) @ k if self.rel_pos_type == 'spatial': attn = cal_rel_pos_type( attn, @@ -764,7 +704,7 @@ class MultiScaleVit(nn.Module): cfg: MultiScaleVitCfg, img_size: Tuple[int, int] = (224, 224), in_chans: int = 3, - global_pool: str = 'avg', + global_pool: Optional[str] = None, num_classes: int = 1000, drop_path_rate: float = 0., drop_rate: float = 0., @@ -774,6 +714,8 @@ class MultiScaleVit(nn.Module): norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) self.num_classes = num_classes self.drop_rate = drop_rate + if global_pool is None: + global_pool = 'token' if cfg.use_cls_token else 'avg' self.global_pool = global_pool self.depths = tuple(cfg.depths) self.expand_attn = cfg.expand_attn @@ -963,13 +905,89 @@ def checkpoint_filter_fn(state_dict, model): return out_dict +model_cfgs = dict( + mvitv2_tiny=MultiScaleVitCfg( + depths=(1, 2, 5, 2), + ), + mvitv2_small=MultiScaleVitCfg( + depths=(1, 2, 11, 2), + ), + mvitv2_base=MultiScaleVitCfg( + depths=(2, 3, 16, 3), + ), + mvitv2_large=MultiScaleVitCfg( + depths=(2, 6, 36, 4), + embed_dim=144, + num_heads=2, + expand_attn=False, + ), + + mvitv2_small_cls=MultiScaleVitCfg( + depths=(1, 2, 11, 2), + use_cls_token=True, + ), + mvitv2_base_cls=MultiScaleVitCfg( + depths=(2, 3, 16, 3), + use_cls_token=True, + ), + mvitv2_large_cls=MultiScaleVitCfg( + depths=(2, 6, 36, 4), + embed_dim=144, + num_heads=2, + use_cls_token=True, + expand_attn=True, + ), + mvitv2_huge_cls=MultiScaleVitCfg( + depths=(4, 8, 60, 8), + embed_dim=192, + num_heads=3, + use_cls_token=True, + expand_attn=True, + ), +) + + def _create_mvitv2(variant, cfg_variant=None, pretrained=False, **kwargs): return build_model_with_cfg( - MultiScaleVit, variant, pretrained, + MultiScaleVit, + variant, + pretrained, model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant], pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(flatten_sequential=True), - **kwargs) + **kwargs, + ) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc', + 'fixed_input_size': True, + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'mvitv2_tiny.fb_in1k': _cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_T_in1k.pyth'), + 'mvitv2_small.fb_in1k': _cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_S_in1k.pyth'), + 'mvitv2_base.fb_in1k': _cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_B_in1k.pyth'), + 'mvitv2_large.fb_in1k': _cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_L_in1k.pyth'), + + 'mvitv2_small_cls': _cfg(url=''), + 'mvitv2_base_cls.fb_inw21k': _cfg( + url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_B_in21k.pyth', + num_classes=19168), + 'mvitv2_large_cls.fb_inw21k': _cfg( + url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_L_in21k.pyth', + num_classes=19168), + 'mvitv2_huge_cls.fb_inw21k': _cfg( + url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_H_in21k.pyth', + num_classes=19168), +}) @register_model @@ -992,21 +1010,21 @@ def mvitv2_large(pretrained=False, **kwargs): return _create_mvitv2('mvitv2_large', pretrained=pretrained, **kwargs) -# @register_model -# def mvitv2_base_in21k(pretrained=False, **kwargs): -# return _create_mvitv2('mvitv2_base_in21k', pretrained=pretrained, **kwargs) -# -# -# @register_model -# def mvitv2_large_in21k(pretrained=False, **kwargs): -# return _create_mvitv2('mvitv2_large_in21k', pretrained=pretrained, **kwargs) -# -# -# @register_model -# def mvitv2_huge_in21k(pretrained=False, **kwargs): -# return _create_mvitv2('mvitv2_huge_in21k', pretrained=pretrained, **kwargs) - - @register_model def mvitv2_small_cls(pretrained=False, **kwargs): return _create_mvitv2('mvitv2_small_cls', pretrained=pretrained, **kwargs) + + +@register_model +def mvitv2_base_cls(pretrained=False, **kwargs): + return _create_mvitv2('mvitv2_base_cls', pretrained=pretrained, **kwargs) + + +@register_model +def mvitv2_large_cls(pretrained=False, **kwargs): + return _create_mvitv2('mvitv2_large_cls', pretrained=pretrained, **kwargs) + + +@register_model +def mvitv2_huge_cls(pretrained=False, **kwargs): + return _create_mvitv2('mvitv2_huge_cls', pretrained=pretrained, **kwargs) diff --git a/timm/models/nest.py b/timm/models/nest.py index c9c6258c..681593df 100644 --- a/timm/models/nest.py +++ b/timm/models/nest.py @@ -104,15 +104,25 @@ class TransformerLayer(nn.Module): - Called TransformerLayer here to allow for "block" as defined in the paper ("non-overlapping image blocks") - Uses modified Attention layer that handles the "block" dimension """ - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + proj_drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): super().__init__() self.norm1 = norm_layer(dim) - self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=proj_drop) def forward(self, x): y = self.norm1(x) @@ -176,9 +186,23 @@ class NestLevel(nn.Module): """ Single hierarchical level of a Nested Transformer """ def __init__( - self, num_blocks, block_size, seq_length, num_heads, depth, embed_dim, prev_embed_dim=None, - mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rates=[], - norm_layer=None, act_layer=None, pad_type=''): + self, + num_blocks, + block_size, + seq_length, + num_heads, + depth, + embed_dim, + prev_embed_dim=None, + mlp_ratio=4., + qkv_bias=True, + proj_drop=0., + attn_drop=0., + drop_path=[], + norm_layer=None, + act_layer=None, + pad_type='', + ): super().__init__() self.block_size = block_size self.grad_checkpointing = False @@ -191,13 +215,20 @@ class NestLevel(nn.Module): self.pool = nn.Identity() # Transformer encoder - if len(drop_path_rates): - assert len(drop_path_rates) == depth, 'Must provide as many drop path rates as there are transformer layers' + if len(drop_path): + assert len(drop_path) == depth, 'Must provide as many drop path rates as there are transformer layers' self.transformer_encoder = nn.Sequential(*[ TransformerLayer( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rates[i], - norm_layer=norm_layer, act_layer=act_layer) + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_drop=proj_drop, + attn_drop=attn_drop, + drop_path=drop_path[i], + norm_layer=norm_layer, + act_layer=act_layer, + ) for i in range(depth)]) def forward(self, x): @@ -225,10 +256,26 @@ class Nest(nn.Module): """ def __init__( - self, img_size=224, in_chans=3, patch_size=4, num_levels=3, embed_dims=(128, 256, 512), - num_heads=(4, 8, 16), depths=(2, 2, 20), num_classes=1000, mlp_ratio=4., qkv_bias=True, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.5, norm_layer=None, act_layer=None, - pad_type='', weight_init='', global_pool='avg' + self, + img_size=224, + in_chans=3, + patch_size=4, + num_levels=3, + embed_dims=(128, 256, 512), + num_heads=(4, 8, 16), + depths=(2, 2, 20), + num_classes=1000, + mlp_ratio=4., + qkv_bias=True, + drop_rate=0., + proj_drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.5, + norm_layer=None, + act_layer=None, + pad_type='', + weight_init='', + global_pool='avg', ): """ Args: @@ -292,7 +339,12 @@ class Nest(nn.Module): # Patch embedding self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0], flatten=False) + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dims[0], + flatten=False, + ) self.num_patches = self.patch_embed.num_patches self.seq_length = self.num_patches // self.num_blocks[0] @@ -304,8 +356,22 @@ class Nest(nn.Module): for i in range(len(self.num_blocks)): dim = embed_dims[i] levels.append(NestLevel( - self.num_blocks[i], self.block_size, self.seq_length, num_heads[i], depths[i], dim, prev_dim, - mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, dp_rates[i], norm_layer, act_layer, pad_type=pad_type)) + self.num_blocks[i], + self.block_size, + self.seq_length, + num_heads[i], + depths[i], + dim, + prev_dim, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dp_rates[i], + norm_layer=norm_layer, + act_layer=act_layer, + pad_type=pad_type, + )) self.feature_info += [dict(num_chs=dim, reduction=curr_stride, module=f'levels.{i}')] prev_dim = dim curr_stride *= 2 @@ -315,7 +381,10 @@ class Nest(nn.Module): self.norm = norm_layer(embed_dims[-1]) # Classifier - self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + global_pool, head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + self.global_pool = global_pool + self.head_drop = nn.Dropout(drop_rate) + self.head = head self.init_weights(weight_init) @@ -366,8 +435,7 @@ class Nest(nn.Module): def forward_head(self, x, pre_logits: bool = False): x = self.global_pool(x) - if self.drop_rate > 0.: - x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.head_drop(x) return x if pre_logits else self.head(x) def forward(self, x): diff --git a/timm/models/pit.py b/timm/models/pit.py index 4f40e5e0..17dc679c 100644 --- a/timm/models/pit.py +++ b/timm/models/pit.py @@ -78,7 +78,16 @@ class SequentialTuple(nn.Sequential): class Transformer(nn.Module): def __init__( - self, base_dim, depth, heads, mlp_ratio, pool=None, drop_rate=.0, attn_drop_rate=.0, drop_path_prob=None): + self, + base_dim, + depth, + heads, + mlp_ratio, + pool=None, + proj_drop=.0, + attn_drop=.0, + drop_path_prob=None, + ): super(Transformer, self).__init__() self.layers = nn.ModuleList([]) embed_dim = base_dim * heads @@ -89,8 +98,8 @@ class Transformer(nn.Module): num_heads=heads, mlp_ratio=mlp_ratio, qkv_bias=True, - drop=drop_rate, - attn_drop=attn_drop_rate, + proj_drop=proj_drop, + attn_drop=attn_drop, drop_path=drop_path_prob[i], norm_layer=partial(nn.LayerNorm, eps=1e-6) ) @@ -122,8 +131,14 @@ class ConvHeadPooling(nn.Module): super(ConvHeadPooling, self).__init__() self.conv = nn.Conv2d( - in_feature, out_feature, kernel_size=stride + 1, padding=stride // 2, stride=stride, - padding_mode=padding_mode, groups=in_feature) + in_feature, + out_feature, + kernel_size=stride + 1, + padding=stride // 2, + stride=stride, + padding_mode=padding_mode, + groups=in_feature, + ) self.fc = nn.Linear(in_feature, out_feature) def forward(self, x, cls_token) -> Tuple[torch.Tensor, torch.Tensor]: @@ -150,9 +165,24 @@ class PoolingVisionTransformer(nn.Module): - https://arxiv.org/abs/2103.16302 """ def __init__( - self, img_size, patch_size, stride, base_dims, depth, heads, - mlp_ratio, num_classes=1000, in_chans=3, global_pool='token', - distilled=False, attn_drop_rate=.0, drop_rate=.0, drop_path_rate=.0): + self, + img_size, + patch_size, + stride, + base_dims, + depth, + heads, + mlp_ratio, + num_classes=1000, + in_chans=3, + global_pool='token', + distilled=False, + drop_rate=0., + pos_drop_drate=0., + proj_drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + ): super(PoolingVisionTransformer, self).__init__() assert global_pool in ('token',) @@ -173,7 +203,7 @@ class PoolingVisionTransformer(nn.Module): self.patch_embed = ConvEmbedding(in_chans, base_dims[0] * heads[0], patch_size, stride, padding) self.cls_token = nn.Parameter(torch.randn(1, self.num_tokens, base_dims[0] * heads[0])) - self.pos_drop = nn.Dropout(p=drop_rate) + self.pos_drop = nn.Dropout(p=pos_drop_drate) transformers = [] # stochastic depth decay rule @@ -182,16 +212,27 @@ class PoolingVisionTransformer(nn.Module): pool = None if stage < len(heads) - 1: pool = ConvHeadPooling( - base_dims[stage] * heads[stage], base_dims[stage + 1] * heads[stage + 1], stride=2) + base_dims[stage] * heads[stage], + base_dims[stage + 1] * heads[stage + 1], + stride=2, + ) transformers += [Transformer( - base_dims[stage], depth[stage], heads[stage], mlp_ratio, pool=pool, - drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_prob=dpr[stage]) + base_dims[stage], + depth[stage], + heads[stage], + mlp_ratio, + pool=pool, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path_prob=dpr[stage], + ) ] self.transformers = SequentialTuple(*transformers) self.norm = nn.LayerNorm(base_dims[-1] * heads[-1], eps=1e-6) self.num_features = self.embed_dim = base_dims[-1] * heads[-1] # Classifier head + self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.head_dist = None if distilled: @@ -243,6 +284,8 @@ class PoolingVisionTransformer(nn.Module): if self.head_dist is not None: assert self.global_pool == 'token' x, x_dist = x[:, 0], x[:, 1] + x = self.head_drop(x) + x_dist = self.head_drop(x) if not pre_logits: x = self.head(x) x_dist = self.head_dist(x_dist) @@ -255,6 +298,7 @@ class PoolingVisionTransformer(nn.Module): else: if self.global_pool == 'token': x = x[:, 0] + x = self.head_drop(x) if not pre_logits: x = self.head(x) return x @@ -284,71 +328,70 @@ def _create_pit(variant, pretrained=False, **kwargs): raise RuntimeError('features_only not implemented for Vision Transformer models.') model = build_model_with_cfg( - PoolingVisionTransformer, variant, pretrained, + PoolingVisionTransformer, + variant, + pretrained, pretrained_filter_fn=checkpoint_filter_fn, - **kwargs) + **kwargs, + ) return model @register_model def pit_b_224(pretrained, **kwargs): - model_kwargs = dict( + model_args = dict( patch_size=14, stride=7, base_dims=[64, 64, 64], depth=[3, 6, 4], heads=[4, 8, 16], mlp_ratio=4, - **kwargs ) - return _create_pit('pit_b_224', pretrained, **model_kwargs) + return _create_pit('pit_b_224', pretrained, **dict(model_args, **kwargs)) @register_model def pit_s_224(pretrained, **kwargs): - model_kwargs = dict( + model_args = dict( patch_size=16, stride=8, base_dims=[48, 48, 48], depth=[2, 6, 4], heads=[3, 6, 12], mlp_ratio=4, - **kwargs ) - return _create_pit('pit_s_224', pretrained, **model_kwargs) + return _create_pit('pit_s_224', pretrained, **dict(model_args, **kwargs)) @register_model def pit_xs_224(pretrained, **kwargs): - model_kwargs = dict( + model_args = dict( patch_size=16, stride=8, base_dims=[48, 48, 48], depth=[2, 6, 4], heads=[2, 4, 8], mlp_ratio=4, - **kwargs ) - return _create_pit('pit_xs_224', pretrained, **model_kwargs) + return _create_pit('pit_xs_224', pretrained, **dict(model_args, **kwargs)) @register_model def pit_ti_224(pretrained, **kwargs): - model_kwargs = dict( + model_args = dict( patch_size=16, stride=8, base_dims=[32, 32, 32], depth=[2, 6, 4], heads=[2, 4, 8], mlp_ratio=4, - **kwargs ) - return _create_pit('pit_ti_224', pretrained, **model_kwargs) + return _create_pit('pit_ti_224', pretrained, **dict(model_args, **kwargs)) @register_model def pit_b_distilled_224(pretrained, **kwargs): - model_kwargs = dict( + model_args = dict( patch_size=14, stride=7, base_dims=[64, 64, 64], @@ -356,14 +399,13 @@ def pit_b_distilled_224(pretrained, **kwargs): heads=[4, 8, 16], mlp_ratio=4, distilled=True, - **kwargs ) - return _create_pit('pit_b_distilled_224', pretrained, **model_kwargs) + return _create_pit('pit_b_distilled_224', pretrained, **dict(model_args, **kwargs)) @register_model def pit_s_distilled_224(pretrained, **kwargs): - model_kwargs = dict( + model_args = dict( patch_size=16, stride=8, base_dims=[48, 48, 48], @@ -371,14 +413,13 @@ def pit_s_distilled_224(pretrained, **kwargs): heads=[3, 6, 12], mlp_ratio=4, distilled=True, - **kwargs ) - return _create_pit('pit_s_distilled_224', pretrained, **model_kwargs) + return _create_pit('pit_s_distilled_224', pretrained, **dict(model_args, **kwargs)) @register_model def pit_xs_distilled_224(pretrained, **kwargs): - model_kwargs = dict( + model_args = dict( patch_size=16, stride=8, base_dims=[48, 48, 48], @@ -386,14 +427,13 @@ def pit_xs_distilled_224(pretrained, **kwargs): heads=[2, 4, 8], mlp_ratio=4, distilled=True, - **kwargs ) - return _create_pit('pit_xs_distilled_224', pretrained, **model_kwargs) + return _create_pit('pit_xs_distilled_224', pretrained, **dict(model_args, **kwargs)) @register_model def pit_ti_distilled_224(pretrained, **kwargs): - model_kwargs = dict( + model_args = dict( patch_size=16, stride=8, base_dims=[32, 32, 32], @@ -401,6 +441,5 @@ def pit_ti_distilled_224(pretrained, **kwargs): heads=[2, 4, 8], mlp_ratio=4, distilled=True, - **kwargs ) - return _create_pit('pit_ti_distilled_224', pretrained, **model_kwargs) \ No newline at end of file + return _create_pit('pit_ti_distilled_224', pretrained, **dict(model_args, **kwargs)) diff --git a/timm/models/poolformer.py b/timm/models/poolformer.py index b4d2d18f..7e101bc4 100644 --- a/timm/models/poolformer.py +++ b/timm/models/poolformer.py @@ -103,9 +103,16 @@ class PoolFormerBlock(nn.Module): """ def __init__( - self, dim, pool_size=3, mlp_ratio=4., - act_layer=nn.GELU, norm_layer=GroupNorm1, - drop=0., drop_path=0., layer_scale_init_value=1e-5): + self, + dim, + pool_size=3, + mlp_ratio=4., + act_layer=nn.GELU, + norm_layer=GroupNorm1, + drop=0., + drop_path=0., + layer_scale_init_value=1e-5, + ): super().__init__() @@ -116,7 +123,7 @@ class PoolFormerBlock(nn.Module): self.mlp = ConvMlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - if layer_scale_init_value: + if layer_scale_init_value is not None: self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones(dim)) self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones(dim)) else: @@ -134,10 +141,15 @@ class PoolFormerBlock(nn.Module): def basic_blocks( - dim, index, layers, - pool_size=3, mlp_ratio=4., - act_layer=nn.GELU, norm_layer=GroupNorm1, - drop_rate=.0, drop_path_rate=0., + dim, + index, + layers, + pool_size=3, + mlp_ratio=4., + act_layer=nn.GELU, + norm_layer=GroupNorm1, + drop_rate=.0, + drop_path_rate=0., layer_scale_init_value=1e-5, ): """ generate PoolFormer blocks for a stage """ @@ -145,9 +157,13 @@ def basic_blocks( for block_idx in range(layers[index]): block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1) blocks.append(PoolFormerBlock( - dim, pool_size=pool_size, mlp_ratio=mlp_ratio, - act_layer=act_layer, norm_layer=norm_layer, - drop=drop_rate, drop_path=block_dpr, + dim, + pool_size=pool_size, + mlp_ratio=mlp_ratio, + act_layer=act_layer, + norm_layer=norm_layer, + drop=drop_rate, + drop_path=block_dpr, layer_scale_init_value=layer_scale_init_value, )) blocks = nn.Sequential(*blocks) @@ -176,9 +192,12 @@ class PoolFormer(nn.Module): down_patch_size=3, down_stride=2, down_pad=1, - drop_rate=0., drop_path_rate=0., + drop_rate=0., + proj_drop_rate=0., + drop_path_rate=0., layer_scale_init_value=1e-5, - **kwargs): + **kwargs, + ): super().__init__() self.num_classes = num_classes @@ -187,28 +206,42 @@ class PoolFormer(nn.Module): self.grad_checkpointing = False self.patch_embed = PatchEmbed( - patch_size=in_patch_size, stride=in_stride, padding=in_pad, - in_chs=in_chans, embed_dim=embed_dims[0]) + patch_size=in_patch_size, + stride=in_stride, + padding=in_pad, + in_chs=in_chans, + embed_dim=embed_dims[0], + ) # set the main block in network network = [] for i in range(len(layers)): network.append(basic_blocks( - embed_dims[i], i, layers, - pool_size=pool_size, mlp_ratio=mlp_ratios[i], - act_layer=act_layer, norm_layer=norm_layer, - drop_rate=drop_rate, drop_path_rate=drop_path_rate, - layer_scale_init_value=layer_scale_init_value) - ) + embed_dims[i], + i, + layers, + pool_size=pool_size, + mlp_ratio=mlp_ratios[i], + act_layer=act_layer, + norm_layer=norm_layer, + drop_rate=proj_drop_rate, + drop_path_rate=drop_path_rate, + layer_scale_init_value=layer_scale_init_value + )) if i < len(layers) - 1 and (downsamples[i] or embed_dims[i] != embed_dims[i + 1]): # downsampling between stages network.append(PatchEmbed( - in_chs=embed_dims[i], embed_dim=embed_dims[i + 1], - patch_size=down_patch_size, stride=down_stride, padding=down_pad) - ) + in_chs=embed_dims[i], + embed_dim=embed_dims[i + 1], + patch_size=down_patch_size, + stride=down_stride, + padding=down_pad, + )) self.network = nn.Sequential(*network) self.norm = norm_layer(self.num_features) + + self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.apply(self._init_weights) @@ -254,6 +287,7 @@ class PoolFormer(nn.Module): def forward_head(self, x, pre_logits: bool = False): if self.global_pool == 'avg': x = x.mean([-2, -1]) + x = self.head_drop(x) return x if pre_logits else self.head(x) def forward(self, x): diff --git a/timm/models/pvt_v2.py b/timm/models/pvt_v2.py index 696a2506..d230e788 100644 --- a/timm/models/pvt_v2.py +++ b/timm/models/pvt_v2.py @@ -54,8 +54,14 @@ default_cfgs = { class MlpWithDepthwiseConv(nn.Module): def __init__( - self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, - drop=0., extra_relu=False): + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0., + extra_relu=False, + ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features @@ -154,8 +160,19 @@ class Attention(nn.Module): class Block(nn.Module): def __init__( - self, dim, num_heads, mlp_ratio=4., sr_ratio=1, linear_attn=False, qkv_bias=False, - drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + self, + dim, + num_heads, + mlp_ratio=4., + sr_ratio=1, + linear_attn=False, + qkv_bias=False, + proj_drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( @@ -165,7 +182,7 @@ class Block(nn.Module): linear_attn=linear_attn, qkv_bias=qkv_bias, attn_drop=attn_drop, - proj_drop=drop, + proj_drop=proj_drop, ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) @@ -173,8 +190,8 @@ class Block(nn.Module): in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, - drop=drop, - extra_relu=linear_attn + drop=proj_drop, + extra_relu=linear_attn, ) def forward(self, x, feat_size: List[int]): @@ -193,8 +210,8 @@ class OverlapPatchEmbed(nn.Module): assert max(patch_size) > stride, "Set larger patch_size than stride" self.patch_size = patch_size self.proj = nn.Conv2d( - in_chans, embed_dim, kernel_size=patch_size, stride=stride, - padding=(patch_size[0] // 2, patch_size[1] // 2)) + in_chans, embed_dim, patch_size, + stride=stride, padding=(patch_size[0] // 2, patch_size[1] // 2)) self.norm = nn.LayerNorm(embed_dim) def forward(self, x): @@ -217,7 +234,7 @@ class PyramidVisionTransformerStage(nn.Module): linear_attn: bool = False, mlp_ratio: float = 4.0, qkv_bias: bool = True, - drop: float = 0., + proj_drop: float = 0., attn_drop: float = 0., drop_path: Union[List[float], float] = 0.0, norm_layer: Callable = nn.LayerNorm, @@ -242,7 +259,7 @@ class PyramidVisionTransformerStage(nn.Module): linear_attn=linear_attn, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, - drop=drop, + proj_drop=proj_drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, @@ -278,6 +295,7 @@ class PyramidVisionTransformerV2(nn.Module): qkv_bias=True, linear=False, drop_rate=0., + proj_drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, @@ -314,16 +332,17 @@ class PyramidVisionTransformerV2(nn.Module): mlp_ratio=mlp_ratios[i], linear_attn=linear, qkv_bias=qkv_bias, - drop=drop_rate, + proj_drop=proj_drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], - norm_layer=norm_layer + norm_layer=norm_layer, )) prev_dim = embed_dims[i] cur += depths[i] # classification head self.num_features = embed_dims[-1] + self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() self.apply(self._init_weights) @@ -379,6 +398,7 @@ class PyramidVisionTransformerV2(nn.Module): def forward_head(self, x, pre_logits: bool = False): if self.global_pool: x = x.mean(dim=(-1, -2)) + x = self.head_drop(x) return x if pre_logits else self.head(x) def forward(self, x): diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index f4621614..f4df31c6 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -181,7 +181,7 @@ class SwinTransformerBlock(nn.Module): shift_size: int = 0, mlp_ratio: float = 4., qkv_bias: bool = True, - drop: float = 0., + proj_drop: float = 0., attn_drop: float = 0., drop_path: float = 0., act_layer: Callable = nn.GELU, @@ -197,7 +197,7 @@ class SwinTransformerBlock(nn.Module): shift_size: Shift size for SW-MSA. mlp_ratio: Ratio of mlp hidden dim to embedding dim. qkv_bias: If True, add a learnable bias to query, key, value. - drop: Dropout rate. + proj_drop: Dropout rate. attn_drop: Attention dropout rate. drop_path: Stochastic depth rate. act_layer: Activation layer. @@ -223,7 +223,7 @@ class SwinTransformerBlock(nn.Module): window_size=to_2tuple(self.window_size), qkv_bias=qkv_bias, attn_drop=attn_drop, - proj_drop=drop, + proj_drop=proj_drop, ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -232,7 +232,7 @@ class SwinTransformerBlock(nn.Module): in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, - drop=drop, + drop=proj_drop, ) if self.shift_size > 0: @@ -346,7 +346,7 @@ class SwinTransformerStage(nn.Module): window_size: _int_or_tuple_2_t = 7, mlp_ratio: float = 4., qkv_bias: bool = True, - drop: float = 0., + proj_drop: float = 0., attn_drop: float = 0., drop_path: Union[List[float], float] = 0., norm_layer: Callable = nn.LayerNorm, @@ -362,7 +362,7 @@ class SwinTransformerStage(nn.Module): window_size: Local window size. mlp_ratio: Ratio of mlp hidden dim to embedding dim. qkv_bias: If True, add a learnable bias to query, key, value. - drop: Dropout rate. + proj_drop: Projection dropout rate. attn_drop: Attention dropout rate. drop_path: Stochastic depth rate. norm_layer: Normalization layer. @@ -396,7 +396,7 @@ class SwinTransformerStage(nn.Module): shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, - drop=drop, + proj_drop=proj_drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, @@ -435,6 +435,7 @@ class SwinTransformer(nn.Module): mlp_ratio: float = 4., qkv_bias: bool = True, drop_rate: float = 0., + proj_drop_rate: float = 0., attn_drop_rate: float = 0., drop_path_rate: float = 0.1, norm_layer: Union[str, Callable] = nn.LayerNorm, @@ -508,7 +509,7 @@ class SwinTransformer(nn.Module): window_size=window_size[i], mlp_ratio=mlp_ratio[i], qkv_bias=qkv_bias, - drop=drop_rate, + proj_drop=proj_drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index cddb68a0..f9aef98e 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -203,7 +203,7 @@ class SwinTransformerV2Block(nn.Module): shift_size=0, mlp_ratio=4., qkv_bias=True, - drop=0., + proj_drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, @@ -219,7 +219,7 @@ class SwinTransformerV2Block(nn.Module): shift_size: Shift size for SW-MSA. mlp_ratio: Ratio of mlp hidden dim to embedding dim. qkv_bias: If True, add a learnable bias to query, key, value. - drop: Dropout rate. + proj_drop: Dropout rate. attn_drop: Attention dropout rate. drop_path: Stochastic depth rate. act_layer: Activation layer. @@ -242,7 +242,7 @@ class SwinTransformerV2Block(nn.Module): num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, - proj_drop=drop, + proj_drop=proj_drop, pretrained_window_size=to_2tuple(pretrained_window_size), ) self.norm1 = norm_layer(dim) @@ -252,7 +252,7 @@ class SwinTransformerV2Block(nn.Module): in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, - drop=drop, + drop=proj_drop, ) self.norm2 = norm_layer(dim) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -367,7 +367,7 @@ class SwinTransformerV2Stage(nn.Module): downsample=False, mlp_ratio=4., qkv_bias=True, - drop=0., + proj_drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, @@ -384,7 +384,7 @@ class SwinTransformerV2Stage(nn.Module): downsample: Use downsample layer at start of the block. mlp_ratio: Ratio of mlp hidden dim to embedding dim. qkv_bias: If True, add a learnable bias to query, key, value. - drop: Dropout rate + proj_drop: Projection dropout rate attn_drop: Attention dropout rate. drop_path: Stochastic depth rate. norm_layer: Normalization layer. @@ -416,7 +416,7 @@ class SwinTransformerV2Stage(nn.Module): shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, - drop=drop, + proj_drop=proj_drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, @@ -463,6 +463,7 @@ class SwinTransformerV2(nn.Module): mlp_ratio: float = 4., qkv_bias: bool = True, drop_rate: float = 0., + proj_drop_rate: float = 0., attn_drop_rate: float = 0., drop_path_rate: float = 0.1, norm_layer: Callable = nn.LayerNorm, @@ -481,7 +482,8 @@ class SwinTransformerV2(nn.Module): window_size: Window size. mlp_ratio: Ratio of mlp hidden dim to embedding dim. qkv_bias: If True, add a learnable bias to query, key, value. - drop_rate: Dropout rate. + drop_rate: Head dropout rate. + proj_drop_rate: Projection dropout rate. attn_drop_rate: Attention dropout rate. drop_path_rate: Stochastic depth rate. norm_layer: Normalization layer. @@ -531,7 +533,8 @@ class SwinTransformerV2(nn.Module): window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, - drop=drop_rate, attn_drop=attn_drop_rate, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, pretrained_window_size=pretrained_window_sizes[i], diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index 178617d3..2b7dbf16 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -221,7 +221,7 @@ class SwinTransformerV2CrBlock(nn.Module): window_size (Tuple[int, int]): Window size to be utilized shift_size (int): Shifting size to be used mlp_ratio (int): Ratio of the hidden dimension in the FFN to the input channels - drop (float): Dropout in input mapping + proj_drop (float): Dropout in input mapping drop_attn (float): Dropout rate of attention map drop_path (float): Dropout in main path extra_norm (bool): Insert extra norm on 'main' branch if True @@ -238,7 +238,7 @@ class SwinTransformerV2CrBlock(nn.Module): shift_size: Tuple[int, int] = (0, 0), mlp_ratio: float = 4.0, init_values: Optional[float] = 0, - drop: float = 0.0, + proj_drop: float = 0.0, drop_attn: float = 0.0, drop_path: float = 0.0, extra_norm: bool = False, @@ -259,7 +259,7 @@ class SwinTransformerV2CrBlock(nn.Module): num_heads=num_heads, window_size=self.window_size, drop_attn=drop_attn, - drop_proj=drop, + drop_proj=proj_drop, sequential_attn=sequential_attn, ) self.norm1 = norm_layer(dim) @@ -269,7 +269,7 @@ class SwinTransformerV2CrBlock(nn.Module): self.mlp = Mlp( in_features=dim, hidden_features=int(dim * mlp_ratio), - drop=drop, + drop=proj_drop, out_features=dim, ) self.norm2 = norm_layer(dim) @@ -445,7 +445,7 @@ class SwinTransformerV2CrStage(nn.Module): num_heads (int): Number of attention heads to be utilized window_size (int): Window size to be utilized mlp_ratio (int): Ratio of the hidden dimension in the FFN to the input channels - drop (float): Dropout in input mapping + proj_drop (float): Dropout in input mapping drop_attn (float): Dropout rate of attention map drop_path (float): Dropout in main path norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm @@ -464,7 +464,7 @@ class SwinTransformerV2CrStage(nn.Module): window_size: Tuple[int, int], mlp_ratio: float = 4.0, init_values: Optional[float] = 0.0, - drop: float = 0.0, + proj_drop: float = 0.0, drop_attn: float = 0.0, drop_path: Union[List[float], float] = 0.0, norm_layer: Type[nn.Module] = nn.LayerNorm, @@ -498,7 +498,7 @@ class SwinTransformerV2CrStage(nn.Module): shift_size=tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size]), mlp_ratio=mlp_ratio, init_values=init_values, - drop=drop, + proj_drop=proj_drop, drop_attn=drop_attn, drop_path=drop_path[index] if isinstance(drop_path, list) else drop_path, extra_norm=_extra_norm(index), @@ -546,23 +546,24 @@ class SwinTransformerV2Cr(nn.Module): https://arxiv.org/pdf/2111.09883 Args: - img_size (Tuple[int, int]): Input resolution. - window_size (Optional[int]): Window size. If None, img_size // window_div. Default: None - img_window_ratio (int): Window size to image size ratio. Default: 32 - patch_size (int | tuple(int)): Patch size. Default: 4 - in_chans (int): Number of input channels. - depths (int): Depth of the stage (number of layers). - num_heads (int): Number of attention heads to be utilized. - embed_dim (int): Patch embedding dimension. Default: 96 - num_classes (int): Number of output classes. Default: 1000 - mlp_ratio (int): Ratio of the hidden dimension in the FFN to the input channels. Default: 4 - drop_rate (float): Dropout rate. Default: 0.0 - attn_drop_rate (float): Dropout rate of attention map. Default: 0.0 - drop_path_rate (float): Stochastic depth rate. Default: 0.0 - norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm - extra_norm_period (int): Insert extra norm layer on main branch every N (period) blocks in stage - extra_norm_stage (bool): End each stage with an extra norm layer in main branch - sequential_attn (bool): If true sequential self-attention is performed. Default: False + img_size: Input resolution. + window_size: Window size. If None, img_size // window_div + img_window_ratio: Window size to image size ratio. + patch_size: Patch size. + in_chans: Number of input channels. + depths: Depth of the stage (number of layers). + num_heads: Number of attention heads to be utilized. + embed_dim: Patch embedding dimension. + num_classes: Number of output classes. + mlp_ratio: Ratio of the hidden dimension in the FFN to the input channels. + drop_rate: Dropout rate. + proj_drop_rate: Projection dropout rate. + attn_drop_rate: Dropout rate of attention map. + drop_path_rate: Stochastic depth rate. + norm_layer: Type of normalization layer to be utilized. + extra_norm_period: Insert extra norm layer on main branch every N (period) blocks in stage + extra_norm_stage: End each stage with an extra norm layer in main branch + sequential_attn: If true sequential self-attention is performed. """ def __init__( @@ -579,6 +580,7 @@ class SwinTransformerV2Cr(nn.Module): mlp_ratio: float = 4.0, init_values: Optional[float] = 0., drop_rate: float = 0.0, + proj_drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, norm_layer: Type[nn.Module] = nn.LayerNorm, @@ -627,7 +629,7 @@ class SwinTransformerV2Cr(nn.Module): window_size=window_size, mlp_ratio=mlp_ratio, init_values=init_values, - drop=drop_rate, + proj_drop=proj_drop_rate, drop_attn=attn_drop_rate, drop_path=dpr[stage_idx], extra_norm_period=extra_norm_period, diff --git a/timm/models/tnt.py b/timm/models/tnt.py index 50088baf..c99d2775 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -80,31 +80,64 @@ class Block(nn.Module): """ TNT Block """ def __init__( - self, dim, in_dim, num_pixel, num_heads=12, in_num_head=4, mlp_ratio=4., - qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + self, + dim, + dim_out, + num_pixel, + num_heads_in=4, + num_heads_out=12, + mlp_ratio=4., + qkv_bias=False, + proj_drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): super().__init__() # Inner transformer - self.norm_in = norm_layer(in_dim) + self.norm_in = norm_layer(dim) self.attn_in = Attention( - in_dim, in_dim, num_heads=in_num_head, qkv_bias=qkv_bias, - attn_drop=attn_drop, proj_drop=drop) + dim, + dim, + num_heads=num_heads_in, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + ) - self.norm_mlp_in = norm_layer(in_dim) - self.mlp_in = Mlp(in_features=in_dim, hidden_features=int(in_dim * 4), - out_features=in_dim, act_layer=act_layer, drop=drop) + self.norm_mlp_in = norm_layer(dim) + self.mlp_in = Mlp( + in_features=dim, + hidden_features=int(dim * 4), + out_features=dim, + act_layer=act_layer, + drop=proj_drop, + ) - self.norm1_proj = norm_layer(in_dim) - self.proj = nn.Linear(in_dim * num_pixel, dim, bias=True) + self.norm1_proj = norm_layer(dim) + self.proj = nn.Linear(dim * num_pixel, dim_out, bias=True) + # Outer transformer - self.norm_out = norm_layer(dim) + self.norm_out = norm_layer(dim_out) self.attn_out = Attention( - dim, dim, num_heads=num_heads, qkv_bias=qkv_bias, - attn_drop=attn_drop, proj_drop=drop) + dim_out, + dim_out, + num_heads=num_heads_out, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm_mlp = norm_layer(dim) - self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), - out_features=dim, act_layer=act_layer, drop=drop) + self.norm_mlp = norm_layer(dim_out) + self.mlp = Mlp( + in_features=dim_out, + hidden_features=int(dim_out * mlp_ratio), + out_features=dim_out, + act_layer=act_layer, + drop=proj_drop, + ) def forward(self, pixel_embed, patch_embed): # inner @@ -157,9 +190,27 @@ class TNT(nn.Module): """ Transformer in Transformer - https://arxiv.org/abs/2103.00112 """ def __init__( - self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', - embed_dim=768, in_dim=48, depth=12, num_heads=12, in_num_head=4, mlp_ratio=4., qkv_bias=False, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, first_stride=4): + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + global_pool='token', + embed_dim=768, + inner_dim=48, + depth=12, + num_heads_inner=4, + num_heads_outer=12, + mlp_ratio=4., + qkv_bias=False, + drop_rate=0., + pos_drop_rate=0., + proj_drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_layer=nn.LayerNorm, + first_stride=4, + ): super().__init__() assert global_pool in ('', 'token', 'avg') self.num_classes = num_classes @@ -168,31 +219,46 @@ class TNT(nn.Module): self.grad_checkpointing = False self.pixel_embed = PixelEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, in_dim=in_dim, stride=first_stride) + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + in_dim=inner_dim, + stride=first_stride, + ) num_patches = self.pixel_embed.num_patches self.num_patches = num_patches new_patch_size = self.pixel_embed.new_patch_size num_pixel = new_patch_size[0] * new_patch_size[1] - self.norm1_proj = norm_layer(num_pixel * in_dim) - self.proj = nn.Linear(num_pixel * in_dim, embed_dim) + self.norm1_proj = norm_layer(num_pixel * inner_dim) + self.proj = nn.Linear(num_pixel * inner_dim, embed_dim) self.norm2_proj = norm_layer(embed_dim) self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.patch_pos = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) - self.pixel_pos = nn.Parameter(torch.zeros(1, in_dim, new_patch_size[0], new_patch_size[1])) - self.pos_drop = nn.Dropout(p=drop_rate) + self.pixel_pos = nn.Parameter(torch.zeros(1, inner_dim, new_patch_size[0], new_patch_size[1])) + self.pos_drop = nn.Dropout(p=pos_drop_rate) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule blocks = [] for i in range(depth): blocks.append(Block( - dim=embed_dim, in_dim=in_dim, num_pixel=num_pixel, num_heads=num_heads, in_num_head=in_num_head, - mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[i], norm_layer=norm_layer)) + dim=inner_dim, + dim_out=embed_dim, + num_pixel=num_pixel, + num_heads_in=num_heads_inner, + num_heads_out=num_heads_outer, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + )) self.blocks = nn.ModuleList(blocks) self.norm = norm_layer(embed_dim) + self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() trunc_normal_(self.cls_token, std=.02) @@ -260,6 +326,7 @@ class TNT(nn.Module): def forward_head(self, x, pre_logits: bool = False): if self.global_pool: x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + x = self.head_drop(x) return x if pre_logits else self.head(x) def forward(self, x): @@ -290,16 +357,16 @@ def _create_tnt(variant, pretrained=False, **kwargs): @register_model def tnt_s_patch16_224(pretrained=False, **kwargs): model_cfg = dict( - patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4, - qkv_bias=False, **kwargs) - model = _create_tnt('tnt_s_patch16_224', pretrained=pretrained, **model_cfg) + patch_size=16, embed_dim=384, inner_dim=24, depth=12, num_heads_outer=6, + qkv_bias=False) + model = _create_tnt('tnt_s_patch16_224', pretrained=pretrained, **dict(model_cfg, **kwargs)) return model @register_model def tnt_b_patch16_224(pretrained=False, **kwargs): model_cfg = dict( - patch_size=16, embed_dim=640, in_dim=40, depth=12, num_heads=10, in_num_head=4, - qkv_bias=False, **kwargs) - model = _create_tnt('tnt_b_patch16_224', pretrained=pretrained, **model_cfg) + patch_size=16, embed_dim=640, inner_dim=40, depth=12, num_heads_outer=10, + qkv_bias=False) + model = _create_tnt('tnt_b_patch16_224', pretrained=pretrained, **dict(model_cfg, **kwargs)) return model diff --git a/timm/models/twins.py b/timm/models/twins.py index 41944c36..25fc95c7 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -200,20 +200,20 @@ class GlobalSubSampleAttn(nn.Module): class Block(nn.Module): def __init__( - self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0., + self, dim, num_heads, mlp_ratio=4., proj_drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, ws=None): super().__init__() self.norm1 = norm_layer(dim) if ws is None: - self.attn = Attention(dim, num_heads, False, None, attn_drop, drop) + self.attn = Attention(dim, num_heads, False, None, attn_drop, proj_drop) elif ws == 1: - self.attn = GlobalSubSampleAttn(dim, num_heads, attn_drop, drop, sr_ratio) + self.attn = GlobalSubSampleAttn(dim, num_heads, attn_drop, proj_drop, sr_ratio) else: - self.attn = LocallyGroupedAttn(dim, num_heads, attn_drop, drop, ws) + self.attn = LocallyGroupedAttn(dim, num_heads, attn_drop, proj_drop, ws) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=proj_drop) def forward(self, x, size: Size_): x = x + self.drop_path(self.attn(self.norm1(x), size)) @@ -275,10 +275,26 @@ class Twins(nn.Module): Adapted from PVT (PyramidVisionTransformer) class at https://github.com/whai362/PVT.git """ def __init__( - self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, global_pool='avg', - embed_dims=(64, 128, 256, 512), num_heads=(1, 2, 4, 8), mlp_ratios=(4, 4, 4, 4), depths=(3, 4, 6, 3), - sr_ratios=(8, 4, 2, 1), wss=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., - norm_layer=partial(nn.LayerNorm, eps=1e-6), block_cls=Block): + self, + img_size=224, + patch_size=4, + in_chans=3, + num_classes=1000, + global_pool='avg', + embed_dims=(64, 128, 256, 512), + num_heads=(1, 2, 4, 8), + mlp_ratios=(4, 4, 4, 4), + depths=(3, 4, 6, 3), + sr_ratios=(8, 4, 2, 1), + wss=None, + drop_rate=0., + pos_drop_rate=0., + proj_drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_layer=partial(nn.LayerNorm, eps=1e-6), + block_cls=Block, + ): super().__init__() self.num_classes = num_classes self.global_pool = global_pool @@ -293,7 +309,7 @@ class Twins(nn.Module): self.pos_drops = nn.ModuleList() for i in range(len(depths)): self.patch_embeds.append(PatchEmbed(img_size, patch_size, prev_chs, embed_dims[i])) - self.pos_drops.append(nn.Dropout(p=drop_rate)) + self.pos_drops.append(nn.Dropout(p=pos_drop_rate)) prev_chs = embed_dims[i] img_size = tuple(t // patch_size for t in img_size) patch_size = 2 @@ -303,9 +319,16 @@ class Twins(nn.Module): cur = 0 for k in range(len(depths)): _block = nn.ModuleList([block_cls( - dim=embed_dims[k], num_heads=num_heads[k], mlp_ratio=mlp_ratios[k], drop=drop_rate, - attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[k], - ws=1 if wss is None or i % 2 == 1 else wss[k]) for i in range(depths[k])]) + dim=embed_dims[k], + num_heads=num_heads[k], + mlp_ratio=mlp_ratios[k], + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[cur + i], + norm_layer=norm_layer, + sr_ratio=sr_ratios[k], + ws=1 if wss is None or i % 2 == 1 else wss[k]) for i in range(depths[k])], + ) self.blocks.append(_block) cur += depths[k] @@ -314,6 +337,7 @@ class Twins(nn.Module): self.norm = norm_layer(self.num_features) # classification head + self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() # init weights @@ -386,6 +410,7 @@ class Twins(nn.Module): def forward_head(self, x, pre_logits: bool = False): if self.global_pool == 'avg': x = x.mean(dim=1) + x = self.head_drop(x) return x if pre_logits else self.head(x) def forward(self, x): @@ -404,47 +429,47 @@ def _create_twins(variant, pretrained=False, **kwargs): @register_model def twins_pcpvt_small(pretrained=False, **kwargs): - model_kwargs = dict( + model_args = dict( patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], - depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], **kwargs) - return _create_twins('twins_pcpvt_small', pretrained=pretrained, **model_kwargs) + depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]) + return _create_twins('twins_pcpvt_small', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def twins_pcpvt_base(pretrained=False, **kwargs): - model_kwargs = dict( + model_args = dict( patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], - depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], **kwargs) - return _create_twins('twins_pcpvt_base', pretrained=pretrained, **model_kwargs) + depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1]) + return _create_twins('twins_pcpvt_base', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def twins_pcpvt_large(pretrained=False, **kwargs): - model_kwargs = dict( + model_args = dict( patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], - depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], **kwargs) - return _create_twins('twins_pcpvt_large', pretrained=pretrained, **model_kwargs) + depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1]) + return _create_twins('twins_pcpvt_large', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def twins_svt_small(pretrained=False, **kwargs): - model_kwargs = dict( + model_args = dict( patch_size=4, embed_dims=[64, 128, 256, 512], num_heads=[2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4], - depths=[2, 2, 10, 4], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs) - return _create_twins('twins_svt_small', pretrained=pretrained, **model_kwargs) + depths=[2, 2, 10, 4], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1]) + return _create_twins('twins_svt_small', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def twins_svt_base(pretrained=False, **kwargs): - model_kwargs = dict( + model_args = dict( patch_size=4, embed_dims=[96, 192, 384, 768], num_heads=[3, 6, 12, 24], mlp_ratios=[4, 4, 4, 4], - depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs) - return _create_twins('twins_svt_base', pretrained=pretrained, **model_kwargs) + depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1]) + return _create_twins('twins_svt_base', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def twins_svt_large(pretrained=False, **kwargs): - model_kwargs = dict( + model_args = dict( patch_size=4, embed_dims=[128, 256, 512, 1024], num_heads=[4, 8, 16, 32], mlp_ratios=[4, 4, 4, 4], - depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs) - return _create_twins('twins_svt_large', pretrained=pretrained, **model_kwargs) + depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1]) + return _create_twins('twins_svt_large', pretrained=pretrained, **dict(model_args, **kwargs)) diff --git a/timm/models/visformer.py b/timm/models/visformer.py index e15ae4a5..d44d56fd 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -40,8 +40,15 @@ default_cfgs = dict( class SpatialMlp(nn.Module): def __init__( - self, in_features, hidden_features=None, out_features=None, - act_layer=nn.GELU, drop=0., group=8, spatial_conv=False): + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0., + group=8, + spatial_conv=False, + ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features @@ -83,6 +90,8 @@ class SpatialMlp(nn.Module): class Attention(nn.Module): + fast_attn: torch.jit.Final[bool] + def __init__(self, dim, num_heads=8, head_dim_ratio=1., attn_drop=0., proj_drop=0.): super().__init__() self.dim = dim @@ -90,6 +99,8 @@ class Attention(nn.Module): head_dim = round(dim // num_heads * head_dim_ratio) self.head_dim = head_dim self.scale = head_dim ** -0.5 + self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME + self.qkv = nn.Conv2d(dim, head_dim * num_heads * 3, 1, stride=1, padding=0, bias=False) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Conv2d(self.head_dim * self.num_heads, dim, 1, stride=1, padding=0, bias=False) @@ -100,10 +111,16 @@ class Attention(nn.Module): x = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, -1).permute(1, 0, 2, 4, 3) q, k, v = x.unbind(0) - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - x = attn @ v + if self.fast_attn: + x = torch.nn.functional.scaled_dot_product_attention( + q, k, v, + dropout_p=self.attn_drop.p, + ) + else: + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v x = x.permute(0, 1, 3, 2).reshape(B, -1, H, W) x = self.proj(x) @@ -113,9 +130,20 @@ class Attention(nn.Module): class Block(nn.Module): def __init__( - self, dim, num_heads, head_dim_ratio=1., mlp_ratio=4., - drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=LayerNorm2d, - group=8, attn_disabled=False, spatial_conv=False): + self, + dim, + num_heads, + head_dim_ratio=1., + mlp_ratio=4., + proj_drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=LayerNorm2d, + group=8, + attn_disabled=False, + spatial_conv=False, + ): super().__init__() self.spatial_conv = spatial_conv self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -125,12 +153,22 @@ class Block(nn.Module): else: self.norm1 = norm_layer(dim) self.attn = Attention( - dim, num_heads=num_heads, head_dim_ratio=head_dim_ratio, attn_drop=attn_drop, proj_drop=drop) + dim, + num_heads=num_heads, + head_dim_ratio=head_dim_ratio, + attn_drop=attn_drop, + proj_drop=proj_drop, + ) self.norm2 = norm_layer(dim) self.mlp = SpatialMlp( - in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop, - group=group, spatial_conv=spatial_conv) # new setting + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + group=group, + spatial_conv=spatial_conv, + ) def forward(self, x): if self.attn is not None: @@ -141,10 +179,31 @@ class Block(nn.Module): class Visformer(nn.Module): def __init__( - self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, init_channels=32, embed_dim=384, - depth=12, num_heads=6, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0., - norm_layer=LayerNorm2d, attn_stage='111', pos_embed=True, spatial_conv='111', - vit_stem=False, group=8, global_pool='avg', conv_init=False, embed_norm=None): + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + init_channels=32, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4., + drop_rate=0., + pos_drop_rate=0., + proj_drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_layer=LayerNorm2d, + attn_stage='111', + use_pos_embed=True, + spatial_conv='111', + vit_stem=False, + group=8, + global_pool='avg', + conv_init=False, + embed_norm=None, + ): super().__init__() img_size = to_2tuple(img_size) self.num_classes = num_classes @@ -159,7 +218,7 @@ class Visformer(nn.Module): else: self.stage_num1 = self.stage_num3 = depth // 3 self.stage_num2 = depth - self.stage_num1 - self.stage_num3 - self.pos_embed = pos_embed + self.use_pos_embed = use_pos_embed self.grad_checkpointing = False dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] @@ -167,15 +226,25 @@ class Visformer(nn.Module): if self.vit_stem: self.stem = None self.patch_embed1 = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, - embed_dim=embed_dim, norm_layer=embed_norm, flatten=False) + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=embed_norm, + flatten=False, + ) img_size = [x // patch_size for x in img_size] else: if self.init_channels is None: self.stem = None self.patch_embed1 = PatchEmbed( - img_size=img_size, patch_size=patch_size // 2, in_chans=in_chans, - embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False) + img_size=img_size, + patch_size=patch_size // 2, + in_chans=in_chans, + embed_dim=embed_dim // 2, + norm_layer=embed_norm, + flatten=False, + ) img_size = [x // (patch_size // 2) for x in img_size] else: self.stem = nn.Sequential( @@ -185,21 +254,37 @@ class Visformer(nn.Module): ) img_size = [x // 2 for x in img_size] self.patch_embed1 = PatchEmbed( - img_size=img_size, patch_size=patch_size // 4, in_chans=self.init_channels, - embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False) + img_size=img_size, + patch_size=patch_size // 4, + in_chans=self.init_channels, + embed_dim=embed_dim // 2, + norm_layer=embed_norm, + flatten=False, + ) img_size = [x // (patch_size // 4) for x in img_size] - if self.pos_embed: + if self.use_pos_embed: if self.vit_stem: self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim, *img_size)) else: self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim//2, *img_size)) - self.pos_drop = nn.Dropout(p=drop_rate) + self.pos_drop = nn.Dropout(p=pos_drop_rate) + else: + self.pos_embed1 = None + self.stage1 = nn.Sequential(*[ Block( - dim=embed_dim//2, num_heads=num_heads, head_dim_ratio=0.5, mlp_ratio=mlp_ratio, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, - group=group, attn_disabled=(attn_stage[0] == '0'), spatial_conv=(spatial_conv[0] == '1') + dim=embed_dim//2, + num_heads=num_heads, + head_dim_ratio=0.5, + mlp_ratio=mlp_ratio, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + group=group, + attn_disabled=(attn_stage[0] == '0'), + spatial_conv=(spatial_conv[0] == '1'), ) for i in range(self.stage_num1) ]) @@ -207,16 +292,33 @@ class Visformer(nn.Module): # stage2 if not self.vit_stem: self.patch_embed2 = PatchEmbed( - img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim // 2, - embed_dim=embed_dim, norm_layer=embed_norm, flatten=False) + img_size=img_size, + patch_size=patch_size // 8, + in_chans=embed_dim // 2, + embed_dim=embed_dim, + norm_layer=embed_norm, + flatten=False, + ) img_size = [x // (patch_size // 8) for x in img_size] - if self.pos_embed: + if self.use_pos_embed: self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, *img_size)) + else: + self.pos_embed2 = None + else: + self.patch_embed2 = None self.stage2 = nn.Sequential(*[ Block( - dim=embed_dim, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, - group=group, attn_disabled=(attn_stage[1] == '0'), spatial_conv=(spatial_conv[1] == '1') + dim=embed_dim, + num_heads=num_heads, + head_dim_ratio=1.0, + mlp_ratio=mlp_ratio, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + group=group, + attn_disabled=(attn_stage[1] == '0'), + spatial_conv=(spatial_conv[1] == '1'), ) for i in range(self.stage_num1, self.stage_num1+self.stage_num2) ]) @@ -224,27 +326,48 @@ class Visformer(nn.Module): # stage 3 if not self.vit_stem: self.patch_embed3 = PatchEmbed( - img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim, - embed_dim=embed_dim * 2, norm_layer=embed_norm, flatten=False) + img_size=img_size, + patch_size=patch_size // 8, + in_chans=embed_dim, + embed_dim=embed_dim * 2, + norm_layer=embed_norm, + flatten=False, + ) img_size = [x // (patch_size // 8) for x in img_size] - if self.pos_embed: + if self.use_pos_embed: self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, *img_size)) + else: + self.pos_embed3 = None + else: + self.patch_embed3 = None self.stage3 = nn.Sequential(*[ Block( - dim=embed_dim*2, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, - group=group, attn_disabled=(attn_stage[2] == '0'), spatial_conv=(spatial_conv[2] == '1') + dim=embed_dim * 2, + num_heads=num_heads, + head_dim_ratio=1.0, + mlp_ratio=mlp_ratio, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + group=group, + attn_disabled=(attn_stage[2] == '0'), + spatial_conv=(spatial_conv[2] == '1'), ) for i in range(self.stage_num1+self.stage_num2, depth) ]) - # head self.num_features = embed_dim if self.vit_stem else embed_dim * 2 self.norm = norm_layer(self.num_features) - self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + # head + global_pool, head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + self.global_pool = global_pool + self.head_drop = nn.Dropout(drop_rate) + self.head = head # weights init - if self.pos_embed: + if self.use_pos_embed: trunc_normal_(self.pos_embed1, std=0.02) if not self.vit_stem: trunc_normal_(self.pos_embed2, std=0.02) @@ -293,7 +416,7 @@ class Visformer(nn.Module): # stage 1 x = self.patch_embed1(x) - if self.pos_embed: + if self.pos_embed1 is not None: x = self.pos_drop(x + self.pos_embed1) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.stage1, x) @@ -301,9 +424,9 @@ class Visformer(nn.Module): x = self.stage1(x) # stage 2 - if not self.vit_stem: + if self.patch_embed2 is not None: x = self.patch_embed2(x) - if self.pos_embed: + if self.pos_embed2 is not None: x = self.pos_drop(x + self.pos_embed2) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.stage2, x) @@ -311,9 +434,9 @@ class Visformer(nn.Module): x = self.stage2(x) # stage3 - if not self.vit_stem: + if self.patch_embed3 is not None: x = self.patch_embed3(x) - if self.pos_embed: + if self.pos_embed3 is not None: x = self.pos_drop(x + self.pos_embed3) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.stage3, x) @@ -325,6 +448,7 @@ class Visformer(nn.Module): def forward_head(self, x, pre_logits: bool = False): x = self.global_pool(x) + x = self.head_drop(x) return x if pre_logits else self.head(x) def forward(self, x): @@ -345,8 +469,8 @@ def visformer_tiny(pretrained=False, **kwargs): model_cfg = dict( init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8, attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, - embed_norm=nn.BatchNorm2d, **kwargs) - model = _create_visformer('visformer_tiny', pretrained=pretrained, **model_cfg) + embed_norm=nn.BatchNorm2d) + model = _create_visformer('visformer_tiny', pretrained=pretrained, **dict(model_cfg, **kwargs)) return model @@ -355,8 +479,8 @@ def visformer_small(pretrained=False, **kwargs): model_cfg = dict( init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8, attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, - embed_norm=nn.BatchNorm2d, **kwargs) - model = _create_visformer('visformer_small', pretrained=pretrained, **model_cfg) + embed_norm=nn.BatchNorm2d) + model = _create_visformer('visformer_small', pretrained=pretrained, **dict(model_cfg, **kwargs)) return model diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 2fefb3ab..9034bc5d 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -27,7 +27,7 @@ import logging import math from collections import OrderedDict from functools import partial -from typing import Optional, List +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -38,7 +38,7 @@ from torch.jit import Final from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \ - resample_abs_pos_embed, RmsNorm + resample_abs_pos_embed, RmsNorm, PatchDropout from ._builder import build_model_with_cfg from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv from ._registry import generate_default_cfgs, register_model, register_model_deprecations @@ -119,7 +119,7 @@ class Block(nn.Module): mlp_ratio=4., qkv_bias=False, qk_norm=False, - drop=0., + proj_drop=0., attn_drop=0., init_values=None, drop_path=0., @@ -134,7 +134,7 @@ class Block(nn.Module): qkv_bias=qkv_bias, qk_norm=qk_norm, attn_drop=attn_drop, - proj_drop=drop, + proj_drop=proj_drop, norm_layer=norm_layer, ) self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() @@ -145,7 +145,7 @@ class Block(nn.Module): in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, - drop=drop, + drop=proj_drop, ) self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -165,7 +165,7 @@ class ResPostBlock(nn.Module): mlp_ratio=4., qkv_bias=False, qk_norm=False, - drop=0., + proj_drop=0., attn_drop=0., init_values=None, drop_path=0., @@ -181,7 +181,7 @@ class ResPostBlock(nn.Module): qkv_bias=qkv_bias, qk_norm=qk_norm, attn_drop=attn_drop, - proj_drop=drop, + proj_drop=proj_drop, norm_layer=norm_layer, ) self.norm1 = norm_layer(dim) @@ -191,7 +191,7 @@ class ResPostBlock(nn.Module): in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, - drop=drop, + drop=proj_drop, ) self.norm2 = norm_layer(dim) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -224,7 +224,7 @@ class ParallelScalingBlock(nn.Module): mlp_ratio=4., qkv_bias=False, qk_norm=False, - drop=0., + proj_drop=0., attn_drop=0., init_values=None, drop_path=0., @@ -255,7 +255,7 @@ class ParallelScalingBlock(nn.Module): self.attn_drop = nn.Dropout(attn_drop) self.attn_out_proj = nn.Linear(dim, dim) - self.mlp_drop = nn.Dropout(drop) + self.mlp_drop = nn.Dropout(proj_drop) self.mlp_act = act_layer() self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim) @@ -318,7 +318,7 @@ class ParallelThingsBlock(nn.Module): qkv_bias=False, qk_norm=False, init_values=None, - drop=0., + proj_drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, @@ -337,7 +337,7 @@ class ParallelThingsBlock(nn.Module): qkv_bias=qkv_bias, qk_norm=qk_norm, attn_drop=attn_drop, - proj_drop=drop, + proj_drop=proj_drop, norm_layer=norm_layer, )), ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), @@ -349,7 +349,7 @@ class ParallelThingsBlock(nn.Module): dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, - drop=drop, + drop=proj_drop, )), ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) @@ -382,53 +382,58 @@ class VisionTransformer(nn.Module): def __init__( self, - img_size=224, - patch_size=16, - in_chans=3, - num_classes=1000, - global_pool='token', - embed_dim=768, - depth=12, - num_heads=12, - mlp_ratio=4., - qkv_bias=True, - qk_norm=False, - init_values=None, - class_token=True, - no_embed_class=False, - pre_norm=False, - fc_norm=None, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0., - weight_init='', - embed_layer=PatchEmbed, - norm_layer=None, - act_layer=None, - block_fn=Block, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'token', + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4., + qkv_bias: bool = True, + qk_norm: bool = False, + init_values: Optional[float] = None, + class_token: bool = True, + no_embed_class: bool = False, + pre_norm: bool = False, + fc_norm: Optional[bool] = None, + drop_rate: float = 0., + pos_drop_rate: float = 0., + patch_drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + weight_init: str = '', + embed_layer: Callable = PatchEmbed, + norm_layer: Optional[Callable] = None, + act_layer: Optional[Callable] = None, + block_fn: Callable = Block, ): """ Args: - img_size (int, tuple): input image size - patch_size (int, tuple): patch size - in_chans (int): number of input channels - num_classes (int): number of classes for classification head - global_pool (str): type of global pooling for final sequence (default: 'token') - embed_dim (int): embedding dimension - depth (int): depth of transformer - num_heads (int): number of attention heads - mlp_ratio (int): ratio of mlp hidden dim to embedding dim - qkv_bias (bool): enable bias for qkv if True - init_values: (float): layer-scale init values - class_token (bool): use class token - fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None) - drop_rate (float): dropout rate - attn_drop_rate (float): attention dropout rate - drop_path_rate (float): stochastic depth rate - weight_init (str): weight init scheme - embed_layer (nn.Module): patch embedding layer - norm_layer: (nn.Module): normalization layer - act_layer: (nn.Module): MLP activation layer + img_size: Input image size. + patch_size: Patch size. + in_chans: Number of image input channels. + num_classes: Mumber of classes for classification head. + global_pool: Type of global pooling for final sequence (default: 'token'). + embed_dim: Transformer embedding dimension. + depth: Depth of transformer. + num_heads: Number of attention heads. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: Enable bias for qkv projections if True. + init_values: Layer-scale init values (layer-scale enabled if not None). + class_token: Use class token. + fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'. + drop_rate: Head dropout rate. + pos_drop_rate: Position embedding dropout rate. + attn_drop_rate: Attention dropout rate. + drop_path_rate: Stochastic depth rate. + weight_init: Weight initialization scheme. + embed_layer: Patch embedding layey. + norm_layer: Normalization layer. + act_layer: MLP activation layer. + block_fn: Transformer block layer. """ super().__init__() assert global_pool in ('', 'avg', 'token') @@ -456,7 +461,14 @@ class VisionTransformer(nn.Module): self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) - self.pos_drop = nn.Dropout(p=drop_rate) + self.pos_drop = nn.Dropout(p=pos_drop_rate) + if patch_drop_rate > 0: + self.patch_drop = PatchDropout( + patch_drop_rate, + num_prefix_tokens=self.num_prefix_tokens, + ) + else: + self.patch_drop = nn.Identity() self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule @@ -468,7 +480,7 @@ class VisionTransformer(nn.Module): qkv_bias=qkv_bias, qk_norm=qk_norm, init_values=init_values, - drop=drop_rate, + proj_drop=proj_drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, @@ -479,6 +491,7 @@ class VisionTransformer(nn.Module): # Classifier Head self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() + self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() if weight_init != 'skip': @@ -544,6 +557,7 @@ class VisionTransformer(nn.Module): def forward_features(self, x): x = self.patch_embed(x) x = self._pos_embed(x) + x = self.patch_drop(x) x = self.norm_pre(x) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) @@ -556,6 +570,7 @@ class VisionTransformer(nn.Module): if self.global_pool: x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] x = self.fc_norm(x) + x = self.head_drop(x) return x if pre_logits else self.head(x) def forward(self, x): diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index ca01e9fb..a511f66b 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -110,7 +110,7 @@ class RelPosBlock(nn.Module): qk_norm=False, rel_pos_cls=None, init_values=None, - drop=0., + proj_drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, @@ -125,7 +125,7 @@ class RelPosBlock(nn.Module): qk_norm=qk_norm, rel_pos_cls=rel_pos_cls, attn_drop=attn_drop, - proj_drop=drop, + proj_drop=proj_drop, ) self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here @@ -136,7 +136,7 @@ class RelPosBlock(nn.Module): in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, - drop=drop, + drop=proj_drop, ) self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -158,7 +158,7 @@ class ResPostRelPosBlock(nn.Module): qk_norm=False, rel_pos_cls=None, init_values=None, - drop=0., + proj_drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, @@ -174,7 +174,7 @@ class ResPostRelPosBlock(nn.Module): qk_norm=qk_norm, rel_pos_cls=rel_pos_cls, attn_drop=attn_drop, - proj_drop=drop, + proj_drop=proj_drop, ) self.norm1 = norm_layer(dim) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -183,7 +183,7 @@ class ResPostRelPosBlock(nn.Module): in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, - drop=drop, + drop=proj_drop, ) self.norm2 = norm_layer(dim) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -232,6 +232,7 @@ class VisionTransformerRelPos(nn.Module): rel_pos_dim=None, shared_rel_pos=False, drop_rate=0., + proj_drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='skip', @@ -259,6 +260,7 @@ class VisionTransformerRelPos(nn.Module): rel_pos_ty pe (str): type of relative position shared_rel_pos (bool): share relative pos across all blocks drop_rate (float): dropout rate + proj_drop_rate (float): projection dropout rate attn_drop_rate (float): attention dropout rate drop_path_rate (float): stochastic depth rate weight_init (str): weight init scheme @@ -313,7 +315,7 @@ class VisionTransformerRelPos(nn.Module): qk_norm=qk_norm, rel_pos_cls=rel_pos_cls, init_values=init_values, - drop=drop_rate, + proj_drop=proj_drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, @@ -324,6 +326,7 @@ class VisionTransformerRelPos(nn.Module): # Classifier Head self.fc_norm = norm_layer(embed_dim) if fc_norm else nn.Identity() + self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() if weight_init != 'skip': @@ -380,6 +383,7 @@ class VisionTransformerRelPos(nn.Module): if self.global_pool: x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] x = self.fc_norm(x) + x = self.head_drop(x) return x if pre_logits else self.head(x) def forward(self, x): diff --git a/timm/models/volo.py b/timm/models/volo.py index 1117995a..94a0efc3 100644 --- a/timm/models/volo.py +++ b/timm/models/volo.py @@ -85,7 +85,17 @@ default_cfgs = { class OutlookAttention(nn.Module): - def __init__(self, dim, num_heads, kernel_size=3, padding=1, stride=1, qkv_bias=False, attn_drop=0., proj_drop=0.): + def __init__( + self, + dim, + num_heads, + kernel_size=3, + padding=1, + stride=1, + qkv_bias=False, + attn_drop=0., + proj_drop=0., + ): super().__init__() head_dim = dim // num_heads self.num_heads = num_heads @@ -133,21 +143,40 @@ class OutlookAttention(nn.Module): class Outlooker(nn.Module): def __init__( - self, dim, kernel_size, padding, stride=1, num_heads=1, mlp_ratio=3., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, qkv_bias=False + self, + dim, + kernel_size, + padding, + stride=1, + num_heads=1, + mlp_ratio=3., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + qkv_bias=False, ): super().__init__() self.norm1 = norm_layer(dim) self.attn = OutlookAttention( - dim, num_heads, kernel_size=kernel_size, - padding=padding, stride=stride, - qkv_bias=qkv_bias, attn_drop=attn_drop) + dim, + num_heads, + kernel_size=kernel_size, + padding=padding, + stride=stride, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + ) def forward(self, x): x = x + self.drop_path(self.attn(self.norm1(x))) @@ -158,7 +187,13 @@ class Outlooker(nn.Module): class Attention(nn.Module): def __init__( - self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + self, + dim, + num_heads=8, + qkv_bias=False, + attn_drop=0., + proj_drop=0., + ): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads @@ -189,8 +224,16 @@ class Attention(nn.Module): class Transformer(nn.Module): def __init__( - self, dim, num_heads, mlp_ratio=4., qkv_bias=False, - attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop) @@ -211,7 +254,14 @@ class Transformer(nn.Module): class ClassAttention(nn.Module): def __init__( - self, dim, num_heads=8, head_dim=None, qkv_bias=False, attn_drop=0., proj_drop=0.): + self, + dim, + num_heads=8, + head_dim=None, + qkv_bias=False, + attn_drop=0., + proj_drop=0., + ): super().__init__() self.num_heads = num_heads if head_dim is not None: @@ -246,17 +296,38 @@ class ClassAttention(nn.Module): class ClassBlock(nn.Module): def __init__( - self, dim, num_heads, head_dim=None, mlp_ratio=4., qkv_bias=False, - drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + self, + dim, + num_heads, + head_dim=None, + mlp_ratio=4., + qkv_bias=False, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): super().__init__() self.norm1 = norm_layer(dim) self.attn = ClassAttention( - dim, num_heads=num_heads, head_dim=head_dim, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + dim, + num_heads=num_heads, + head_dim=head_dim, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) # NOTE: drop path for stochastic depth self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) def forward(self, x): cls_embed = x[:, :1] @@ -354,18 +425,33 @@ def outlooker_blocks( blocks = [] for block_idx in range(layers[index]): block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1) - blocks.append( - block_fn( - dim, kernel_size=kernel_size, padding=padding, - stride=stride, num_heads=num_heads, mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, attn_drop=attn_drop, drop_path=block_dpr)) + blocks.append(block_fn( + dim, + kernel_size=kernel_size, + padding=padding, + stride=stride, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + drop_path=block_dpr, + )) blocks = nn.Sequential(*blocks) return blocks def transformer_blocks( - block_fn, index, dim, layers, num_heads, mlp_ratio=3., - qkv_bias=False, attn_drop=0, drop_path_rate=0., **kwargs): + block_fn, + index, + dim, + layers, + num_heads, + mlp_ratio=3., + qkv_bias=False, + attn_drop=0, + drop_path_rate=0., + **kwargs, +): """ generate transformer layers in stage2 return: transformer layers @@ -373,13 +459,14 @@ def transformer_blocks( blocks = [] for block_idx in range(layers[index]): block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1) - blocks.append( - block_fn( - dim, num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - attn_drop=attn_drop, - drop_path=block_dpr)) + blocks.append(block_fn( + dim, + num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + drop_path=block_dpr, + )) blocks = nn.Sequential(*blocks) return blocks @@ -405,6 +492,7 @@ class VOLO(nn.Module): mlp_ratio=3.0, qkv_bias=False, drop_rate=0., + pos_drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, @@ -429,14 +517,18 @@ class VOLO(nn.Module): self.grad_checkpointing = False self.patch_embed = PatchEmbed( - stem_conv=True, stem_stride=2, patch_size=patch_size, - in_chans=in_chans, hidden_dim=stem_hidden_dim, - embed_dim=embed_dims[0]) + stem_conv=True, + stem_stride=2, + patch_size=patch_size, + in_chans=in_chans, + hidden_dim=stem_hidden_dim, + embed_dim=embed_dims[0], + ) # inital positional encoding, we add positional encoding after outlooker blocks patch_grid = (img_size[0] // patch_size // pooling_scale, img_size[1] // patch_size // pooling_scale) self.pos_embed = nn.Parameter(torch.zeros(1, patch_grid[0], patch_grid[1], embed_dims[-1])) - self.pos_drop = nn.Dropout(p=drop_rate) + self.pos_drop = nn.Dropout(p=pos_drop_rate) # set the main block in network network = [] @@ -444,14 +536,31 @@ class VOLO(nn.Module): if outlook_attention[i]: # stage 1 stage = outlooker_blocks( - Outlooker, i, embed_dims[i], layers, num_heads[i], mlp_ratio=mlp_ratio[i], - qkv_bias=qkv_bias, attn_drop=attn_drop_rate, norm_layer=norm_layer) + Outlooker, + i, + embed_dims[i], + layers, + num_heads[i], + mlp_ratio=mlp_ratio[i], + qkv_bias=qkv_bias, + attn_drop=attn_drop_rate, + norm_layer=norm_layer, + ) network.append(stage) else: # stage 2 stage = transformer_blocks( - Transformer, i, embed_dims[i], layers, num_heads[i], mlp_ratio=mlp_ratio[i], qkv_bias=qkv_bias, - drop_path_rate=drop_path_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer) + Transformer, + i, + embed_dims[i], + layers, + num_heads[i], + mlp_ratio=mlp_ratio[i], + qkv_bias=qkv_bias, + drop_path_rate=drop_path_rate, + attn_drop=attn_drop_rate, + norm_layer=norm_layer, + ) network.append(stage) if downsamples[i]: @@ -463,19 +572,18 @@ class VOLO(nn.Module): # set post block, for example, class attention layers self.post_network = None if post_layers is not None: - self.post_network = nn.ModuleList( - [ - get_block( - post_layers[i], - dim=embed_dims[-1], - num_heads=num_heads[-1], - mlp_ratio=mlp_ratio[-1], - qkv_bias=qkv_bias, - attn_drop=attn_drop_rate, - drop_path=0., - norm_layer=norm_layer) - for i in range(len(post_layers)) - ]) + self.post_network = nn.ModuleList([ + get_block( + post_layers[i], + dim=embed_dims[-1], + num_heads=num_heads[-1], + mlp_ratio=mlp_ratio[-1], + qkv_bias=qkv_bias, + attn_drop=attn_drop_rate, + drop_path=0., + norm_layer=norm_layer) + for i in range(len(post_layers)) + ]) self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims[-1])) trunc_normal_(self.cls_token, std=.02) @@ -487,6 +595,7 @@ class VOLO(nn.Module): self.norm = norm_layer(self.num_features) # Classifier head + self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() trunc_normal_(self.pos_embed, std=.02) @@ -630,6 +739,7 @@ class VOLO(nn.Module): out = x[:, 0] else: out = x + x = self.head_drop(x) if pre_logits: return out out = self.head(out) diff --git a/timm/models/xcit.py b/timm/models/xcit.py index 57c11183..3fd35c7d 100644 --- a/timm/models/xcit.py +++ b/timm/models/xcit.py @@ -219,17 +219,28 @@ class ClassAttentionBlock(nn.Module): """Class Attention Layer as in CaiT https://arxiv.org/abs/2103.17239""" def __init__( - self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm, eta=1., tokens_norm=False): + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + proj_drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + eta=1., + tokens_norm=False, + ): super().__init__() self.norm1 = norm_layer(dim) self.attn = ClassAttn( - dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) - self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop) if eta is not None: # LayerScale Initialization (no layerscale when None) self.gamma1 = nn.Parameter(eta * torch.ones(dim)) @@ -297,18 +308,28 @@ class XCA(nn.Module): class XCABlock(nn.Module): def __init__( - self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, eta=1.): + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + proj_drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + eta=1., + ): super().__init__() self.norm1 = norm_layer(dim) - self.attn = XCA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.attn = XCA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm3 = norm_layer(dim) self.local_mp = LPI(in_features=dim, act_layer=act_layer) self.norm2 = norm_layer(dim) - self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop) self.gamma1 = nn.Parameter(eta * torch.ones(dim)) self.gamma3 = nn.Parameter(eta * torch.ones(dim)) @@ -331,9 +352,29 @@ class XCiT(nn.Module): """ def __init__( - self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', embed_dim=768, - depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., - act_layer=None, norm_layer=None, cls_attn_layers=2, use_pos_embed=True, eta=1., tokens_norm=False): + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + global_pool='token', + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + qkv_bias=True, + drop_rate=0., + pos_drop_rate=0., + proj_drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + act_layer=None, + norm_layer=None, + cls_attn_layers=2, + use_pos_embed=True, + eta=1., + tokens_norm=False, + ): """ Args: img_size (int, tuple): input image size @@ -346,6 +387,8 @@ class XCiT(nn.Module): mlp_ratio (int): ratio of mlp hidden dim to embedding dim qkv_bias (bool): enable bias for qkv if True drop_rate (float): dropout rate after positional embedding, and in XCA/CA projection + MLP + pos_drop_rate: position embedding dropout rate + proj_drop_rate (float): projection dropout rate attn_drop_rate (float): attention dropout rate drop_path_rate (float): stochastic depth rate (constant across all layers) norm_layer: (nn.Module): normalization layer @@ -372,28 +415,53 @@ class XCiT(nn.Module): self.grad_checkpointing = False self.patch_embed = ConvPatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, act_layer=act_layer) + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + act_layer=act_layer, + ) self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - self.use_pos_embed = use_pos_embed if use_pos_embed: self.pos_embed = PositionalEncodingFourier(dim=embed_dim) - self.pos_drop = nn.Dropout(p=drop_rate) + else: + self.pos_embed = None + self.pos_drop = nn.Dropout(p=pos_drop_rate) self.blocks = nn.ModuleList([ XCABlock( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, - attn_drop=attn_drop_rate, drop_path=drop_path_rate, act_layer=act_layer, norm_layer=norm_layer, eta=eta) + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + act_layer=act_layer, + norm_layer=norm_layer, + eta=eta, + ) for _ in range(depth)]) self.cls_attn_blocks = nn.ModuleList([ ClassAttentionBlock( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, - attn_drop=attn_drop_rate, act_layer=act_layer, norm_layer=norm_layer, eta=eta, tokens_norm=tokens_norm) + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_drop=drop_rate, + attn_drop=attn_drop_rate, + act_layer=act_layer, + norm_layer=norm_layer, + eta=eta, + tokens_norm=tokens_norm, + ) for _ in range(cls_attn_layers)]) # Classifier head self.norm = norm_layer(embed_dim) + self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() # Init weights @@ -438,7 +506,7 @@ class XCiT(nn.Module): # x is (B, N, C). (Hp, Hw) is (height in units of patches, width in units of patches) x, (Hp, Wp) = self.patch_embed(x) - if self.use_pos_embed: + if self.pos_embed is not None: # `pos_embed` (B, C, Hp, Wp), reshape -> (B, C, N), permute -> (B, N, C) pos_encoding = self.pos_embed(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1) x = x + pos_encoding @@ -464,6 +532,7 @@ class XCiT(nn.Module): def forward_head(self, x, pre_logits: bool = False): if self.global_pool: x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + x = self.head_drop(x) return x if pre_logits else self.head(x) def forward(self, x): @@ -509,336 +578,336 @@ def _create_xcit(variant, pretrained=False, default_cfg=None, **kwargs): @register_model def xcit_nano_12_p16_224(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs) - model = _create_xcit('xcit_nano_12_p16_224', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False) + model = _create_xcit('xcit_nano_12_p16_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_nano_12_p16_224_dist(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs) - model = _create_xcit('xcit_nano_12_p16_224_dist', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False) + model = _create_xcit('xcit_nano_12_p16_224_dist', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_nano_12_p16_384_dist(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, img_size=384, **kwargs) - model = _create_xcit('xcit_nano_12_p16_384_dist', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, img_size=384) + model = _create_xcit('xcit_nano_12_p16_384_dist', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_tiny_12_p16_224(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_tiny_12_p16_224', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True) + model = _create_xcit('xcit_tiny_12_p16_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_tiny_12_p16_224_dist(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_tiny_12_p16_224_dist', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True) + model = _create_xcit('xcit_tiny_12_p16_224_dist', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_tiny_12_p16_384_dist(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_tiny_12_p16_384_dist', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True) + model = _create_xcit('xcit_tiny_12_p16_384_dist', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_small_12_p16_224(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_small_12_p16_224', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True) + model = _create_xcit('xcit_small_12_p16_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_small_12_p16_224_dist(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_small_12_p16_224_dist', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True) + model = _create_xcit('xcit_small_12_p16_224_dist', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_small_12_p16_384_dist(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_small_12_p16_384_dist', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True) + model = _create_xcit('xcit_small_12_p16_384_dist', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_tiny_24_p16_224(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_tiny_24_p16_224', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True) + model = _create_xcit('xcit_tiny_24_p16_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_tiny_24_p16_224_dist(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_tiny_24_p16_224_dist', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True) + model = _create_xcit('xcit_tiny_24_p16_224_dist', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_tiny_24_p16_384_dist(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_tiny_24_p16_384_dist', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True) + model = _create_xcit('xcit_tiny_24_p16_384_dist', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_small_24_p16_224(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_small_24_p16_224', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True) + model = _create_xcit('xcit_small_24_p16_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_small_24_p16_224_dist(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_small_24_p16_224_dist', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True) + model = _create_xcit('xcit_small_24_p16_224_dist', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_small_24_p16_384_dist(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_small_24_p16_384_dist', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True) + model = _create_xcit('xcit_small_24_p16_384_dist', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_medium_24_p16_224(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_medium_24_p16_224', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True) + model = _create_xcit('xcit_medium_24_p16_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_medium_24_p16_224_dist(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_medium_24_p16_224_dist', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True) + model = _create_xcit('xcit_medium_24_p16_224_dist', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_medium_24_p16_384_dist(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_medium_24_p16_384_dist', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True) + model = _create_xcit('xcit_medium_24_p16_384_dist', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_large_24_p16_224(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_large_24_p16_224', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True) + model = _create_xcit('xcit_large_24_p16_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_large_24_p16_224_dist(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_large_24_p16_224_dist', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True) + model = _create_xcit('xcit_large_24_p16_224_dist', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_large_24_p16_384_dist(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_large_24_p16_384_dist', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True) + model = _create_xcit('xcit_large_24_p16_384_dist', pretrained=pretrained, **dict(model_args, **kwargs)) return model # Patch size 8x8 models @register_model def xcit_nano_12_p8_224(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs) - model = _create_xcit('xcit_nano_12_p8_224', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False) + model = _create_xcit('xcit_nano_12_p8_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_nano_12_p8_224_dist(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs) - model = _create_xcit('xcit_nano_12_p8_224_dist', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False) + model = _create_xcit('xcit_nano_12_p8_224_dist', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_nano_12_p8_384_dist(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs) - model = _create_xcit('xcit_nano_12_p8_384_dist', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False) + model = _create_xcit('xcit_nano_12_p8_384_dist', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_tiny_12_p8_224(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_tiny_12_p8_224', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True) + model = _create_xcit('xcit_tiny_12_p8_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_tiny_12_p8_224_dist(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_tiny_12_p8_224_dist', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True) + model = _create_xcit('xcit_tiny_12_p8_224_dist', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_tiny_12_p8_384_dist(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_tiny_12_p8_384_dist', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True) + model = _create_xcit('xcit_tiny_12_p8_384_dist', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_small_12_p8_224(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_small_12_p8_224', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True) + model = _create_xcit('xcit_small_12_p8_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_small_12_p8_224_dist(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_small_12_p8_224_dist', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True) + model = _create_xcit('xcit_small_12_p8_224_dist', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_small_12_p8_384_dist(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_small_12_p8_384_dist', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True) + model = _create_xcit('xcit_small_12_p8_384_dist', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_tiny_24_p8_224(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_tiny_24_p8_224', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True) + model = _create_xcit('xcit_tiny_24_p8_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_tiny_24_p8_224_dist(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_tiny_24_p8_224_dist', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True) + model = _create_xcit('xcit_tiny_24_p8_224_dist', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_tiny_24_p8_384_dist(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_tiny_24_p8_384_dist', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True) + model = _create_xcit('xcit_tiny_24_p8_384_dist', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_small_24_p8_224(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_small_24_p8_224', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True) + model = _create_xcit('xcit_small_24_p8_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_small_24_p8_224_dist(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_small_24_p8_224_dist', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True) + model = _create_xcit('xcit_small_24_p8_224_dist', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_small_24_p8_384_dist(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_small_24_p8_384_dist', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True) + model = _create_xcit('xcit_small_24_p8_384_dist', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_medium_24_p8_224(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_medium_24_p8_224', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True) + model = _create_xcit('xcit_medium_24_p8_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_medium_24_p8_224_dist(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_medium_24_p8_224_dist', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True) + model = _create_xcit('xcit_medium_24_p8_224_dist', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_medium_24_p8_384_dist(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_medium_24_p8_384_dist', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True) + model = _create_xcit('xcit_medium_24_p8_384_dist', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_large_24_p8_224(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_large_24_p8_224', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True) + model = _create_xcit('xcit_large_24_p8_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_large_24_p8_224_dist(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_large_24_p8_224_dist', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True) + model = _create_xcit('xcit_large_24_p8_224_dist', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def xcit_large_24_p8_384_dist(pretrained=False, **kwargs): - model_kwargs = dict( - patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs) - model = _create_xcit('xcit_large_24_p8_384_dist', pretrained=pretrained, **model_kwargs) + model_args = dict( + patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True) + model = _create_xcit('xcit_large_24_p8_384_dist', pretrained=pretrained, **dict(model_args, **kwargs)) return model