mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Prepping weight push, benchmarking.
This commit is contained in:
parent
2bfa5e5d74
commit
3582ca499e
@ -718,10 +718,10 @@ def checkpoint_filter_fn(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# FIXME here while import new weights, to remove
|
# FIXME here while import new weights, to remove
|
||||||
# if k == 'cls_token':
|
if k == 'cls_token':
|
||||||
# print('DEBUG: cls token -> reg')
|
print('DEBUG: cls token -> reg')
|
||||||
# k = 'reg_token'
|
k = 'reg_token'
|
||||||
# #v = v + state_dict['pos_embed'][0, :]
|
#v = v + state_dict['pos_embed'][0, :]
|
||||||
|
|
||||||
if 'patch_embed.proj.weight' in k:
|
if 'patch_embed.proj.weight' in k:
|
||||||
_, _, H, W = model.patch_embed.proj.weight.shape
|
_, _, H, W = model.patch_embed.proj.weight.shape
|
||||||
@ -951,26 +951,26 @@ default_cfgs = generate_default_cfgs({
|
|||||||
num_classes=0,
|
num_classes=0,
|
||||||
),
|
),
|
||||||
|
|
||||||
'vit_medium_patch16_rope_reg1_gap_256.in1k': _cfg(
|
'vit_medium_patch16_rope_reg1_gap_256.sbb_in1k': _cfg(
|
||||||
#hf_hub_id='timm/',
|
#hf_hub_id='timm/',
|
||||||
#file='vit_medium_gap1_rope-in1k-20230920-5.pth',
|
file='vit_medium_gap1_rope-in1k-20230920-5.pth',
|
||||||
input_size=(3, 256, 256), crop_pct=0.95,
|
input_size=(3, 256, 256), crop_pct=0.95,
|
||||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
|
||||||
),
|
),
|
||||||
'vit_mediumd_patch16_rope_reg1_gap_256.in1k': _cfg(
|
'vit_mediumd_patch16_rope_reg1_gap_256.sbb_in1k': _cfg(
|
||||||
#hf_hub_id='timm/',
|
#hf_hub_id='timm/',
|
||||||
#file='vit_mediumd_gap1_rope-in1k-20230926-5.pth',
|
file='vit_mediumd_gap1_rope-in1k-20230926-5.pth',
|
||||||
input_size=(3, 256, 256), crop_pct=0.95,
|
input_size=(3, 256, 256), crop_pct=0.95,
|
||||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
|
||||||
),
|
),
|
||||||
'vit_betwixt_patch16_rope_reg4_gap_256.in1k': _cfg(
|
'vit_betwixt_patch16_rope_reg4_gap_256.sbb_in1k': _cfg(
|
||||||
#hf_hub_id='timm/',
|
#hf_hub_id='timm/',
|
||||||
#file='vit_betwixt_gap4_rope-in1k-20231005-5.pth',
|
file='vit_betwixt_gap4_rope-in1k-20231005-5.pth',
|
||||||
input_size=(3, 256, 256), crop_pct=0.95,
|
input_size=(3, 256, 256), crop_pct=0.95,
|
||||||
),
|
),
|
||||||
'vit_base_patch16_rope_reg1_gap_256.in1k': _cfg(
|
'vit_base_patch16_rope_reg1_gap_256.sbb_in1k': _cfg(
|
||||||
#hf_hub_id='timm/',
|
#hf_hub_id='timm/',
|
||||||
#file='vit_base_gap1_rope-in1k-20230930-5.pth',
|
file='vit_base_gap1_rope-in1k-20230930-5.pth',
|
||||||
input_size=(3, 256, 256), crop_pct=0.95,
|
input_size=(3, 256, 256), crop_pct=0.95,
|
||||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
|
||||||
),
|
),
|
||||||
|
@ -428,6 +428,7 @@ class VisionTransformer(nn.Module):
|
|||||||
act_layer: Optional[LayerType] = None,
|
act_layer: Optional[LayerType] = None,
|
||||||
block_fn: Type[nn.Module] = Block,
|
block_fn: Type[nn.Module] = Block,
|
||||||
mlp_layer: Type[nn.Module] = Mlp,
|
mlp_layer: Type[nn.Module] = Mlp,
|
||||||
|
repr_size = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -536,6 +537,14 @@ class VisionTransformer(nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.attn_pool = None
|
self.attn_pool = None
|
||||||
|
if repr_size:
|
||||||
|
repr_size = self.embed_dim if isinstance(repr_size, bool) else repr_size
|
||||||
|
self.repr = nn.Sequential(nn.Linear(self.embed_dim, repr_size), nn.Tanh())
|
||||||
|
embed_dim = repr_size
|
||||||
|
print(self.repr)
|
||||||
|
else:
|
||||||
|
self.repr = nn.Identity()
|
||||||
|
|
||||||
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
|
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
|
||||||
self.head_drop = nn.Dropout(drop_rate)
|
self.head_drop = nn.Dropout(drop_rate)
|
||||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
@ -752,6 +761,7 @@ class VisionTransformer(nn.Module):
|
|||||||
x = x[:, self.num_prefix_tokens:].mean(dim=1)
|
x = x[:, self.num_prefix_tokens:].mean(dim=1)
|
||||||
elif self.global_pool:
|
elif self.global_pool:
|
||||||
x = x[:, 0] # class token
|
x = x[:, 0] # class token
|
||||||
|
x = self.repr(x)
|
||||||
x = self.fc_norm(x)
|
x = self.fc_norm(x)
|
||||||
x = self.head_drop(x)
|
x = self.head_drop(x)
|
||||||
return x if pre_logits else self.head(x)
|
return x if pre_logits else self.head(x)
|
||||||
@ -1790,23 +1800,40 @@ default_cfgs = {
|
|||||||
license='mit',
|
license='mit',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
|
||||||
|
|
||||||
'vit_wee_patch16_reg1_gap_256': _cfg(
|
'vit_wee_patch16_reg1_gap_256.sbb_in1k': _cfg(
|
||||||
#file='',
|
#file='',
|
||||||
input_size=(3, 256, 256), crop_pct=0.95),
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
'vit_little_patch16_reg4_gap_256': _cfg(
|
'vit_pwee_patch16_reg1_gap_256.sbb_in1k': _cfg(
|
||||||
#file='',
|
file='./vit_pwee-in1k-8.pth',
|
||||||
input_size=(3, 256, 256), crop_pct=0.95),
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
'vit_medium_patch16_reg1_gap_256': _cfg(
|
'vit_little_patch16_reg4_gap_256.sbb_in1k': _cfg(
|
||||||
#file='vit_medium_gap1-in1k-20231118-8.pth',
|
file='vit_little_patch16-in1k-8a.pth',
|
||||||
input_size=(3, 256, 256), crop_pct=0.95),
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
'vit_medium_patch16_reg4_gap_256': _cfg(
|
'vit_medium_patch16_reg1_gap_256.sbb_in1k': _cfg(
|
||||||
#file='vit_medium_gap4-in1k-20231115-8.pth',
|
file='vit_medium_gap1-in1k-20231118-8.pth',
|
||||||
input_size=(3, 256, 256), crop_pct=0.95),
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
'vit_betwixt_patch16_reg1_gap_256': _cfg(
|
'vit_medium_patch16_reg4_gap_256.sbb_in1k': _cfg(
|
||||||
#file='vit_betwixt_gap1-in1k-20231121-8.pth',
|
file='vit_medium_gap4-in1k-20231115-8.pth',
|
||||||
input_size=(3, 256, 256), crop_pct=0.95),
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
'vit_betwixt_patch16_reg4_gap_256': _cfg(
|
'vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
|
||||||
#file='vit_betwixt_gap4-in1k-20231106-8.pth',
|
file='vit_mp_patch16_reg4-in1k-5a.pth',
|
||||||
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
|
'vit_mediumd_patch16_reg4_gap_256.sbb_in12k': _cfg(
|
||||||
|
file='vit_mp_patch16_reg4-in12k-8.pth',
|
||||||
|
num_classes=11821,
|
||||||
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
|
'vit_betwixt_patch16_reg1_gap_256.sbb_in1k': _cfg(
|
||||||
|
file='vit_betwixt_gap1-in1k-20231121-8.pth',
|
||||||
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
|
'vit_betwixt_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
|
||||||
|
file='vit_betwixt_patch16_reg4-ft-in1k-8b.pth',
|
||||||
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
|
'vit_betwixt_patch16_reg4_gap_256.sbb_in1k': _cfg(
|
||||||
|
file='vit_betwixt_gap4-in1k-20231106-8.pth',
|
||||||
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
|
'vit_betwixt_patch16_reg4_gap_256.sbb_in12k': _cfg(
|
||||||
|
file='vit_betwixt_gap4-in12k-8.pth',
|
||||||
|
num_classes=11821,
|
||||||
input_size=(3, 256, 256), crop_pct=0.95),
|
input_size=(3, 256, 256), crop_pct=0.95),
|
||||||
'vit_base_patch16_reg4_gap_256': _cfg(
|
'vit_base_patch16_reg4_gap_256': _cfg(
|
||||||
input_size=(3, 256, 256)),
|
input_size=(3, 256, 256)),
|
||||||
@ -1906,6 +1933,14 @@ def vit_small_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransform
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_small_patch16_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||||
|
""" ViT-Small (ViT-S/16)
|
||||||
|
"""
|
||||||
|
model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, global_pool='avg', class_token=False, repr_size=True)
|
||||||
|
model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
return model
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vit_small_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
def vit_small_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||||
""" ViT-Small (ViT-S/16)
|
""" ViT-Small (ViT-S/16)
|
||||||
@ -2755,10 +2790,21 @@ def vit_so400m_patch14_siglip_384(pretrained: bool = False, **kwargs) -> VisionT
|
|||||||
def vit_wee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
def vit_wee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
patch_size=16, embed_dim=256, depth=14, num_heads=4, init_values=1e-5, mlp_ratio=5,
|
patch_size=16, embed_dim=256, depth=14, num_heads=4, init_values=1e-5, mlp_ratio=5,
|
||||||
|
class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg',
|
||||||
|
)
|
||||||
|
model = _create_vision_transformer(
|
||||||
|
'vit_wee_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_pwee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||||
|
model_args = dict(
|
||||||
|
patch_size=16, embed_dim=256, depth=16, num_heads=4, init_values=1e-5, mlp_ratio=5,
|
||||||
class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', block_fn=ParallelScalingBlock,
|
class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', block_fn=ParallelScalingBlock,
|
||||||
)
|
)
|
||||||
model = _create_vision_transformer(
|
model = _create_vision_transformer(
|
||||||
'vit_medium_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
'vit_pwee_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -2769,7 +2815,7 @@ def vit_little_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio
|
|||||||
class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
|
class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
|
||||||
)
|
)
|
||||||
model = _create_vision_transformer(
|
model = _create_vision_transformer(
|
||||||
'vit_medium_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
'vit_little_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -2795,6 +2841,17 @@ def vit_medium_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_mediumd_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||||
|
model_args = dict(
|
||||||
|
patch_size=16, embed_dim=512, depth=20, num_heads=8, init_values=1e-5,
|
||||||
|
class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
|
||||||
|
)
|
||||||
|
model = _create_vision_transformer(
|
||||||
|
'vit_mediumd_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vit_betwixt_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
def vit_betwixt_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
|
@ -1 +1 @@
|
|||||||
__version__ = '1.0.0.dev0'
|
__version__ = '1.0.1.dev0'
|
||||||
|
Loading…
x
Reference in New Issue
Block a user