mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge branch 'main' into dot_nine_cleanup
This commit is contained in:
commit
59bea4c306
@ -27,7 +27,7 @@ from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible,
|
||||
from .inplace_abn import InplaceAbn
|
||||
from .linear import Linear
|
||||
from .mixed_conv2d import MixedConv2d
|
||||
from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, ConvMlp, GlobalResponseNormMlp
|
||||
from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp
|
||||
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
|
||||
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm
|
||||
from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
|
||||
|
@ -97,6 +97,9 @@ class GluMlp(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
SwiGLUPacked = partial(GluMlp, act_layer=nn.SiLU, gate_last=False)
|
||||
|
||||
|
||||
class SwiGLU(nn.Module):
|
||||
""" SwiGLU
|
||||
NOTE: GluMLP above can implement SwiGLU, but this impl has split fc1 and
|
||||
@ -108,7 +111,7 @@ class SwiGLU(nn.Module):
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.SiLU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
norm_layer=None,
|
||||
bias=True,
|
||||
drop=0.,
|
||||
):
|
||||
@ -130,8 +133,8 @@ class SwiGLU(nn.Module):
|
||||
|
||||
def init_weights(self):
|
||||
# override init of fc1 w/ gate portion set to weight near zero, bias=1
|
||||
nn.init.ones_(self.fc1a.bias)
|
||||
nn.init.normal_(self.fc1a.weight, std=1e-6)
|
||||
nn.init.ones_(self.fc1_g.bias)
|
||||
nn.init.normal_(self.fc1_g.weight, std=1e-6)
|
||||
|
||||
def forward(self, x):
|
||||
x_gate = self.fc1_g(x)
|
||||
|
@ -38,7 +38,7 @@ from torch.jit import Final
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
|
||||
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \
|
||||
resample_abs_pos_embed, RmsNorm, PatchDropout, use_fused_attn
|
||||
resample_abs_pos_embed, RmsNorm, PatchDropout, use_fused_attn, SwiGLUPacked
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
@ -124,7 +124,8 @@ class Block(nn.Module):
|
||||
init_values=None,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm
|
||||
norm_layer=nn.LayerNorm,
|
||||
ffn_layer=Mlp,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
@ -141,7 +142,7 @@ class Block(nn.Module):
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = Mlp(
|
||||
self.mlp = ffn_layer(
|
||||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
@ -170,7 +171,8 @@ class ResPostBlock(nn.Module):
|
||||
init_values=None,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm
|
||||
norm_layer=nn.LayerNorm,
|
||||
ffn_layer=Mlp,
|
||||
):
|
||||
super().__init__()
|
||||
self.init_values = init_values
|
||||
@ -187,7 +189,7 @@ class ResPostBlock(nn.Module):
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
self.mlp = Mlp(
|
||||
self.mlp = ffn_layer(
|
||||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
@ -229,7 +231,8 @@ class ParallelScalingBlock(nn.Module):
|
||||
init_values=None,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm
|
||||
norm_layer=nn.LayerNorm,
|
||||
ffn_layer=None, # NOTE: not used
|
||||
):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
@ -322,7 +325,8 @@ class ParallelThingsBlock(nn.Module):
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm
|
||||
norm_layer=nn.LayerNorm,
|
||||
ffn_layer=Mlp,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_parallel = num_parallel
|
||||
@ -345,7 +349,7 @@ class ParallelThingsBlock(nn.Module):
|
||||
])))
|
||||
self.ffns.append(nn.Sequential(OrderedDict([
|
||||
('norm', norm_layer(dim)),
|
||||
('mlp', Mlp(
|
||||
('mlp', ffn_layer(
|
||||
dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
@ -409,6 +413,7 @@ class VisionTransformer(nn.Module):
|
||||
norm_layer: Optional[Callable] = None,
|
||||
act_layer: Optional[Callable] = None,
|
||||
block_fn: Callable = Block,
|
||||
ffn_layer: Callable = Mlp,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -484,7 +489,8 @@ class VisionTransformer(nn.Module):
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer
|
||||
act_layer=act_layer,
|
||||
ffn_layer=ffn_layer,
|
||||
)
|
||||
for i in range(depth)])
|
||||
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
|
||||
@ -808,6 +814,25 @@ def _convert_openai_clip(state_dict, model):
|
||||
return out_dict
|
||||
|
||||
|
||||
def _convert_dinov2(state_dict, model):
|
||||
import re
|
||||
|
||||
out_dict = {}
|
||||
|
||||
for k, v in state_dict.items():
|
||||
if k == "mask_token":
|
||||
continue
|
||||
elif re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k):
|
||||
out_dict[k.replace("w12", "fc1")] = v
|
||||
continue
|
||||
elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k):
|
||||
out_dict[k.replace("w3", "fc2")] = v
|
||||
continue
|
||||
|
||||
out_dict[k] = v
|
||||
|
||||
return out_dict
|
||||
|
||||
def checkpoint_filter_fn(
|
||||
state_dict,
|
||||
model,
|
||||
@ -824,6 +849,9 @@ def checkpoint_filter_fn(
|
||||
if 'visual.class_embedding' in state_dict:
|
||||
return _convert_openai_clip(state_dict, model)
|
||||
|
||||
if "mask_token" in state_dict:
|
||||
return _convert_dinov2(state_dict, model)
|
||||
|
||||
for k, v in state_dict.items():
|
||||
if 'patch_embed.proj.weight' in k:
|
||||
O, I, H, W = model.patch_embed.proj.weight.shape
|
||||
@ -1043,6 +1071,20 @@ default_cfgs = generate_default_cfgs({
|
||||
url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth',
|
||||
hf_hub_id='timm/',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
||||
|
||||
# DINOv2 pretrained - https://arxiv.org/abs/2304.07193 (no classifier head, for fine-tune only)
|
||||
'vit_small_patch14_dinov2': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518)),
|
||||
'vit_base_patch14_dinov2': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518)),
|
||||
'vit_large_patch14_dinov2': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518)),
|
||||
'vit_giant_patch14_dinov2': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518)),
|
||||
|
||||
# ViT ImageNet-21K-P pretraining by MILL
|
||||
'vit_base_patch16_224_miil.in21k': _cfg(
|
||||
@ -1857,6 +1899,66 @@ def vit_huge_patch14_xp_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_patch14_dinov2(pretrained=False, **kwargs):
|
||||
""" ViT-S/14 for DINOv2
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=14, embed_dim=384, depth=12, num_heads=6,
|
||||
init_values=1.0, img_size=518,
|
||||
)
|
||||
|
||||
model = _create_vision_transformer(
|
||||
'vit_small_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch14_dinov2(pretrained=False, **kwargs):
|
||||
""" ViT-B/14 for DINOv2
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=14, embed_dim=768, depth=12, num_heads=12,
|
||||
init_values=1.0, img_size=518,
|
||||
)
|
||||
|
||||
model = _create_vision_transformer(
|
||||
'vit_base_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_large_patch14_dinov2(pretrained=False, **kwargs):
|
||||
""" ViT-L/14 for DINOv2
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=14, embed_dim=1024, depth=24, num_heads=16,
|
||||
init_values=1.0, img_size=518,
|
||||
)
|
||||
|
||||
model = _create_vision_transformer(
|
||||
'vit_large_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
@register_model
|
||||
def vit_giant_patch14_dinov2(pretrained=False, **kwargs):
|
||||
""" ViT-G/14 for DINOv2
|
||||
"""
|
||||
|
||||
# The hidden_features of SwiGLU is calculated by:
|
||||
# hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
||||
# When embed_dim=1536, hidden_features=4096
|
||||
# With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192
|
||||
|
||||
model_args = dict(
|
||||
patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1.0,
|
||||
mlp_ratio=2.66667 * 2, ffn_layer=SwiGLUPacked, img_size=518, act_layer=nn.SiLU
|
||||
)
|
||||
|
||||
model = _create_vision_transformer(
|
||||
'vit_giant_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
register_model_deprecations(__name__, {
|
||||
'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k',
|
||||
'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k',
|
||||
|
Loading…
x
Reference in New Issue
Block a user