diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index 2eda24cb..45f2e541 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -27,7 +27,7 @@ from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, from .inplace_abn import InplaceAbn from .linear import Linear from .mixed_conv2d import MixedConv2d -from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, ConvMlp, GlobalResponseNormMlp +from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp from .non_local_attn import NonLocalAttn, BatNonLocalAttn from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\ diff --git a/timm/layers/mlp.py b/timm/layers/mlp.py index c4edf1b1..2c307330 100644 --- a/timm/layers/mlp.py +++ b/timm/layers/mlp.py @@ -97,6 +97,9 @@ class GluMlp(nn.Module): return x +SwiGLUPacked = partial(GluMlp, act_layer=nn.SiLU, gate_last=False) + + class SwiGLU(nn.Module): """ SwiGLU NOTE: GluMLP above can implement SwiGLU, but this impl has split fc1 and @@ -108,7 +111,7 @@ class SwiGLU(nn.Module): hidden_features=None, out_features=None, act_layer=nn.SiLU, - norm_layer=nn.LayerNorm, + norm_layer=None, bias=True, drop=0., ): @@ -130,8 +133,8 @@ class SwiGLU(nn.Module): def init_weights(self): # override init of fc1 w/ gate portion set to weight near zero, bias=1 - nn.init.ones_(self.fc1a.bias) - nn.init.normal_(self.fc1a.weight, std=1e-6) + nn.init.ones_(self.fc1_g.bias) + nn.init.normal_(self.fc1_g.weight, std=1e-6) def forward(self, x): x_gate = self.fc1_g(x) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 40e103a2..b38a168d 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -38,7 +38,7 @@ from torch.jit import Final from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \ - resample_abs_pos_embed, RmsNorm, PatchDropout, use_fused_attn + resample_abs_pos_embed, RmsNorm, PatchDropout, use_fused_attn, SwiGLUPacked from ._builder import build_model_with_cfg from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv from ._registry import generate_default_cfgs, register_model, register_model_deprecations @@ -124,7 +124,8 @@ class Block(nn.Module): init_values=None, drop_path=0., act_layer=nn.GELU, - norm_layer=nn.LayerNorm + norm_layer=nn.LayerNorm, + ffn_layer=Mlp, ): super().__init__() self.norm1 = norm_layer(dim) @@ -141,7 +142,7 @@ class Block(nn.Module): self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) - self.mlp = Mlp( + self.mlp = ffn_layer( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, @@ -170,7 +171,8 @@ class ResPostBlock(nn.Module): init_values=None, drop_path=0., act_layer=nn.GELU, - norm_layer=nn.LayerNorm + norm_layer=nn.LayerNorm, + ffn_layer=Mlp, ): super().__init__() self.init_values = init_values @@ -187,7 +189,7 @@ class ResPostBlock(nn.Module): self.norm1 = norm_layer(dim) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.mlp = Mlp( + self.mlp = ffn_layer( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, @@ -229,7 +231,8 @@ class ParallelScalingBlock(nn.Module): init_values=None, drop_path=0., act_layer=nn.GELU, - norm_layer=nn.LayerNorm + norm_layer=nn.LayerNorm, + ffn_layer=None, # NOTE: not used ): super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' @@ -322,7 +325,8 @@ class ParallelThingsBlock(nn.Module): attn_drop=0., drop_path=0., act_layer=nn.GELU, - norm_layer=nn.LayerNorm + norm_layer=nn.LayerNorm, + ffn_layer=Mlp, ): super().__init__() self.num_parallel = num_parallel @@ -345,7 +349,7 @@ class ParallelThingsBlock(nn.Module): ]))) self.ffns.append(nn.Sequential(OrderedDict([ ('norm', norm_layer(dim)), - ('mlp', Mlp( + ('mlp', ffn_layer( dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, @@ -409,6 +413,7 @@ class VisionTransformer(nn.Module): norm_layer: Optional[Callable] = None, act_layer: Optional[Callable] = None, block_fn: Callable = Block, + ffn_layer: Callable = Mlp, ): """ Args: @@ -484,7 +489,8 @@ class VisionTransformer(nn.Module): attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, - act_layer=act_layer + act_layer=act_layer, + ffn_layer=ffn_layer, ) for i in range(depth)]) self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() @@ -808,6 +814,25 @@ def _convert_openai_clip(state_dict, model): return out_dict +def _convert_dinov2(state_dict, model): + import re + + out_dict = {} + + for k, v in state_dict.items(): + if k == "mask_token": + continue + elif re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k): + out_dict[k.replace("w12", "fc1")] = v + continue + elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k): + out_dict[k.replace("w3", "fc2")] = v + continue + + out_dict[k] = v + + return out_dict + def checkpoint_filter_fn( state_dict, model, @@ -824,6 +849,9 @@ def checkpoint_filter_fn( if 'visual.class_embedding' in state_dict: return _convert_openai_clip(state_dict, model) + if "mask_token" in state_dict: + return _convert_dinov2(state_dict, model) + for k, v in state_dict.items(): if 'patch_embed.proj.weight' in k: O, I, H, W = model.patch_embed.proj.weight.shape @@ -1043,6 +1071,20 @@ default_cfgs = generate_default_cfgs({ url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth', hf_hub_id='timm/', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + + # DINOv2 pretrained - https://arxiv.org/abs/2304.07193 (no classifier head, for fine-tune only) + 'vit_small_patch14_dinov2': _cfg( + url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518)), + 'vit_base_patch14_dinov2': _cfg( + url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518)), + 'vit_large_patch14_dinov2': _cfg( + url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518)), + 'vit_giant_patch14_dinov2': _cfg( + url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518)), # ViT ImageNet-21K-P pretraining by MILL 'vit_base_patch16_224_miil.in21k': _cfg( @@ -1857,6 +1899,66 @@ def vit_huge_patch14_xp_224(pretrained=False, **kwargs) -> VisionTransformer: return model +@register_model +def vit_small_patch14_dinov2(pretrained=False, **kwargs): + """ ViT-S/14 for DINOv2 + """ + model_args = dict( + patch_size=14, embed_dim=384, depth=12, num_heads=6, + init_values=1.0, img_size=518, + ) + + model = _create_vision_transformer( + 'vit_small_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch14_dinov2(pretrained=False, **kwargs): + """ ViT-B/14 for DINOv2 + """ + model_args = dict( + patch_size=14, embed_dim=768, depth=12, num_heads=12, + init_values=1.0, img_size=518, + ) + + model = _create_vision_transformer( + 'vit_base_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch14_dinov2(pretrained=False, **kwargs): + """ ViT-L/14 for DINOv2 + """ + model_args = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=16, + init_values=1.0, img_size=518, + ) + + model = _create_vision_transformer( + 'vit_large_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + +@register_model +def vit_giant_patch14_dinov2(pretrained=False, **kwargs): + """ ViT-G/14 for DINOv2 + """ + + # The hidden_features of SwiGLU is calculated by: + # hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + # When embed_dim=1536, hidden_features=4096 + # With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192 + + model_args = dict( + patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1.0, + mlp_ratio=2.66667 * 2, ffn_layer=SwiGLUPacked, img_size=518, act_layer=nn.SiLU + ) + + model = _create_vision_transformer( + 'vit_giant_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + register_model_deprecations(__name__, { 'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k', 'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k',