mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update experimental vit model configs
This commit is contained in:
parent
7d3c2dc993
commit
87fec3dc14
@ -1723,7 +1723,12 @@ default_cfgs = {
|
|||||||
input_size=(3, 256, 256)),
|
input_size=(3, 256, 256)),
|
||||||
'vit_medium_patch16_reg4_gap_256': _cfg(
|
'vit_medium_patch16_reg4_gap_256': _cfg(
|
||||||
input_size=(3, 256, 256)),
|
input_size=(3, 256, 256)),
|
||||||
'vit_base_patch16_reg8_gap_256': _cfg(input_size=(3, 256, 256)),
|
'vit_base_patch16_reg4_gap_256': _cfg(
|
||||||
|
input_size=(3, 256, 256)),
|
||||||
|
'vit_so150m_patch16_reg4_gap_256': _cfg(
|
||||||
|
input_size=(3, 256, 256)),
|
||||||
|
'vit_so150m_patch16_reg4_map_256': _cfg(
|
||||||
|
input_size=(3, 256, 256)),
|
||||||
}
|
}
|
||||||
|
|
||||||
_quick_gelu_cfgs = [
|
_quick_gelu_cfgs = [
|
||||||
@ -2623,13 +2628,35 @@ def vit_medium_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio
|
|||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vit_base_patch16_reg8_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
def vit_base_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False,
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False,
|
||||||
no_embed_class=True, global_pool='avg', reg_tokens=8,
|
no_embed_class=True, global_pool='avg', reg_tokens=4,
|
||||||
)
|
)
|
||||||
model = _create_vision_transformer(
|
model = _create_vision_transformer(
|
||||||
'vit_base_patch16_reg8_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
'vit_base_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_so150m_patch16_reg4_map_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||||
|
model_args = dict(
|
||||||
|
patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572,
|
||||||
|
class_token=False, reg_tokens=4, global_pool='map',
|
||||||
|
)
|
||||||
|
model = _create_vision_transformer(
|
||||||
|
'vit_so150m_patch16_reg4_map_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_so150m_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||||
|
model_args = dict(
|
||||||
|
patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572,
|
||||||
|
class_token=False, reg_tokens=4, global_pool='avg', fc_norm=False,
|
||||||
|
)
|
||||||
|
model = _create_vision_transformer(
|
||||||
|
'vit_so150m_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user