Merge branch 'main' into dot_nine_cleanup

This commit is contained in:
Ross Wightman 2023-05-09 12:27:32 -07:00
commit 59bea4c306
3 changed files with 118 additions and 13 deletions

View File

@ -27,7 +27,7 @@ from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible,
from .inplace_abn import InplaceAbn from .inplace_abn import InplaceAbn
from .linear import Linear from .linear import Linear
from .mixed_conv2d import MixedConv2d from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, ConvMlp, GlobalResponseNormMlp from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp
from .non_local_attn import NonLocalAttn, BatNonLocalAttn from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm
from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\ from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\

View File

@ -97,6 +97,9 @@ class GluMlp(nn.Module):
return x return x
SwiGLUPacked = partial(GluMlp, act_layer=nn.SiLU, gate_last=False)
class SwiGLU(nn.Module): class SwiGLU(nn.Module):
""" SwiGLU """ SwiGLU
NOTE: GluMLP above can implement SwiGLU, but this impl has split fc1 and NOTE: GluMLP above can implement SwiGLU, but this impl has split fc1 and
@ -108,7 +111,7 @@ class SwiGLU(nn.Module):
hidden_features=None, hidden_features=None,
out_features=None, out_features=None,
act_layer=nn.SiLU, act_layer=nn.SiLU,
norm_layer=nn.LayerNorm, norm_layer=None,
bias=True, bias=True,
drop=0., drop=0.,
): ):
@ -130,8 +133,8 @@ class SwiGLU(nn.Module):
def init_weights(self): def init_weights(self):
# override init of fc1 w/ gate portion set to weight near zero, bias=1 # override init of fc1 w/ gate portion set to weight near zero, bias=1
nn.init.ones_(self.fc1a.bias) nn.init.ones_(self.fc1_g.bias)
nn.init.normal_(self.fc1a.weight, std=1e-6) nn.init.normal_(self.fc1_g.weight, std=1e-6)
def forward(self, x): def forward(self, x):
x_gate = self.fc1_g(x) x_gate = self.fc1_g(x)

View File

