mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add intern300m vit w/ converted timm weights. Fix #2300
This commit is contained in:
parent
60f517c883
commit
a1f379e712
@ -710,10 +710,12 @@ def checkpoint_filter_fn(state_dict, model):
|
||||
def _create_davit(variant, pretrained=False, **kwargs):
|
||||
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
|
||||
out_indices = kwargs.pop('out_indices', default_out_indices)
|
||||
strict = True
|
||||
|
||||
strict = kwargs.pop('pretrained_strict', True)
|
||||
if variant.endswith('_fl'):
|
||||
# FIXME cleaner approach to missing head norm?
|
||||
strict = False
|
||||
|
||||
model = build_model_with_cfg(
|
||||
DaVit,
|
||||
variant,
|
||||
|
@ -438,6 +438,7 @@ class VisionTransformer(nn.Module):
|
||||
no_embed_class: bool = False,
|
||||
reg_tokens: int = 0,
|
||||
pre_norm: bool = False,
|
||||
final_norm: bool = True,
|
||||
fc_norm: Optional[bool] = None,
|
||||
dynamic_img_size: bool = False,
|
||||
dynamic_img_pad: bool = False,
|
||||
@ -471,7 +472,9 @@ class VisionTransformer(nn.Module):
|
||||
class_token: Use class token.
|
||||
no_embed_class: Don't include position embeddings for class (or reg) tokens.
|
||||
reg_tokens: Number of register tokens.
|
||||
fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
|
||||
pre_norm: Enable norm after embeddings, before transformer blocks (standard in CLIP ViT).
|
||||
final_norm: Enable norm after transformer blocks, before head (standard in most ViT).
|
||||
fc_norm: Move final norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
|
||||
drop_rate: Head dropout rate.
|
||||
pos_drop_rate: Position embedding dropout rate.
|
||||
attn_drop_rate: Attention dropout rate.
|
||||
@ -554,7 +557,7 @@ class VisionTransformer(nn.Module):
|
||||
for i in range(depth)])
|
||||
self.feature_info = [
|
||||
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=reduction) for i in range(depth)]
|
||||
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
|
||||
self.norm = norm_layer(embed_dim) if final_norm and not use_fc_norm else nn.Identity()
|
||||
|
||||
# Classifier Head
|
||||
if global_pool == 'map':
|
||||
@ -566,7 +569,7 @@ class VisionTransformer(nn.Module):
|
||||
)
|
||||
else:
|
||||
self.attn_pool = None
|
||||
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
|
||||
self.fc_norm = norm_layer(embed_dim) if final_norm and use_fc_norm else nn.Identity()
|
||||
self.head_drop = nn.Dropout(drop_rate)
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
@ -2051,6 +2054,12 @@ default_cfgs = {
|
||||
'vit_so150m_patch16_reg4_map_256.untrained': _cfg(
|
||||
input_size=(3, 256, 256)),
|
||||
|
||||
'vit_intern300m_patch14_448.ogvl_dist': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
|
||||
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0,
|
||||
),
|
||||
|
||||
'test_vit.r160_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 160, 160), crop_pct=0.95),
|
||||
@ -2091,7 +2100,7 @@ def _create_vision_transformer(variant: str, pretrained: bool = False, **kwargs)
|
||||
_filter_fn = checkpoint_filter_fn
|
||||
|
||||
# FIXME attn pool (currently only in siglip) params removed if pool disabled, is there a better soln?
|
||||
strict = True
|
||||
strict = kwargs.pop('pretrained_strict', True)
|
||||
if 'siglip' in variant and kwargs.get('global_pool', None) != 'map':
|
||||
strict = False
|
||||
|
||||
@ -3298,6 +3307,17 @@ def vit_so150m_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_intern300m_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=14, embed_dim=1024, depth=24, num_heads=16,
|
||||
init_values=0.1, final_norm=False, dynamic_img_size=True,
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_intern300m_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def test_vit(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
""" ViT Test
|
||||
|
Loading…
x
Reference in New Issue
Block a user