mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
eva02 non-CLIP weights on HF hub, add initial eva02 clip model configs w/ postnorm variant & attn LN
This commit is contained in:
parent
ac67098147
commit
0737bd3ec8
@ -316,9 +316,17 @@ def generate_readme(model_card: dict, model_name: str):
|
||||
readme_text += f"license: {model_card.get('license', 'apache-2.0')}\n"
|
||||
if 'details' in model_card and 'Dataset' in model_card['details']:
|
||||
readme_text += 'datasets:\n'
|
||||
readme_text += f"- {model_card['details']['Dataset'].lower()}\n"
|
||||
if isinstance(model_card['details']['Dataset'], (tuple, list)):
|
||||
for d in model_card['details']['Dataset']:
|
||||
readme_text += f"- {d.lower()}\n"
|
||||
else:
|
||||
readme_text += f"- {model_card['details']['Dataset'].lower()}\n"
|
||||
if 'Pretrain Dataset' in model_card['details']:
|
||||
readme_text += f"- {model_card['details']['Pretrain Dataset'].lower()}\n"
|
||||
if isinstance(model_card['details']['Pretrain Dataset'], (tuple, list)):
|
||||
for d in model_card['details']['Pretrain Dataset']:
|
||||
readme_text += f"- {d.lower()}\n"
|
||||
else:
|
||||
readme_text += f"- {model_card['details']['Pretrain Dataset'].lower()}\n"
|
||||
readme_text += "---\n"
|
||||
readme_text += f"# Model card for {model_name}\n"
|
||||
if 'description' in model_card:
|
||||
|
@ -55,7 +55,20 @@ class EvaAttention(nn.Module):
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
attn_head_dim: Optional[int] = None,
|
||||
norm_layer: Optional[Callable] = None,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
dim:
|
||||
num_heads:
|
||||
qkv_bias:
|
||||
qkv_fused:
|
||||
attn_drop:
|
||||
proj_drop:
|
||||
attn_head_dim:
|
||||
norm_layer:
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
@ -82,6 +95,7 @@ class EvaAttention(nn.Module):
|
||||
self.q_bias = self.k_bias = self.v_bias = None
|
||||
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.norm = norm_layer(all_head_dim) if norm_layer is not None else nn.Identity()
|
||||
self.proj = nn.Linear(all_head_dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
@ -124,6 +138,7 @@ class EvaAttention(nn.Module):
|
||||
x = attn @ v
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.norm(x)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
@ -138,16 +153,36 @@ class EvaBlock(nn.Module):
|
||||
qkv_bias: bool = True,
|
||||
qkv_fused: bool = True,
|
||||
mlp_ratio: float = 4.,
|
||||
scale_mlp: bool = False,
|
||||
swiglu_mlp: bool = False,
|
||||
scale_mlp: bool = False,
|
||||
scale_attn_inner: bool = False,
|
||||
proj_drop: float = 0.,
|
||||
attn_drop: float = 0.,
|
||||
drop_path: float = 0.,
|
||||
init_values: Optional[float] = None,
|
||||
act_layer: Callable = nn.GELU,
|
||||
norm_layer: Callable = nn.LayerNorm,
|
||||
norm_layer: Callable = LayerNorm,
|
||||
attn_head_dim: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
dim:
|
||||
num_heads:
|
||||
qkv_bias:
|
||||
qkv_fused:
|
||||
mlp_ratio:
|
||||
swiglu_mlp:
|
||||
scale_mlp:
|
||||
scale_attn_inner:
|
||||
proj_drop:
|
||||
attn_drop:
|
||||
drop_path:
|
||||
init_values:
|
||||
act_layer:
|
||||
norm_layer:
|
||||
attn_head_dim:
|
||||
"""
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = EvaAttention(
|
||||
@ -158,6 +193,7 @@ class EvaBlock(nn.Module):
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
attn_head_dim=attn_head_dim,
|
||||
norm_layer=norm_layer if scale_attn_inner else None,
|
||||
)
|
||||
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim)) if init_values is not None else None
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
@ -204,6 +240,96 @@ class EvaBlock(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class EvaBlockPostNorm(nn.Module):
|
||||
""" EVA block w/ post-norm and support for swiglu, MLP norm scale, ROPE. """
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
qkv_bias: bool = True,
|
||||
qkv_fused: bool = True,
|
||||
mlp_ratio: float = 4.,
|
||||
swiglu_mlp: bool = False,
|
||||
scale_mlp: bool = False,
|
||||
scale_attn_inner: bool = False,
|
||||
proj_drop: float = 0.,
|
||||
attn_drop: float = 0.,
|
||||
drop_path: float = 0.,
|
||||
init_values: Optional[float] = None, # ignore for post-norm
|
||||
act_layer: Callable = nn.GELU,
|
||||
norm_layer: Callable = nn.LayerNorm,
|
||||
attn_head_dim: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
dim:
|
||||
num_heads:
|
||||
qkv_bias:
|
||||
qkv_fused:
|
||||
mlp_ratio:
|
||||
swiglu_mlp:
|
||||
scale_mlp:
|
||||
scale_attn_inner:
|
||||
proj_drop:
|
||||
attn_drop:
|
||||
drop_path:
|
||||
init_values:
|
||||
act_layer:
|
||||
norm_layer:
|
||||
attn_head_dim:
|
||||
"""
|
||||
super().__init__()
|
||||
self.attn = EvaAttention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qkv_fused=qkv_fused,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
attn_head_dim=attn_head_dim,
|
||||
norm_layer=norm_layer if scale_attn_inner else None,
|
||||
)
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
hidden_features = int(dim * mlp_ratio)
|
||||
if swiglu_mlp:
|
||||
if scale_mlp:
|
||||
# when norm in SwiGLU used, an impl with separate fc for gate & x is used
|
||||
self.mlp = SwiGLU(
|
||||
in_features=dim,
|
||||
hidden_features=hidden_features,
|
||||
norm_layer=norm_layer if scale_mlp else None,
|
||||
drop=proj_drop,
|
||||
)
|
||||
else:
|
||||
# w/o any extra norm, an impl with packed fc1 weights is used, matches existing GluMLP
|
||||
self.mlp = GluMlp(
|
||||
in_features=dim,
|
||||
hidden_features=hidden_features * 2,
|
||||
norm_layer=norm_layer if scale_mlp else None,
|
||||
act_layer=nn.SiLU,
|
||||
gate_last=False,
|
||||
drop=proj_drop,
|
||||
)
|
||||
else:
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=hidden_features,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer if scale_mlp else None,
|
||||
drop=proj_drop,
|
||||
)
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x, rope: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None):
|
||||
x = x + self.drop_path1(self.norm1(self.attn(x, rope=rope, attn_mask=attn_mask)))
|
||||
x = x + self.drop_path2(self.norm2(self.mlp(x)))
|
||||
return x
|
||||
|
||||
|
||||
class Eva(nn.Module):
|
||||
""" Eva Vision Transformer w/ Abs & Rotary Pos Embed
|
||||
|
||||
@ -227,6 +353,7 @@ class Eva(nn.Module):
|
||||
mlp_ratio: float = 4.,
|
||||
swiglu_mlp: bool = False,
|
||||
scale_mlp: bool = False,
|
||||
scale_attn_inner: bool = False,
|
||||
drop_rate: float = 0.,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
@ -234,9 +361,38 @@ class Eva(nn.Module):
|
||||
init_values: Optional[float] = None,
|
||||
use_abs_pos_emb: bool = True,
|
||||
use_rot_pos_emb: bool = False,
|
||||
use_post_norm: bool = False,
|
||||
ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None,
|
||||
head_init_scale: float = 0.001,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
img_size:
|
||||
patch_size:
|
||||
in_chans:
|
||||
num_classes:
|
||||
global_pool:
|
||||
embed_dim:
|
||||
depth:
|
||||
num_heads:
|
||||
qkv_bias:
|
||||
qkv_fused:
|
||||
mlp_ratio:
|
||||
swiglu_mlp:
|
||||
scale_mlp:
|
||||
scale_attn_inner:
|
||||
drop_rate:
|
||||
attn_drop_rate:
|
||||
drop_path_rate:
|
||||
norm_layer:
|
||||
init_values:
|
||||
use_abs_pos_emb:
|
||||
use_rot_pos_emb:
|
||||
use_post_norm:
|
||||
ref_feat_shape:
|
||||
head_init_scale:
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = global_pool
|
||||
@ -268,15 +424,17 @@ class Eva(nn.Module):
|
||||
self.rope = None
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
block_fn = EvaBlockPostNorm if use_post_norm else EvaBlock
|
||||
self.blocks = nn.ModuleList([
|
||||
EvaBlock(
|
||||
block_fn(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qkv_fused=qkv_fused,
|
||||
mlp_ratio=mlp_ratio,
|
||||
scale_mlp=scale_mlp,
|
||||
swiglu_mlp=swiglu_mlp,
|
||||
scale_mlp=scale_mlp,
|
||||
scale_attn_inner=scale_attn_inner,
|
||||
proj_drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
@ -387,10 +545,24 @@ def checkpoint_filter_fn(
|
||||
state_dict = state_dict.get('model', state_dict)
|
||||
state_dict = state_dict.get('module', state_dict)
|
||||
state_dict = state_dict.get('state_dict', state_dict)
|
||||
no_qkv = 'blocks.0.attn.q_proj.weight' in state_dict
|
||||
mim_weights = 'mask_token' in state_dict
|
||||
# prefix for loading OpenCLIP compatible weights
|
||||
if 'visual.trunk.pos_embed' in state_dict:
|
||||
prefix = 'visual.trunk.'
|
||||
elif 'visual.pos_embed' in state_dict:
|
||||
prefix = 'visual.'
|
||||
else:
|
||||
prefix = ''
|
||||
mim_weights = prefix + 'mask_token' in state_dict
|
||||
no_qkv = prefix + 'blocks.0.attn.q_proj.weight' in state_dict
|
||||
|
||||
len_prefix = len(prefix)
|
||||
for k, v in state_dict.items():
|
||||
if prefix:
|
||||
if k.startswith(prefix):
|
||||
k = k[len_prefix:]
|
||||
else:
|
||||
continue
|
||||
|
||||
if 'rope' in k:
|
||||
# fixed embedding no need to load buffer from checkpoint
|
||||
continue
|
||||
@ -418,6 +590,7 @@ def checkpoint_filter_fn(
|
||||
)
|
||||
|
||||
k = k.replace('mlp.ffn_ln', 'mlp.norm')
|
||||
k = k.replace('attn.inner_attn_ln', 'attn.norm')
|
||||
k = k.replace('mlp.w12', 'mlp.fc1')
|
||||
k = k.replace('mlp.w1', 'mlp.fc1_g')
|
||||
k = k.replace('mlp.w2', 'mlp.fc1_x')
|
||||
@ -457,12 +630,13 @@ def _cfg(url='', **kwargs):
|
||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': OPENAI_CLIP_MEAN, 'std': OPENAI_CLIP_STD,
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
||||
**kwargs
|
||||
'license': 'mit', **kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
|
||||
# EVA 01 CLIP fine-tuned on imagenet-1k
|
||||
'eva_giant_patch14_224.clip_ft_in1k': _cfg(
|
||||
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt',
|
||||
hf_hub_id='timm/',
|
||||
@ -471,6 +645,8 @@ default_cfgs = generate_default_cfgs({
|
||||
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
|
||||
|
||||
# MIM EVA 01 pretrain, ft on in22k -> in1k
|
||||
'eva_giant_patch14_336.m30m_ft_in22k_in1k': _cfg(
|
||||
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt',
|
||||
hf_hub_id='timm/',
|
||||
@ -482,79 +658,113 @@ default_cfgs = generate_default_cfgs({
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
|
||||
input_size=(3, 560, 560), crop_pct=1.0, crop_mode='squash'),
|
||||
|
||||
# in22k or m38m MIM pretrain w/ intermediate in22k fine-tune and final in1k fine-tune
|
||||
'eva02_base_patch14_448.mim_in22k_ft_in22k_in1k': _cfg(
|
||||
hf_hub_id='Yuxin-CV/EVA-02',
|
||||
hf_hub_filename='eva02/cls/in21k_to_in1k/eva02_B_pt_in21k_medft_in21k_ft_in1k_p14.pt',
|
||||
# hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k_to_in1k/eva02_B_pt_in21k_medft_in21k_ft_in1k_p14.pt',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash',
|
||||
),
|
||||
'eva02_large_patch14_448.mim_in22k_ft_in22k_in1k': _cfg(
|
||||
hf_hub_id='Yuxin-CV/EVA-02',
|
||||
hf_hub_filename='eva02/cls/in21k_to_in1k/eva02_L_pt_in21k_medft_in21k_ft_in1k_p14.pt',
|
||||
# hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k_to_in1k/eva02_L_pt_in21k_medft_in21k_ft_in1k_p14.pt',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash',
|
||||
),
|
||||
'eva02_large_patch14_448.mim_m38m_ft_in22k_in1k': _cfg(
|
||||
hf_hub_id='Yuxin-CV/EVA-02',
|
||||
hf_hub_filename='eva02/cls/in21k_to_in1k/eva02_L_pt_m38m_medft_in21k_ft_in1k_p14.pt',
|
||||
hf_hub_id='timm/',
|
||||
#hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k_to_in1k/eva02_L_pt_m38m_medft_in21k_ft_in1k_p14.pt',
|
||||
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash',
|
||||
),
|
||||
|
||||
# in22k or m3m MIM pretrain w/ in1k fine-tune
|
||||
'eva02_tiny_patch14_336.mim_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='Yuxin-CV/EVA-02',
|
||||
hf_hub_filename='eva02/cls/in1k/eva02_Ti_pt_in21k_ft_in1k_p14.pt',
|
||||
#hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in1k/eva02_Ti_pt_in21k_ft_in1k_p14.pt',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 336, 336), crop_pct=1.0,
|
||||
),
|
||||
'eva02_small_patch14_336.mim_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='Yuxin-CV/EVA-02',
|
||||
hf_hub_filename='eva02/cls/in1k/eva02_S_pt_in21k_ft_in1k_p14.pt',
|
||||
#hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in1k/eva02_S_pt_in21k_ft_in1k_p14.pt',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 336, 336), crop_pct=1.0,
|
||||
),
|
||||
'eva02_base_patch14_448.mim_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='Yuxin-CV/EVA-02',
|
||||
hf_hub_filename='eva02/cls/in1k/eva02_B_pt_in21k_ft_in1k_p14.pt',
|
||||
#hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in1k/eva02_B_pt_in21k_ft_in1k_p14.pt',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 448, 448), crop_pct=1.0,
|
||||
),
|
||||
'eva02_large_patch14_448.mim_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='Yuxin-CV/EVA-02',
|
||||
hf_hub_filename='eva02/cls/in1k/eva02_L_pt_in21k_ft_in1k_p14.pt',
|
||||
#hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in1k/eva02_L_pt_in21k_ft_in1k_p14.pt',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 448, 448), crop_pct=1.0,
|
||||
),
|
||||
'eva02_large_patch14_448.mim_m38m_ft_in1k': _cfg(
|
||||
hf_hub_id='Yuxin-CV/EVA-02',
|
||||
hf_hub_filename='eva02/cls/in1k/eva02_L_pt_m38m_ft_in1k_p14.pt',
|
||||
#hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in1k/eva02_L_pt_m38m_ft_in1k_p14.pt',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 448, 448), crop_pct=1.0,
|
||||
),
|
||||
|
||||
# in22k or m3m MIM pretrain w/ in22k fine-tune
|
||||
'eva02_base_patch14_448.mim_in22k_ft_in22k': _cfg(
|
||||
hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k/eva02_B_pt_in21k_medft_in21k_p14.pt',
|
||||
#hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k/eva02_B_pt_in21k_medft_in21k_p14.pt',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash', num_classes=21841,
|
||||
),
|
||||
'eva02_large_patch14_448.mim_in22k_ft_in22k': _cfg(
|
||||
hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k/eva02_L_pt_in21k_medft_in21k_p14.pt',
|
||||
#hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k/eva02_L_pt_in21k_medft_in21k_p14.pt',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash', num_classes=21841,
|
||||
),
|
||||
'eva02_large_patch14_448.mim_m38m_ft_in22k': _cfg(
|
||||
hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k/eva02_L_pt_m38m_medft_in21k_p14.pt',
|
||||
#hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k/eva02_L_pt_m38m_medft_in21k_p14.pt',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash', num_classes=21841,
|
||||
),
|
||||
|
||||
# in22k or m38m MIM pretrain
|
||||
'eva02_tiny_patch14_224.mim_in22k': _cfg(
|
||||
hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_Ti_pt_in21k_p14.pt',
|
||||
# hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_Ti_pt_in21k_p14.pt',
|
||||
hf_hub_id='timm/',
|
||||
num_classes=0,
|
||||
),
|
||||
'eva02_small_patch14_224.mim_in22k': _cfg(
|
||||
hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_S_pt_in21k_p14.pt',
|
||||
#hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_S_pt_in21k_p14.pt',
|
||||
hf_hub_id='timm/',
|
||||
num_classes=0,
|
||||
),
|
||||
'eva02_base_patch14_224.mim_in22k': _cfg(
|
||||
hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_B_pt_in21k_p14.pt',
|
||||
#hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_B_pt_in21k_p14.pt',
|
||||
hf_hub_id='timm/',
|
||||
num_classes=0,
|
||||
),
|
||||
'eva02_large_patch14_224.mim_in22k': _cfg(
|
||||
hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_L_pt_in21k_p14.pt',
|
||||
#hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_L_pt_in21k_p14.pt',
|
||||
hf_hub_id='timm/',
|
||||
num_classes=0,
|
||||
),
|
||||
'eva02_large_patch14_224.mim_m38m': _cfg(
|
||||
hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_L_pt_m38m_p14.pt',
|
||||
#hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_L_pt_m38m_p14.pt',
|
||||
hf_hub_id='timm/',
|
||||
num_classes=0,
|
||||
),
|
||||
|
||||
# EVA01 and EVA02 CLIP image towers
|
||||
'eva_giant_patch14_224.clip': _cfg(
|
||||
#hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA01_CLIP_g_14_plus_psz14_s11B.pt',
|
||||
num_classes=1024,
|
||||
),
|
||||
'eva02_base_patch14_clip_224.clip': _cfg(
|
||||
#hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_L_psz14_s4B.pt',
|
||||
num_classes=512,
|
||||
),
|
||||
'eva02_large_patch14_clip_224.clip': _cfg(
|
||||
#hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_L_psz14_s4B.pt',
|
||||
num_classes=768,
|
||||
),
|
||||
'eva02_enormous_patch14_clip_224.clip': _cfg(
|
||||
#hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_E_psz14_plus_s9B.pt',
|
||||
num_classes=1024,
|
||||
),
|
||||
'eva02_enormous_patch14_clip_224.pretrain': _cfg(
|
||||
#hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_E_psz14.pt',
|
||||
num_classes=0,
|
||||
),
|
||||
|
||||
@ -632,8 +842,8 @@ def eva02_base_patch14_224(pretrained=False, **kwargs):
|
||||
num_heads=12,
|
||||
qkv_fused=False,
|
||||
mlp_ratio=4 * 2 / 3,
|
||||
scale_mlp=True,
|
||||
swiglu_mlp=True,
|
||||
scale_mlp=True,
|
||||
use_rot_pos_emb=True,
|
||||
ref_feat_shape=(16, 16), # 224/14
|
||||
)
|
||||
@ -651,8 +861,8 @@ def eva02_large_patch14_224(pretrained=False, **kwargs):
|
||||
num_heads=16,
|
||||
mlp_ratio=4 * 2 / 3,
|
||||
qkv_fused=False,
|
||||
scale_mlp=True,
|
||||
swiglu_mlp=True,
|
||||
scale_mlp=True,
|
||||
use_rot_pos_emb=True,
|
||||
ref_feat_shape=(16, 16), # 224/14
|
||||
)
|
||||
@ -704,8 +914,8 @@ def eva02_base_patch14_448(pretrained=False, **kwargs):
|
||||
num_heads=12,
|
||||
qkv_fused=False,
|
||||
mlp_ratio=4 * 2 / 3,
|
||||
scale_mlp=True,
|
||||
swiglu_mlp=True,
|
||||
scale_mlp=True,
|
||||
use_rot_pos_emb=True,
|
||||
ref_feat_shape=(16, 16), # 224/14
|
||||
)
|
||||
@ -723,10 +933,71 @@ def eva02_large_patch14_448(pretrained=False, **kwargs):
|
||||
num_heads=16,
|
||||
mlp_ratio=4 * 2 / 3,
|
||||
qkv_fused=False,
|
||||
scale_mlp=True,
|
||||
swiglu_mlp=True,
|
||||
scale_mlp=True,
|
||||
use_rot_pos_emb=True,
|
||||
ref_feat_shape=(16, 16), # 224/14
|
||||
)
|
||||
model = _create_eva('eva02_large_patch14_448', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def eva02_base_patch16_clip_224(pretrained=False, **kwargs):
|
||||
# A EVA-CLIP specific variant that adds additional attn scale layernorm to eva02_base
|
||||
model_kwargs = dict(
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
qkv_fused=False,
|
||||
mlp_ratio=4 * 2 / 3,
|
||||
swiglu_mlp=True,
|
||||
scale_mlp=True,
|
||||
scale_attn_inner=True,
|
||||
use_rot_pos_emb=True,
|
||||
ref_feat_shape=(16, 16), # 224/14
|
||||
global_pool='token',
|
||||
)
|
||||
model = _create_eva('eva02_base_patch16_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def eva02_large_patch14_clip_224(pretrained=False, **kwargs):
|
||||
# A EVA-CLIP specific variant that adds additional attn scale layernorm to eva02_large
|
||||
model_kwargs = dict(
|
||||
img_size=224,
|
||||
patch_size=14,
|
||||
embed_dim=1024,
|
||||
depth=24,
|
||||
num_heads=16,
|
||||
mlp_ratio=4 * 2 / 3,
|
||||
qkv_fused=False,
|
||||
swiglu_mlp=True,
|
||||
scale_mlp=True,
|
||||
scale_attn_inner=True,
|
||||
use_rot_pos_emb=True,
|
||||
ref_feat_shape=(16, 16), # 224/14
|
||||
global_pool='token',
|
||||
)
|
||||
model = _create_eva('eva02_large_patch14_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def eva02_enormous_patch14_clip_224(pretrained=False, **kwargs):
|
||||
# A EVA-CLIP specific variant that uses residual post-norm in blocks
|
||||
model_kwargs = dict(
|
||||
img_size=224,
|
||||
patch_size=14,
|
||||
embed_dim=1792,
|
||||
depth=64,
|
||||
num_heads=16,
|
||||
mlp_ratio=15360 / 1792,
|
||||
use_post_norm=True,
|
||||
global_pool='token',
|
||||
)
|
||||
model = _create_eva('eva02_enormous_patch14_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||
return model
|
||||
|
@ -1216,22 +1216,22 @@ default_cfgs = generate_default_cfgs({
|
||||
# https://github.com/baaivision/EVA/blob/7ecf2c0a370d97967e86d047d7af9188f78d2df3/eva/README.md#eva-l-learning-better-mim-representations-from-eva-clip
|
||||
'eva_large_patch14_196.in22k_ft_in22k_in1k': _cfg(
|
||||
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_21k_to_1k_ft_88p6.pt',
|
||||
hf_hub_id='timm/',
|
||||
hf_hub_id='timm/', license='mit',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
input_size=(3, 196, 196), crop_pct=1.0),
|
||||
'eva_large_patch14_336.in22k_ft_in22k_in1k': _cfg(
|
||||
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_21k_to_1k_ft_89p2.pt',
|
||||
hf_hub_id='timm/',
|
||||
hf_hub_id='timm/', license='mit',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
|
||||
'eva_large_patch14_196.in22k_ft_in1k': _cfg(
|
||||
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_1k_ft_88p0.pt',
|
||||
hf_hub_id='timm/',
|
||||
hf_hub_id='timm/', license='mit',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
input_size=(3, 196, 196), crop_pct=1.0),
|
||||
'eva_large_patch14_336.in22k_ft_in1k': _cfg(
|
||||
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_1k_ft_88p65.pt',
|
||||
hf_hub_id='timm/',
|
||||
hf_hub_id='timm/', license='mit',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user