Working on support for siglip (w/ attn pool) vit backbone, and adding registers (reg tokens)
parent
054c763fca
commit
82cc53237e
|
@ -160,7 +160,11 @@ def load_pretrained(
|
||||||
state_dict = pretrained_loc # pretrained_loc is the actual state dict for this override
|
state_dict = pretrained_loc # pretrained_loc is the actual state dict for this override
|
||||||
elif load_from == 'file':
|
elif load_from == 'file':
|
||||||
_logger.info(f'Loading pretrained weights from file ({pretrained_loc})')
|
_logger.info(f'Loading pretrained weights from file ({pretrained_loc})')
|
||||||
state_dict = load_state_dict(pretrained_loc)
|
if pretrained_cfg.get('custom_load', False):
|
||||||
|
model.load_pretrained(pretrained_loc)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
state_dict = load_state_dict(pretrained_loc)
|
||||||
elif load_from == 'url':
|
elif load_from == 'url':
|
||||||
_logger.info(f'Loading pretrained weights from url ({pretrained_loc})')
|
_logger.info(f'Loading pretrained weights from url ({pretrained_loc})')
|
||||||
if pretrained_cfg.get('custom_load', False):
|
if pretrained_cfg.get('custom_load', False):
|
||||||
|
|
|
@ -377,6 +377,72 @@ class ParallelThingsBlock(nn.Module):
|
||||||
return self._forward(x)
|
return self._forward(x)
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionPoolLatent(nn.Module):
|
||||||
|
""" Attention pooling w/ latent query
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
out_features: int = None,
|
||||||
|
embed_dim: int = None,
|
||||||
|
num_heads: int = 8,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
qk_norm: bool = False,
|
||||||
|
latent_size: int = 1,
|
||||||
|
latent_dim: int = None,
|
||||||
|
pos_embed: str = '',
|
||||||
|
pool_type: str = 'token',
|
||||||
|
norm_layer: Optional[nn.Module] = None,
|
||||||
|
drop: float = 0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
embed_dim = embed_dim or in_features
|
||||||
|
out_features = out_features or in_features
|
||||||
|
assert embed_dim % num_heads == 0
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = embed_dim // num_heads
|
||||||
|
self.scale = self.head_dim ** -0.5
|
||||||
|
self.pool = pool_type
|
||||||
|
|
||||||
|
if pos_embed == 'abs':
|
||||||
|
spatial_len = self.feat_size
|
||||||
|
self.pos_embed = nn.Parameter(torch.zeros(spatial_len, in_features))
|
||||||
|
else:
|
||||||
|
self.pos_embed = None
|
||||||
|
|
||||||
|
self.latent_dim = latent_dim or embed_dim
|
||||||
|
latent_size = latent_size or self.feat_size
|
||||||
|
self.latent_len = latent_size
|
||||||
|
self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim))
|
||||||
|
|
||||||
|
self.attn = Attention(embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm)
|
||||||
|
self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity()
|
||||||
|
self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio))
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
if self.pos_embed is not None:
|
||||||
|
trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, N, _ = x.shape
|
||||||
|
|
||||||
|
if self.pos_embed is not None:
|
||||||
|
# FIXME interpolate
|
||||||
|
x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
|
||||||
|
|
||||||
|
latent_q = self.latent.expand(B, -1, -1)
|
||||||
|
x = self.attn(torch.cat([latent_q, x], dim=1))
|
||||||
|
x = x + self.mlp(self.norm(x))
|
||||||
|
|
||||||
|
# optional pool if latent seq_len > 1 and pooled output is desired
|
||||||
|
if self.pool == 'token':
|
||||||
|
x = x[:, 0]
|
||||||
|
elif self.pool == 'avg':
|
||||||
|
x = x.mean(1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class VisionTransformer(nn.Module):
|
class VisionTransformer(nn.Module):
|
||||||
""" Vision Transformer
|
""" Vision Transformer
|
||||||
|
|
||||||
|
@ -401,8 +467,10 @@ class VisionTransformer(nn.Module):
|
||||||
init_values: Optional[float] = None,
|
init_values: Optional[float] = None,
|
||||||
class_token: bool = True,
|
class_token: bool = True,
|
||||||
no_embed_class: bool = False,
|
no_embed_class: bool = False,
|
||||||
|
reg_tokens: int = 0,
|
||||||
pre_norm: bool = False,
|
pre_norm: bool = False,
|
||||||
fc_norm: Optional[bool] = None,
|
fc_norm: Optional[bool] = None,
|
||||||
|
use_attn_pool: bool = False,
|
||||||
dynamic_img_size: bool = False,
|
dynamic_img_size: bool = False,
|
||||||
dynamic_img_pad: bool = False,
|
dynamic_img_pad: bool = False,
|
||||||
drop_rate: float = 0.,
|
drop_rate: float = 0.,
|
||||||
|
@ -432,6 +500,8 @@ class VisionTransformer(nn.Module):
|
||||||
qkv_bias: Enable bias for qkv projections if True.
|
qkv_bias: Enable bias for qkv projections if True.
|
||||||
init_values: Layer-scale init values (layer-scale enabled if not None).
|
init_values: Layer-scale init values (layer-scale enabled if not None).
|
||||||
class_token: Use class token.
|
class_token: Use class token.
|
||||||
|
no_embed_class: Don't include position embeddings for class (or reg) tokens.
|
||||||
|
reg_tokens: Number of register tokens.
|
||||||
fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
|
fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
|
||||||
drop_rate: Head dropout rate.
|
drop_rate: Head dropout rate.
|
||||||
pos_drop_rate: Position embedding dropout rate.
|
pos_drop_rate: Position embedding dropout rate.
|
||||||
|
@ -445,7 +515,7 @@ class VisionTransformer(nn.Module):
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert global_pool in ('', 'avg', 'token')
|
assert global_pool in ('', 'avg', 'token')
|
||||||
assert class_token or global_pool != 'token'
|
assert class_token or use_attn_pool or global_pool != 'token'
|
||||||
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
|
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
|
||||||
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
||||||
act_layer = act_layer or nn.GELU
|
act_layer = act_layer or nn.GELU
|
||||||
|
@ -454,7 +524,10 @@ class VisionTransformer(nn.Module):
|
||||||
self.global_pool = global_pool
|
self.global_pool = global_pool
|
||||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||||
self.num_prefix_tokens = 1 if class_token else 0
|
self.num_prefix_tokens = 1 if class_token else 0
|
||||||
self.no_embed_class = no_embed_class
|
self.num_prefix_tokens += reg_tokens
|
||||||
|
self.num_reg_tokens = reg_tokens
|
||||||
|
self.has_class_token = class_token
|
||||||
|
self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg)
|
||||||
self.dynamic_img_size = dynamic_img_size
|
self.dynamic_img_size = dynamic_img_size
|
||||||
self.grad_checkpointing = False
|
self.grad_checkpointing = False
|
||||||
|
|
||||||
|
@ -474,6 +547,7 @@ class VisionTransformer(nn.Module):
|
||||||
num_patches = self.patch_embed.num_patches
|
num_patches = self.patch_embed.num_patches
|
||||||
|
|
||||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
||||||
|
self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
|
||||||
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
|
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
|
||||||
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
|
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
|
||||||
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
||||||
|
@ -506,6 +580,14 @@ class VisionTransformer(nn.Module):
|
||||||
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
|
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
|
||||||
|
|
||||||
# Classifier Head
|
# Classifier Head
|
||||||
|
if use_attn_pool == 'map':
|
||||||
|
self.attn_pool = AttentionPoolLatent(
|
||||||
|
self.embed_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.attn_pool = None
|
||||||
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
|
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
|
||||||
self.head_drop = nn.Dropout(drop_rate)
|
self.head_drop = nn.Dropout(drop_rate)
|
||||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
@ -566,18 +648,26 @@ class VisionTransformer(nn.Module):
|
||||||
x = x.view(B, -1, C)
|
x = x.view(B, -1, C)
|
||||||
else:
|
else:
|
||||||
pos_embed = self.pos_embed
|
pos_embed = self.pos_embed
|
||||||
|
|
||||||
|
to_cat = []
|
||||||
|
if self.cls_token is not None:
|
||||||
|
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
|
||||||
|
if self.reg_token is not None:
|
||||||
|
to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
|
||||||
|
|
||||||
if self.no_embed_class:
|
if self.no_embed_class:
|
||||||
# deit-3, updated JAX (big vision)
|
# deit-3, updated JAX (big vision)
|
||||||
# position embedding does not overlap with class token, add then concat
|
# position embedding does not overlap with class token, add then concat
|
||||||
x = x + pos_embed
|
x = x + pos_embed
|
||||||
if self.cls_token is not None:
|
if to_cat:
|
||||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
x = torch.cat(to_cat + [x], dim=1)
|
||||||
else:
|
else:
|
||||||
# original timm, JAX, and deit vit impl
|
# original timm, JAX, and deit vit impl
|
||||||
# pos_embed has entry for class token, concat then add
|
# pos_embed has entry for class token, concat then add
|
||||||
if self.cls_token is not None:
|
if to_cat:
|
||||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
x = torch.cat(to_cat + [x], dim=1)
|
||||||
x = x + pos_embed
|
x = x + pos_embed
|
||||||
|
|
||||||
return self.pos_drop(x)
|
return self.pos_drop(x)
|
||||||
|
|
||||||
def _intermediate_layers(
|
def _intermediate_layers(
|
||||||
|
@ -605,7 +695,7 @@ class VisionTransformer(nn.Module):
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
n: Union[int, Sequence] = 1,
|
n: Union[int, Sequence] = 1,
|
||||||
reshape: bool = False,
|
reshape: bool = False,
|
||||||
return_class_token: bool = False,
|
return_prefix_tokens: bool = False,
|
||||||
norm: bool = False,
|
norm: bool = False,
|
||||||
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
||||||
""" Intermediate layer accessor (NOTE: This is a WIP experiment).
|
""" Intermediate layer accessor (NOTE: This is a WIP experiment).
|
||||||
|
@ -615,7 +705,7 @@ class VisionTransformer(nn.Module):
|
||||||
outputs = self._intermediate_layers(x, n)
|
outputs = self._intermediate_layers(x, n)
|
||||||
if norm:
|
if norm:
|
||||||
outputs = [self.norm(out) for out in outputs]
|
outputs = [self.norm(out) for out in outputs]
|
||||||
class_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs]
|
prefix_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs]
|
||||||
outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
|
outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
|
||||||
|
|
||||||
if reshape:
|
if reshape:
|
||||||
|
@ -625,8 +715,8 @@ class VisionTransformer(nn.Module):
|
||||||
for out in outputs
|
for out in outputs
|
||||||
]
|
]
|
||||||
|
|
||||||
if return_class_token:
|
if return_prefix_tokens:
|
||||||
return tuple(zip(outputs, class_tokens))
|
return tuple(zip(outputs, prefix_tokens))
|
||||||
return tuple(outputs)
|
return tuple(outputs)
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
|
@ -642,8 +732,12 @@ class VisionTransformer(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward_head(self, x, pre_logits: bool = False):
|
def forward_head(self, x, pre_logits: bool = False):
|
||||||
if self.global_pool:
|
if self.attn_pool is not None:
|
||||||
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
|
x = self.attn_pool(x)
|
||||||
|
elif self.global_pool == 'avg':
|
||||||
|
x = x[:, self.num_prefix_tokens:].mean(dim=1)
|
||||||
|
elif self.global_pool:
|
||||||
|
x = x[:, 0] # class token
|
||||||
x = self.fc_norm(x)
|
x = self.fc_norm(x)
|
||||||
x = self.head_drop(x)
|
x = self.head_drop(x)
|
||||||
return x if pre_logits else self.head(x)
|
return x if pre_logits else self.head(x)
|
||||||
|
@ -767,6 +861,9 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
|
||||||
elif 'params/embedding/kernel' in w:
|
elif 'params/embedding/kernel' in w:
|
||||||
prefix = 'params/'
|
prefix = 'params/'
|
||||||
big_vision = True
|
big_vision = True
|
||||||
|
elif 'params/img/embedding/kernel' in w:
|
||||||
|
prefix = 'params/img/'
|
||||||
|
big_vision = True
|
||||||
|
|
||||||
if hasattr(model.patch_embed, 'backbone'):
|
if hasattr(model.patch_embed, 'backbone'):
|
||||||
# hybrid
|
# hybrid
|
||||||
|
@ -823,13 +920,31 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
|
||||||
model.pos_embed.copy_(pos_embed_w)
|
model.pos_embed.copy_(pos_embed_w)
|
||||||
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
||||||
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
|
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
|
||||||
if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
|
if (isinstance(model.head, nn.Linear) and
|
||||||
|
f'{prefix}head/bias' in w and
|
||||||
|
model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]):
|
||||||
model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
|
model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
|
||||||
model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
|
model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
|
||||||
# NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights
|
# NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights
|
||||||
# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
|
# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
|
||||||
# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
|
# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
|
||||||
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
|
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
|
||||||
|
if model.attn_pool is not None:
|
||||||
|
block_prefix = f'{prefix}MAPHead_0/'
|
||||||
|
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
|
||||||
|
model.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False))
|
||||||
|
model.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
|
||||||
|
model.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
|
||||||
|
model.attn_pool.attn.qkv.weight.copy_(torch.cat([
|
||||||
|
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
|
||||||
|
model.attn_pool.attn.qkv.bias.copy_(torch.cat([
|
||||||
|
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
|
||||||
|
model.attn_pool.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
|
||||||
|
model.attn_pool.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
|
||||||
|
for r in range(2):
|
||||||
|
getattr(model.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel']))
|
||||||
|
getattr(model.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias']))
|
||||||
|
|
||||||
mha_sub, b_sub, ln1_sub = (0, 0, 1) if big_vision else (1, 3, 2)
|
mha_sub, b_sub, ln1_sub = (0, 0, 1) if big_vision else (1, 3, 2)
|
||||||
for i, block in enumerate(model.blocks.children()):
|
for i, block in enumerate(model.blocks.children()):
|
||||||
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
||||||
|
@ -1493,6 +1608,12 @@ default_cfgs = generate_default_cfgs({
|
||||||
# hf_hub_id='timm/',
|
# hf_hub_id='timm/',
|
||||||
license='cc-by-nc-4.0',
|
license='cc-by-nc-4.0',
|
||||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
||||||
|
|
||||||
|
'vit_base_patch16_siglip_224': _cfg(
|
||||||
|
file='/data/n/temp/siglip/webli_en_b16_224_63724782.npz',
|
||||||
|
custom_load=True,
|
||||||
|
# hf_hub_id='timm/',
|
||||||
|
num_classes=0),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@ -2119,6 +2240,59 @@ def vit_gigantic_patch16_224_ijepa(pretrained=False, **kwargs) -> VisionTransfor
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_base_patch16_siglip_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||||
|
model_args = dict(
|
||||||
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, use_attn_pool=True,
|
||||||
|
)
|
||||||
|
model = _create_vision_transformer(
|
||||||
|
'vit_base_patch16_siglip_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_medium_patch16_reg8_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||||
|
model_args = dict(
|
||||||
|
patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=True, no_embed_class=True,
|
||||||
|
reg_tokens=8,
|
||||||
|
)
|
||||||
|
model = _create_vision_transformer(
|
||||||
|
'vit_base_patch16_reg8_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_base_patch16_reg8_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||||
|
model_args = dict(
|
||||||
|
patch_size=16, embed_dim=768, depth=12, num_heads=8,
|
||||||
|
class_token=True, no_embed_class=True, reg_tokens=8,
|
||||||
|
)
|
||||||
|
model = _create_vision_transformer(
|
||||||
|
'vit_base_patch16_reg8_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_medium_patch16_reg8_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||||
|
model_args = dict(
|
||||||
|
patch_size=16, embed_dim=512, depth=12, num_heads=8,
|
||||||
|
class_token=True, no_embed_class=True, reg_tokens=8,
|
||||||
|
)
|
||||||
|
model = _create_vision_transformer(
|
||||||
|
'vit_base_patch16_reg8_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_base_patch16_reg8_gap_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||||
|
model_args = dict(
|
||||||
|
patch_size=16, embed_dim=768, depth=12, num_heads=8, global_pool='avg', reg_tokens=8,
|
||||||
|
)
|
||||||
|
model = _create_vision_transformer(
|
||||||
|
'vit_base_patch16_reg8_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
register_model_deprecations(__name__, {
|
register_model_deprecations(__name__, {
|
||||||
'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k',
|
'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k',
|
||||||
'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k',
|
'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k',
|
||||||
|
|
Loading…
Reference in New Issue