mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Supporting aimv2 encoders
This commit is contained in:
parent
790decc89b
commit
e752b5d07c
@ -132,7 +132,8 @@ class SwiGLU(nn.Module):
|
|||||||
|
|
||||||
def init_weights(self):
|
def init_weights(self):
|
||||||
# override init of fc1 w/ gate portion set to weight near zero, bias=1
|
# override init of fc1 w/ gate portion set to weight near zero, bias=1
|
||||||
nn.init.ones_(self.fc1_g.bias)
|
if self.fc1_g.bias is not None:
|
||||||
|
nn.init.ones_(self.fc1_g.bias)
|
||||||
nn.init.normal_(self.fc1_g.weight, std=1e-6)
|
nn.init.normal_(self.fc1_g.weight, std=1e-6)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -44,7 +44,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE
|
|||||||
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||||
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \
|
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \
|
||||||
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
|
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
|
||||||
get_act_layer, get_norm_layer, LayerType
|
SwiGLU, get_act_layer, get_norm_layer, LayerType
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._features import feature_take_indices
|
from ._features import feature_take_indices
|
||||||
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
||||||
@ -65,6 +65,7 @@ class Attention(nn.Module):
|
|||||||
num_heads: int = 8,
|
num_heads: int = 8,
|
||||||
qkv_bias: bool = False,
|
qkv_bias: bool = False,
|
||||||
qk_norm: bool = False,
|
qk_norm: bool = False,
|
||||||
|
proj_bias: bool = True,
|
||||||
attn_drop: float = 0.,
|
attn_drop: float = 0.,
|
||||||
proj_drop: float = 0.,
|
proj_drop: float = 0.,
|
||||||
norm_layer: nn.Module = nn.LayerNorm,
|
norm_layer: nn.Module = nn.LayerNorm,
|
||||||
@ -80,7 +81,7 @@ class Attention(nn.Module):
|
|||||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||||
self.attn_drop = nn.Dropout(attn_drop)
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
self.proj = nn.Linear(dim, dim)
|
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
||||||
self.proj_drop = nn.Dropout(proj_drop)
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@ -130,6 +131,7 @@ class Block(nn.Module):
|
|||||||
mlp_ratio: float = 4.,
|
mlp_ratio: float = 4.,
|
||||||
qkv_bias: bool = False,
|
qkv_bias: bool = False,
|
||||||
qk_norm: bool = False,
|
qk_norm: bool = False,
|
||||||
|
proj_bias: bool = True,
|
||||||
proj_drop: float = 0.,
|
proj_drop: float = 0.,
|
||||||
attn_drop: float = 0.,
|
attn_drop: float = 0.,
|
||||||
init_values: Optional[float] = None,
|
init_values: Optional[float] = None,
|
||||||
@ -145,6 +147,7 @@ class Block(nn.Module):
|
|||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
qk_norm=qk_norm,
|
qk_norm=qk_norm,
|
||||||
|
proj_bias=proj_bias,
|
||||||
attn_drop=attn_drop,
|
attn_drop=attn_drop,
|
||||||
proj_drop=proj_drop,
|
proj_drop=proj_drop,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
@ -157,6 +160,7 @@ class Block(nn.Module):
|
|||||||
in_features=dim,
|
in_features=dim,
|
||||||
hidden_features=int(dim * mlp_ratio),
|
hidden_features=int(dim * mlp_ratio),
|
||||||
act_layer=act_layer,
|
act_layer=act_layer,
|
||||||
|
bias=proj_bias,
|
||||||
drop=proj_drop,
|
drop=proj_drop,
|
||||||
)
|
)
|
||||||
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||||
@ -176,6 +180,7 @@ class ResPostBlock(nn.Module):
|
|||||||
mlp_ratio: float = 4.,
|
mlp_ratio: float = 4.,
|
||||||
qkv_bias: bool = False,
|
qkv_bias: bool = False,
|
||||||
qk_norm: bool = False,
|
qk_norm: bool = False,
|
||||||
|
proj_bias: bool = True,
|
||||||
proj_drop: float = 0.,
|
proj_drop: float = 0.,
|
||||||
attn_drop: float = 0.,
|
attn_drop: float = 0.,
|
||||||
init_values: Optional[float] = None,
|
init_values: Optional[float] = None,
|
||||||
@ -192,6 +197,7 @@ class ResPostBlock(nn.Module):
|
|||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
qk_norm=qk_norm,
|
qk_norm=qk_norm,
|
||||||
|
proj_bias=proj_bias,
|
||||||
attn_drop=attn_drop,
|
attn_drop=attn_drop,
|
||||||
proj_drop=proj_drop,
|
proj_drop=proj_drop,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
@ -203,6 +209,7 @@ class ResPostBlock(nn.Module):
|
|||||||
in_features=dim,
|
in_features=dim,
|
||||||
hidden_features=int(dim * mlp_ratio),
|
hidden_features=int(dim * mlp_ratio),
|
||||||
act_layer=act_layer,
|
act_layer=act_layer,
|
||||||
|
bias=proj_bias,
|
||||||
drop=proj_drop,
|
drop=proj_drop,
|
||||||
)
|
)
|
||||||
self.norm2 = norm_layer(dim)
|
self.norm2 = norm_layer(dim)
|
||||||
@ -236,6 +243,7 @@ class ParallelScalingBlock(nn.Module):
|
|||||||
mlp_ratio: float = 4.,
|
mlp_ratio: float = 4.,
|
||||||
qkv_bias: bool = False,
|
qkv_bias: bool = False,
|
||||||
qk_norm: bool = False,
|
qk_norm: bool = False,
|
||||||
|
proj_bias: bool = True,
|
||||||
proj_drop: float = 0.,
|
proj_drop: float = 0.,
|
||||||
attn_drop: float = 0.,
|
attn_drop: float = 0.,
|
||||||
init_values: Optional[float] = None,
|
init_values: Optional[float] = None,
|
||||||
@ -266,11 +274,11 @@ class ParallelScalingBlock(nn.Module):
|
|||||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||||
self.attn_drop = nn.Dropout(attn_drop)
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
self.attn_out_proj = nn.Linear(dim, dim)
|
self.attn_out_proj = nn.Linear(dim, dim, bias=proj_bias)
|
||||||
|
|
||||||
self.mlp_drop = nn.Dropout(proj_drop)
|
self.mlp_drop = nn.Dropout(proj_drop)
|
||||||
self.mlp_act = act_layer()
|
self.mlp_act = act_layer()
|
||||||
self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim)
|
self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim, bias=proj_bias)
|
||||||
|
|
||||||
self.ls = LayerScale(dim, init_values=init_values) if init_values is not None else nn.Identity()
|
self.ls = LayerScale(dim, init_values=init_values) if init_values is not None else nn.Identity()
|
||||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
@ -330,6 +338,7 @@ class ParallelThingsBlock(nn.Module):
|
|||||||
mlp_ratio: float = 4.,
|
mlp_ratio: float = 4.,
|
||||||
qkv_bias: bool = False,
|
qkv_bias: bool = False,
|
||||||
qk_norm: bool = False,
|
qk_norm: bool = False,
|
||||||
|
proj_bias: bool = True,
|
||||||
init_values: Optional[float] = None,
|
init_values: Optional[float] = None,
|
||||||
proj_drop: float = 0.,
|
proj_drop: float = 0.,
|
||||||
attn_drop: float = 0.,
|
attn_drop: float = 0.,
|
||||||
@ -350,6 +359,7 @@ class ParallelThingsBlock(nn.Module):
|
|||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
qk_norm=qk_norm,
|
qk_norm=qk_norm,
|
||||||
|
proj_bias=proj_bias,
|
||||||
attn_drop=attn_drop,
|
attn_drop=attn_drop,
|
||||||
proj_drop=proj_drop,
|
proj_drop=proj_drop,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
@ -363,6 +373,7 @@ class ParallelThingsBlock(nn.Module):
|
|||||||
dim,
|
dim,
|
||||||
hidden_features=int(dim * mlp_ratio),
|
hidden_features=int(dim * mlp_ratio),
|
||||||
act_layer=act_layer,
|
act_layer=act_layer,
|
||||||
|
bias=proj_bias,
|
||||||
drop=proj_drop,
|
drop=proj_drop,
|
||||||
)),
|
)),
|
||||||
('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
|
('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
|
||||||
@ -433,6 +444,7 @@ class VisionTransformer(nn.Module):
|
|||||||
mlp_ratio: float = 4.,
|
mlp_ratio: float = 4.,
|
||||||
qkv_bias: bool = True,
|
qkv_bias: bool = True,
|
||||||
qk_norm: bool = False,
|
qk_norm: bool = False,
|
||||||
|
proj_bias: bool = True,
|
||||||
init_values: Optional[float] = None,
|
init_values: Optional[float] = None,
|
||||||
class_token: bool = True,
|
class_token: bool = True,
|
||||||
pos_embed: str = 'learn',
|
pos_embed: str = 'learn',
|
||||||
@ -452,6 +464,7 @@ class VisionTransformer(nn.Module):
|
|||||||
weight_init: Literal['skip', 'jax', 'jax_nlhb', 'moco', ''] = '',
|
weight_init: Literal['skip', 'jax', 'jax_nlhb', 'moco', ''] = '',
|
||||||
fix_init: bool = False,
|
fix_init: bool = False,
|
||||||
embed_layer: Callable = PatchEmbed,
|
embed_layer: Callable = PatchEmbed,
|
||||||
|
embed_norm_layer: Optional[LayerType] = None,
|
||||||
norm_layer: Optional[LayerType] = None,
|
norm_layer: Optional[LayerType] = None,
|
||||||
act_layer: Optional[LayerType] = None,
|
act_layer: Optional[LayerType] = None,
|
||||||
block_fn: Type[nn.Module] = Block,
|
block_fn: Type[nn.Module] = Block,
|
||||||
@ -483,6 +496,7 @@ class VisionTransformer(nn.Module):
|
|||||||
weight_init: Weight initialization scheme.
|
weight_init: Weight initialization scheme.
|
||||||
fix_init: Apply weight initialization fix (scaling w/ layer index).
|
fix_init: Apply weight initialization fix (scaling w/ layer index).
|
||||||
embed_layer: Patch embedding layer.
|
embed_layer: Patch embedding layer.
|
||||||
|
embed_norm_layer: Normalization layer to use / override in patch embed module.
|
||||||
norm_layer: Normalization layer.
|
norm_layer: Normalization layer.
|
||||||
act_layer: MLP activation layer.
|
act_layer: MLP activation layer.
|
||||||
block_fn: Transformer block layer.
|
block_fn: Transformer block layer.
|
||||||
@ -493,6 +507,7 @@ class VisionTransformer(nn.Module):
|
|||||||
assert pos_embed in ('', 'none', 'learn')
|
assert pos_embed in ('', 'none', 'learn')
|
||||||
use_fc_norm = global_pool in ('avg', 'avgmax', 'max') if fc_norm is None else fc_norm
|
use_fc_norm = global_pool in ('avg', 'avgmax', 'max') if fc_norm is None else fc_norm
|
||||||
norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
|
norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
|
||||||
|
embed_norm_layer = get_norm_layer(embed_norm_layer)
|
||||||
act_layer = get_act_layer(act_layer) or nn.GELU
|
act_layer = get_act_layer(act_layer) or nn.GELU
|
||||||
|
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
@ -510,6 +525,8 @@ class VisionTransformer(nn.Module):
|
|||||||
if dynamic_img_size:
|
if dynamic_img_size:
|
||||||
# flatten deferred until after pos embed
|
# flatten deferred until after pos embed
|
||||||
embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
|
embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
|
||||||
|
if embed_norm_layer is not None:
|
||||||
|
embed_args['norm_layer'] = embed_norm_layer
|
||||||
self.patch_embed = embed_layer(
|
self.patch_embed = embed_layer(
|
||||||
img_size=img_size,
|
img_size=img_size,
|
||||||
patch_size=patch_size,
|
patch_size=patch_size,
|
||||||
@ -539,7 +556,7 @@ class VisionTransformer(nn.Module):
|
|||||||
self.patch_drop = nn.Identity()
|
self.patch_drop = nn.Identity()
|
||||||
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
|
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
|
||||||
|
|
||||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth, device='cpu')] # stochastic depth decay rule
|
||||||
self.blocks = nn.Sequential(*[
|
self.blocks = nn.Sequential(*[
|
||||||
block_fn(
|
block_fn(
|
||||||
dim=embed_dim,
|
dim=embed_dim,
|
||||||
@ -547,6 +564,7 @@ class VisionTransformer(nn.Module):
|
|||||||
mlp_ratio=mlp_ratio,
|
mlp_ratio=mlp_ratio,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
qk_norm=qk_norm,
|
qk_norm=qk_norm,
|
||||||
|
proj_bias=proj_bias,
|
||||||
init_values=init_values,
|
init_values=init_values,
|
||||||
proj_drop=proj_drop_rate,
|
proj_drop=proj_drop_rate,
|
||||||
attn_drop=attn_drop_rate,
|
attn_drop=attn_drop_rate,
|
||||||
@ -1128,6 +1146,31 @@ def _convert_dinov2(
|
|||||||
return out_dict
|
return out_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_aimv2(
|
||||||
|
state_dict: Dict[str, torch.Tensor],
|
||||||
|
model: VisionTransformer,
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
#import re
|
||||||
|
out_dict = {}
|
||||||
|
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
k = k.replace('norm_1', 'norm1')
|
||||||
|
k = k.replace('norm_2', 'norm2')
|
||||||
|
k = k.replace('preprocessor.patchifier.', 'patch_embed.')
|
||||||
|
k = k.replace('preprocessor.pos_embed', 'pos_embed')
|
||||||
|
k = k.replace('trunk.', '')
|
||||||
|
k = k.replace('mlp.fc1', 'mlp.fc1_g')
|
||||||
|
k = k.replace('mlp.fc3', 'mlp.fc1_x')
|
||||||
|
k = k.replace('post_trunk_norm.', 'norm.')
|
||||||
|
# if re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k):
|
||||||
|
# out_dict[k.replace("w12", "fc1")] = v
|
||||||
|
# continue
|
||||||
|
# elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k):
|
||||||
|
# out_dict[k.replace("w3", "fc2")] = v
|
||||||
|
# continue
|
||||||
|
out_dict[k] = v
|
||||||
|
return out_dict
|
||||||
|
|
||||||
def checkpoint_filter_fn(
|
def checkpoint_filter_fn(
|
||||||
state_dict: Dict[str, torch.Tensor],
|
state_dict: Dict[str, torch.Tensor],
|
||||||
model: VisionTransformer,
|
model: VisionTransformer,
|
||||||
@ -1159,6 +1202,8 @@ def checkpoint_filter_fn(
|
|||||||
# remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
|
# remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
|
||||||
out_dict['head.weight'] = state_dict['visual.head.proj.weight']
|
out_dict['head.weight'] = state_dict['visual.head.proj.weight']
|
||||||
out_dict['head.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
|
out_dict['head.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
|
||||||
|
elif 'preprocessor.patchifier.proj.weight' in state_dict:
|
||||||
|
state_dict = _convert_aimv2(state_dict, model)
|
||||||
|
|
||||||
if prefix:
|
if prefix:
|
||||||
# filter on & remove prefix string from keys
|
# filter on & remove prefix string from keys
|
||||||
@ -2119,6 +2164,12 @@ default_cfgs = {
|
|||||||
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0,
|
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0,
|
||||||
),
|
),
|
||||||
|
|
||||||
|
'vit_large_patch14_aimv2_224': _cfg(
|
||||||
|
hf_hub_id='apple/aimv2-large-patch14-224',
|
||||||
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||||
|
input_size=(3, 224, 224), crop_pct=1.0,
|
||||||
|
num_classes=0),
|
||||||
|
|
||||||
'test_vit.r160_in1k': _cfg(
|
'test_vit.r160_in1k': _cfg(
|
||||||
hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
input_size=(3, 160, 160), crop_pct=0.95),
|
input_size=(3, 160, 160), crop_pct=0.95),
|
||||||
@ -3390,6 +3441,21 @@ def vit_intern300m_patch14_448(pretrained: bool = False, **kwargs) -> VisionTran
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_large_patch14_aimv2_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||||
|
""" ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled.
|
||||||
|
"""
|
||||||
|
rms_norm = partial(RmsNorm, eps=1e-5)
|
||||||
|
model_args = dict(
|
||||||
|
patch_size=14, embed_dim=1024, depth=24, num_heads=16, class_token=False, fc_norm=False,
|
||||||
|
mlp_ratio=2.75, global_pool='avg', norm_layer=rms_norm, embed_norm_layer=rms_norm, mlp_layer=SwiGLU,
|
||||||
|
qkv_bias=False, proj_bias=False,
|
||||||
|
)
|
||||||
|
model = _create_vision_transformer(
|
||||||
|
'vit_large_patch14_aimv2_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def test_vit(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
def test_vit(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||||
""" ViT Test
|
""" ViT Test
|
||||||
|
Loading…
x
Reference in New Issue
Block a user