mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add to 'abswin' hiera models for train trials
This commit is contained in:
parent
0cbf4fa586
commit
1a05ed29a1
@ -33,7 +33,7 @@ from torch.utils.checkpoint import checkpoint
|
||||
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, Mlp, use_fused_attn, _assert, get_norm_layer
|
||||
from timm.layers import DropPath, Mlp, use_fused_attn, _assert, get_norm_layer, to_2tuple
|
||||
|
||||
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
@ -486,6 +486,8 @@ class Hiera(nn.Module):
|
||||
self.num_classes = num_classes
|
||||
self.grad_checkpointing = False
|
||||
norm_layer = get_norm_layer(norm_layer)
|
||||
if isinstance(img_size, int):
|
||||
img_size = to_2tuple(img_size)
|
||||
|
||||
self.patch_stride = patch_stride
|
||||
self.tokens_spatial_shape = [i // s for i, s in zip(img_size, patch_stride)]
|
||||
@ -895,6 +897,15 @@ default_cfgs = generate_default_cfgs({
|
||||
license='cc-by-nc-4.0',
|
||||
num_classes=0,
|
||||
),
|
||||
|
||||
"hiera_small_abswin_256.untrained": _cfg(
|
||||
#hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256), crop_pct=0.95,
|
||||
),
|
||||
"hiera_base_abswin_256.untrained": _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256), crop_pct=0.95,
|
||||
),
|
||||
})
|
||||
|
||||
|
||||
@ -934,36 +945,48 @@ def _create_hiera(variant: str, pretrained: bool = False, **kwargs) -> Hiera:
|
||||
)
|
||||
|
||||
@register_model
|
||||
def hiera_tiny_224(pretrained = False, **kwargs):
|
||||
def hiera_tiny_224(pretrained=False, **kwargs):
|
||||
model_args = dict(embed_dim=96, num_heads=1, stages=(1, 2, 7, 2))
|
||||
return _create_hiera('hiera_tiny_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def hiera_small_224(pretrained = False, **kwargs):
|
||||
def hiera_small_224(pretrained=False, **kwargs):
|
||||
model_args = dict(embed_dim=96, num_heads=1, stages=(1, 2, 11, 2))
|
||||
return _create_hiera('hiera_small_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def hiera_base_224(pretrained = False, **kwargs):
|
||||
def hiera_base_224(pretrained=False, **kwargs):
|
||||
model_args = dict(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3))
|
||||
return _create_hiera('hiera_base_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def hiera_base_plus_224(pretrained = False, **kwargs):
|
||||
def hiera_base_plus_224(pretrained=False, **kwargs):
|
||||
model_args = dict(embed_dim=112, num_heads=2, stages=(2, 3, 16, 3))
|
||||
return _create_hiera('hiera_base_plus_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def hiera_large_224(pretrained = False, **kwargs):
|
||||
def hiera_large_224(pretrained=False, **kwargs):
|
||||
model_args = dict(embed_dim=144, num_heads=2, stages=(2, 6, 36, 4))
|
||||
return _create_hiera('hiera_large_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def hiera_huge_224(pretrained = False, **kwargs):
|
||||
def hiera_huge_224(pretrained=False, **kwargs):
|
||||
model_args = dict(embed_dim=256, num_heads=4, stages=(2, 6, 36, 4))
|
||||
return _create_hiera('hiera_huge_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def hiera_small_abswin_256(pretrained=False, **kwargs):
|
||||
model_args = dict(embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), abs_win_pos_embed=True, abs_pos_size=(16, 16))
|
||||
return _create_hiera('hiera_small_abswin_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def hiera_base_abswin_256(pretrained=False, **kwargs):
|
||||
model_args = dict(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), abs_win_pos_embed=True, abs_pos_size=(16, 16))
|
||||
return _create_hiera('hiera_base_abswin_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
Loading…
x
Reference in New Issue
Block a user