mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add maxvit_rmlp_nano_rw_256 model def & weights, make window/grid size dynamic wrt img_size by default
This commit is contained in:
parent
e6a4361306
commit
7f1b223c02
@ -39,7 +39,7 @@ Hacked together by / Copyright 2022, Ross Wightman
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, replace
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, Optional, Union, Tuple, List
|
from typing import Callable, Optional, Union, Tuple, List
|
||||||
|
|
||||||
@ -112,6 +112,9 @@ default_cfgs = {
|
|||||||
input_size=(3, 256, 256), pool_size=(8, 8)),
|
input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||||
'maxvit_tiny_rw_224': _cfg(url=''),
|
'maxvit_tiny_rw_224': _cfg(url=''),
|
||||||
'maxvit_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
'maxvit_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||||
|
'maxvit_rmlp_nano_rw_256': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth',
|
||||||
|
input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||||
'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||||
'maxxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
'maxxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||||
|
|
||||||
@ -136,14 +139,23 @@ class MaxxVitTransformerCfg:
|
|||||||
pool_type: str = 'avg2'
|
pool_type: str = 'avg2'
|
||||||
rel_pos_type: str = 'bias'
|
rel_pos_type: str = 'bias'
|
||||||
rel_pos_dim: int = 512 # for relative position types w/ MLP
|
rel_pos_dim: int = 512 # for relative position types w/ MLP
|
||||||
window_size: Tuple[int, int] = (7, 7)
|
partition_stride: int = 32
|
||||||
grid_size: Tuple[int, int] = (7, 7)
|
window_size: Optional[Tuple[int, int]] = None
|
||||||
|
grid_size: Optional[Tuple[int, int]] = None
|
||||||
init_values: Optional[float] = None
|
init_values: Optional[float] = None
|
||||||
act_layer: str = 'gelu'
|
act_layer: str = 'gelu'
|
||||||
norm_layer: str = 'layernorm2d'
|
norm_layer: str = 'layernorm2d'
|
||||||
norm_layer_cl: str = 'layernorm'
|
norm_layer_cl: str = 'layernorm'
|
||||||
norm_eps: float = 1e-6
|
norm_eps: float = 1e-6
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.grid_size is not None:
|
||||||
|
self.grid_size = to_2tuple(self.grid_size)
|
||||||
|
if self.window_size is not None:
|
||||||
|
self.window_size = to_2tuple(self.window_size)
|
||||||
|
if self.grid_size is None:
|
||||||
|
self.grid_size = self.window_size
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MaxxVitConvCfg:
|
class MaxxVitConvCfg:
|
||||||
@ -249,7 +261,7 @@ def _rw_max_cfg(
|
|||||||
conv_norm_layer='',
|
conv_norm_layer='',
|
||||||
transformer_norm_layer='layernorm2d',
|
transformer_norm_layer='layernorm2d',
|
||||||
transformer_norm_layer_cl='layernorm',
|
transformer_norm_layer_cl='layernorm',
|
||||||
window_size=7,
|
window_size=None,
|
||||||
dim_head=32,
|
dim_head=32,
|
||||||
rel_pos_type='bias',
|
rel_pos_type='bias',
|
||||||
rel_pos_dim=512,
|
rel_pos_dim=512,
|
||||||
@ -274,8 +286,7 @@ def _rw_max_cfg(
|
|||||||
expand_first=False,
|
expand_first=False,
|
||||||
pool_type=pool_type,
|
pool_type=pool_type,
|
||||||
dim_head=dim_head,
|
dim_head=dim_head,
|
||||||
window_size=to_2tuple(window_size),
|
window_size=window_size,
|
||||||
grid_size=to_2tuple(window_size),
|
|
||||||
norm_layer=transformer_norm_layer,
|
norm_layer=transformer_norm_layer,
|
||||||
norm_layer_cl=transformer_norm_layer_cl,
|
norm_layer_cl=transformer_norm_layer_cl,
|
||||||
rel_pos_type=rel_pos_type,
|
rel_pos_type=rel_pos_type,
|
||||||
@ -291,7 +302,7 @@ def _next_cfg(
|
|||||||
conv_norm_layer_cl='layernorm',
|
conv_norm_layer_cl='layernorm',
|
||||||
transformer_norm_layer='layernorm2d',
|
transformer_norm_layer='layernorm2d',
|
||||||
transformer_norm_layer_cl='layernorm',
|
transformer_norm_layer_cl='layernorm',
|
||||||
window_size=7,
|
window_size=None,
|
||||||
rel_pos_type='bias',
|
rel_pos_type='bias',
|
||||||
rel_pos_dim=512,
|
rel_pos_dim=512,
|
||||||
):
|
):
|
||||||
@ -308,8 +319,7 @@ def _next_cfg(
|
|||||||
transformer_cfg=MaxxVitTransformerCfg(
|
transformer_cfg=MaxxVitTransformerCfg(
|
||||||
expand_first=False,
|
expand_first=False,
|
||||||
pool_type=pool_type,
|
pool_type=pool_type,
|
||||||
window_size=to_2tuple(window_size),
|
window_size=window_size,
|
||||||
grid_size=to_2tuple(window_size),
|
|
||||||
norm_layer=transformer_norm_layer,
|
norm_layer=transformer_norm_layer,
|
||||||
norm_layer_cl=transformer_norm_layer_cl,
|
norm_layer_cl=transformer_norm_layer_cl,
|
||||||
rel_pos_type=rel_pos_type,
|
rel_pos_type=rel_pos_type,
|
||||||
@ -462,14 +472,14 @@ model_cfgs = dict(
|
|||||||
depths=(2, 2, 5, 2),
|
depths=(2, 2, 5, 2),
|
||||||
block_type=('M',) * 4,
|
block_type=('M',) * 4,
|
||||||
stem_width=(24, 32),
|
stem_width=(24, 32),
|
||||||
**_rw_max_cfg(window_size=8),
|
**_rw_max_cfg(),
|
||||||
),
|
),
|
||||||
maxvit_nano_rw_256=MaxxVitCfg(
|
maxvit_nano_rw_256=MaxxVitCfg(
|
||||||
embed_dim=(64, 128, 256, 512),
|
embed_dim=(64, 128, 256, 512),
|
||||||
depths=(1, 2, 3, 1),
|
depths=(1, 2, 3, 1),
|
||||||
block_type=('M',) * 4,
|
block_type=('M',) * 4,
|
||||||
stem_width=(32, 64),
|
stem_width=(32, 64),
|
||||||
**_rw_max_cfg(window_size=8),
|
**_rw_max_cfg(),
|
||||||
),
|
),
|
||||||
maxvit_tiny_rw_224=MaxxVitCfg(
|
maxvit_tiny_rw_224=MaxxVitCfg(
|
||||||
embed_dim=(64, 128, 256, 512),
|
embed_dim=(64, 128, 256, 512),
|
||||||
@ -483,14 +493,21 @@ model_cfgs = dict(
|
|||||||
depths=(2, 2, 5, 2),
|
depths=(2, 2, 5, 2),
|
||||||
block_type=('M',) * 4,
|
block_type=('M',) * 4,
|
||||||
stem_width=(32, 64),
|
stem_width=(32, 64),
|
||||||
**_rw_max_cfg(window_size=8),
|
**_rw_max_cfg(),
|
||||||
|
),
|
||||||
|
maxvit_rmlp_nano_rw_256=MaxxVitCfg(
|
||||||
|
embed_dim=(64, 128, 256, 512),
|
||||||
|
depths=(1, 2, 3, 1),
|
||||||
|
block_type=('M',) * 4,
|
||||||
|
stem_width=(32, 64),
|
||||||
|
**_rw_max_cfg(rel_pos_type='mlp'),
|
||||||
),
|
),
|
||||||
maxvit_tiny_pm_256=MaxxVitCfg(
|
maxvit_tiny_pm_256=MaxxVitCfg(
|
||||||
embed_dim=(64, 128, 256, 512),
|
embed_dim=(64, 128, 256, 512),
|
||||||
depths=(2, 2, 5, 2),
|
depths=(2, 2, 5, 2),
|
||||||
block_type=('PM',) * 4,
|
block_type=('PM',) * 4,
|
||||||
stem_width=(32, 64),
|
stem_width=(32, 64),
|
||||||
**_rw_max_cfg(window_size=8),
|
**_rw_max_cfg(),
|
||||||
),
|
),
|
||||||
maxxvit_nano_rw_256=MaxxVitCfg(
|
maxxvit_nano_rw_256=MaxxVitCfg(
|
||||||
embed_dim=(64, 128, 256, 512),
|
embed_dim=(64, 128, 256, 512),
|
||||||
@ -498,7 +515,7 @@ model_cfgs = dict(
|
|||||||
block_type=('M',) * 4,
|
block_type=('M',) * 4,
|
||||||
stem_width=(32, 64),
|
stem_width=(32, 64),
|
||||||
weight_init='normal',
|
weight_init='normal',
|
||||||
**_next_cfg(window_size=8),
|
**_next_cfg(),
|
||||||
),
|
),
|
||||||
|
|
||||||
# Trying to be like the MaxViT paper configs
|
# Trying to be like the MaxViT paper configs
|
||||||
@ -1437,6 +1454,15 @@ class Stem(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]):
|
||||||
|
if cfg.window_size is not None:
|
||||||
|
assert cfg.grid_size
|
||||||
|
return cfg
|
||||||
|
partition_size = img_size[0] // cfg.partition_stride, img_size[1] // cfg.partition_stride
|
||||||
|
cfg = replace(cfg, window_size=partition_size, grid_size=partition_size)
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
class MaxxVit(nn.Module):
|
class MaxxVit(nn.Module):
|
||||||
""" CoaTNet + MaxVit base model.
|
""" CoaTNet + MaxVit base model.
|
||||||
|
|
||||||
@ -1455,6 +1481,7 @@ class MaxxVit(nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
img_size = to_2tuple(img_size)
|
img_size = to_2tuple(img_size)
|
||||||
|
transformer_cfg = cfg_window_size(cfg.transformer_cfg, img_size)
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.global_pool = global_pool
|
self.global_pool = global_pool
|
||||||
self.num_features = cfg.embed_dim[-1]
|
self.num_features = cfg.embed_dim[-1]
|
||||||
@ -1488,7 +1515,7 @@ class MaxxVit(nn.Module):
|
|||||||
depth=cfg.depths[i],
|
depth=cfg.depths[i],
|
||||||
block_types=cfg.block_type[i],
|
block_types=cfg.block_type[i],
|
||||||
conv_cfg=cfg.conv_cfg,
|
conv_cfg=cfg.conv_cfg,
|
||||||
transformer_cfg=cfg.transformer_cfg,
|
transformer_cfg=transformer_cfg,
|
||||||
feat_size=feat_size,
|
feat_size=feat_size,
|
||||||
drop_path=dpr[i],
|
drop_path=dpr[i],
|
||||||
)]
|
)]
|
||||||
@ -1671,6 +1698,11 @@ def maxvit_tiny_rw_256(pretrained=False, **kwargs):
|
|||||||
return _create_maxxvit('maxvit_tiny_rw_256', pretrained=pretrained, **kwargs)
|
return _create_maxxvit('maxvit_tiny_rw_256', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def maxvit_rmlp_nano_rw_256(pretrained=False, **kwargs):
|
||||||
|
return _create_maxxvit('maxvit_rmlp_nano_rw_256', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def maxvit_tiny_pm_256(pretrained=False, **kwargs):
|
def maxvit_tiny_pm_256(pretrained=False, **kwargs):
|
||||||
return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs)
|
return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user