@ -38,7 +38,7 @@ from torch.jit import Final
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \ from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \
resample_abs_pos_embed, RmsNorm, PatchDropout, use_fused_attn resample_abs_pos_embed, RmsNorm, PatchDropout, use_fused_attn, SwiGLUPacked
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
from ._registry import generate_default_cfgs, register_model, register_model_deprecations from ._registry import generate_default_cfgs, register_model, register_model_deprecations
@ -124,7 +124,8 @@ class Block(nn.Module):
init_values=None, init_values=None,
drop_path=0., drop_path=0.,
act_layer=nn.GELU, act_layer=nn.GELU,
norm_layer=nn.LayerNorm norm_layer=nn.LayerNorm,
ffn_layer=Mlp,
): ):
super().__init__() super().__init__()
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
@ -141,7 +142,7 @@ class Block(nn.Module):
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
self.mlp = Mlp( self.mlp = ffn_layer(
in_features=dim, in_features=dim,
hidden_features=int(dim * mlp_ratio), hidden_features=int(dim * mlp_ratio),
act_layer=act_layer, act_layer=act_layer,
@ -170,7 +171,8 @@ class ResPostBlock(nn.Module):
init_values=None, init_values=None,
drop_path=0., drop_path=0.,
act_layer=nn.GELU, act_layer=nn.GELU,
norm_layer=nn.LayerNorm norm_layer=nn.LayerNorm,
ffn_layer=Mlp,
): ):
super().__init__() super().__init__()
self.init_values = init_values self.init_values = init_values
@ -187,7 +189,7 @@ class ResPostBlock(nn.Module):
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.mlp = Mlp( self.mlp = ffn_layer(
in_features=dim, in_features=dim,
hidden_features=int(dim * mlp_ratio), hidden_features=int(dim * mlp_ratio),
act_layer=act_layer, act_layer=act_layer,
@ -229,7 +231,8 @@ class ParallelScalingBlock(nn.Module):
init_values=None, init_values=None,
drop_path=0., drop_path=0.,
act_layer=nn.GELU, act_layer=nn.GELU,
norm_layer=nn.LayerNorm norm_layer=nn.LayerNorm,
ffn_layer=None, # NOTE: not used
): ):
super().__init__() super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads' assert dim % num_heads == 0, 'dim should be divisible by num_heads'
@ -322,7 +325,8 @@ class ParallelThingsBlock(nn.Module):
attn_drop=0., attn_drop=0.,
drop_path=0., drop_path=0.,
act_layer=nn.GELU, act_layer=nn.GELU,
norm_layer=nn.LayerNorm norm_layer=nn.LayerNorm,
ffn_layer=Mlp,
): ):
super().__init__() super().__init__()
self.num_parallel = num_parallel self.num_parallel = num_parallel
@ -345,7 +349,7 @@ class ParallelThingsBlock(nn.Module):
]))) ])))
self.ffns.append(nn.Sequential(OrderedDict([ self.ffns.append(nn.Sequential(OrderedDict([
('norm', norm_layer(dim)), ('norm', norm_layer(dim)),
('mlp', Mlp( ('mlp', ffn_layer(
dim, dim,
hidden_features=int(dim * mlp_ratio), hidden_features=int(dim * mlp_ratio),
act_layer=act_layer, act_layer=act_layer,
@ -409,6 +413,7 @@ class VisionTransformer(nn.Module):
norm_layer: Optional[Callable] = None, norm_layer: Optional[Callable] = None,
act_layer: Optional[Callable] = None, act_layer: Optional[Callable] = None,
block_fn: Callable = Block, block_fn: Callable = Block,
ffn_layer: Callable = Mlp,
): ):
""" """
Args: Args:
@ -484,7 +489,8 @@ class VisionTransformer(nn.Module):
attn_drop=attn_drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[i], drop_path=dpr[i],
norm_layer=norm_layer, norm_layer=norm_layer,
act_layer=act_layer act_layer=act_layer,
ffn_layer=ffn_layer,
) )
for i in range(depth)]) for i in range(depth)])
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
@ -808,6 +814,25 @@ def _convert_openai_clip(state_dict, model):
return out_dict return out_dict
def _convert_dinov2(state_dict, model):
import re
out_dict = {}
for k, v in state_dict.items():
if k == "mask_token":
continue
elif re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k):
out_dict[k.replace("w12", "fc1")] = v
continue
elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k):
out_dict[k.replace("w3", "fc2")] = v
continue
out_dict[k] = v
return out_dict
def checkpoint_filter_fn( def checkpoint_filter_fn(
state_dict, state_dict,
model, model,
@ -824,6 +849,9 @@ def checkpoint_filter_fn(
if 'visual.class_embedding' in state_dict: if 'visual.class_embedding' in state_dict:
return _convert_openai_clip(state_dict, model) return _convert_openai_clip(state_dict, model)
if "mask_token" in state_dict:
return _convert_dinov2(state_dict, model)
for k, v in state_dict.items(): for k, v in state_dict.items():
if 'patch_embed.proj.weight' in k: if 'patch_embed.proj.weight' in k:
O, I, H, W = model.patch_embed.proj.weight.shape O, I, H, W = model.patch_embed.proj.weight.shape
@ -1043,6 +1071,20 @@ default_cfgs = generate_default_cfgs({
url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth', url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth',
hf_hub_id='timm/', hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
# DINOv2 pretrained - https://arxiv.org/abs/2304.07193 (no classifier head, for fine-tune only)
'vit_small_patch14_dinov2': _cfg(
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518)),
'vit_base_patch14_dinov2': _cfg(
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518)),
'vit_large_patch14_dinov2': _cfg(
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518)),
'vit_giant_patch14_dinov2': _cfg(
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518)),
# ViT ImageNet-21K-P pretraining by MILL # ViT ImageNet-21K-P pretraining by MILL
'vit_base_patch16_224_miil.in21k': _cfg( 'vit_base_patch16_224_miil.in21k': _cfg(
@ -1857,6 +1899,66 @@ def vit_huge_patch14_xp_224(pretrained=False, **kwargs) -> VisionTransformer:
return model return model
@register_model
def vit_small_patch14_dinov2(pretrained=False, **kwargs):
""" ViT-S/14 for DINOv2
"""
model_args = dict(
patch_size=14, embed_dim=384, depth=12, num_heads=6,
init_values=1.0, img_size=518,
)
model = _create_vision_transformer(
'vit_small_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_base_patch14_dinov2(pretrained=False, **kwargs):
""" ViT-B/14 for DINOv2
"""
model_args = dict(
patch_size=14, embed_dim=768, depth=12, num_heads=12,
init_values=1.0, img_size=518,
)
model = _create_vision_transformer(
'vit_base_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_large_patch14_dinov2(pretrained=False, **kwargs):
""" ViT-L/14 for DINOv2
"""
model_args = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=16,
init_values=1.0, img_size=518,
)
model = _create_vision_transformer(
'vit_large_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_giant_patch14_dinov2(pretrained=False, **kwargs):
""" ViT-G/14 for DINOv2
"""
# The hidden_features of SwiGLU is calculated by:
# hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
# When embed_dim=1536, hidden_features=4096
# With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192
model_args = dict(
patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1.0,
mlp_ratio=2.66667 * 2, ffn_layer=SwiGLUPacked, img_size=518, act_layer=nn.SiLU
)
model = _create_vision_transformer(
'vit_giant_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model
register_model_deprecations(__name__, { register_model_deprecations(__name__, {
'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k', 'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k',
'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k', 'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k',