Add to 'abswin' hiera models for train trials

This commit is contained in:
Ross Wightman 2024-07-19 11:05:31 -07:00
parent 0cbf4fa586
commit 1a05ed29a1

View File

@ -33,7 +33,7 @@ from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, Mlp, use_fused_attn, _assert, get_norm_layer
from timm.layers import DropPath, Mlp, use_fused_attn, _assert, get_norm_layer, to_2tuple
from ._registry import generate_default_cfgs, register_model
@ -486,6 +486,8 @@ class Hiera(nn.Module):
self.num_classes = num_classes
self.grad_checkpointing = False
norm_layer = get_norm_layer(norm_layer)
if isinstance(img_size, int):
img_size = to_2tuple(img_size)
self.patch_stride = patch_stride
self.tokens_spatial_shape = [i // s for i, s in zip(img_size, patch_stride)]
@ -895,6 +897,15 @@ default_cfgs = generate_default_cfgs({
license='cc-by-nc-4.0',
num_classes=0,
),
"hiera_small_abswin_256.untrained": _cfg(
#hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95,
),
"hiera_base_abswin_256.untrained": _cfg(
# hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95,
),
})
@ -934,36 +945,48 @@ def _create_hiera(variant: str, pretrained: bool = False, **kwargs) -> Hiera:
)
@register_model
def hiera_tiny_224(pretrained = False, **kwargs):
def hiera_tiny_224(pretrained=False, **kwargs):
model_args = dict(embed_dim=96, num_heads=1, stages=(1, 2, 7, 2))
return _create_hiera('hiera_tiny_224', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def hiera_small_224(pretrained = False, **kwargs):
def hiera_small_224(pretrained=False, **kwargs):
model_args = dict(embed_dim=96, num_heads=1, stages=(1, 2, 11, 2))
return _create_hiera('hiera_small_224', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def hiera_base_224(pretrained = False, **kwargs):
def hiera_base_224(pretrained=False, **kwargs):
model_args = dict(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3))
return _create_hiera('hiera_base_224', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def hiera_base_plus_224(pretrained = False, **kwargs):
def hiera_base_plus_224(pretrained=False, **kwargs):
model_args = dict(embed_dim=112, num_heads=2, stages=(2, 3, 16, 3))
return _create_hiera('hiera_base_plus_224', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def hiera_large_224(pretrained = False, **kwargs):
def hiera_large_224(pretrained=False, **kwargs):
model_args = dict(embed_dim=144, num_heads=2, stages=(2, 6, 36, 4))
return _create_hiera('hiera_large_224', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def hiera_huge_224(pretrained = False, **kwargs):
def hiera_huge_224(pretrained=False, **kwargs):
model_args = dict(embed_dim=256, num_heads=4, stages=(2, 6, 36, 4))
return _create_hiera('hiera_huge_224', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def hiera_small_abswin_256(pretrained=False, **kwargs):
model_args = dict(embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), abs_win_pos_embed=True, abs_pos_size=(16, 16))
return _create_hiera('hiera_small_abswin_256', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def hiera_base_abswin_256(pretrained=False, **kwargs):
model_args = dict(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), abs_win_pos_embed=True, abs_pos_size=(16, 16))
return _create_hiera('hiera_base_abswin_256', pretrained=pretrained, **dict(model_args, **kwargs))