mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add maxvit_rmlp_nano_rw_256 model def & weights, make window/grid size dynamic wrt img_size by default
This commit is contained in:
parent
e6a4361306
commit
7f1b223c02
@ -39,7 +39,7 @@ Hacked together by / Copyright 2022, Ross Wightman
|
||||
|
||||
import math
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, replace
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Union, Tuple, List
|
||||
|
||||
@ -112,6 +112,9 @@ default_cfgs = {
|
||||
input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'maxvit_tiny_rw_224': _cfg(url=''),
|
||||
'maxvit_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'maxvit_rmlp_nano_rw_256': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'maxxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
|
||||
@ -136,14 +139,23 @@ class MaxxVitTransformerCfg:
|
||||
pool_type: str = 'avg2'
|
||||
rel_pos_type: str = 'bias'
|
||||
rel_pos_dim: int = 512 # for relative position types w/ MLP
|
||||
window_size: Tuple[int, int] = (7, 7)
|
||||
grid_size: Tuple[int, int] = (7, 7)
|
||||
partition_stride: int = 32
|
||||
window_size: Optional[Tuple[int, int]] = None
|
||||
grid_size: Optional[Tuple[int, int]] = None
|
||||
init_values: Optional[float] = None
|
||||
act_layer: str = 'gelu'
|
||||
norm_layer: str = 'layernorm2d'
|
||||
norm_layer_cl: str = 'layernorm'
|
||||
norm_eps: float = 1e-6
|
||||
|
||||
def __post_init__(self):
|
||||
if self.grid_size is not None:
|
||||
self.grid_size = to_2tuple(self.grid_size)
|
||||
if self.window_size is not None:
|
||||
self.window_size = to_2tuple(self.window_size)
|
||||
if self.grid_size is None:
|
||||
self.grid_size = self.window_size
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaxxVitConvCfg:
|
||||
@ -249,7 +261,7 @@ def _rw_max_cfg(
|
||||
conv_norm_layer='',
|
||||
transformer_norm_layer='layernorm2d',
|
||||
transformer_norm_layer_cl='layernorm',
|
||||
window_size=7,
|
||||
window_size=None,
|
||||
dim_head=32,
|
||||
rel_pos_type='bias',
|
||||
rel_pos_dim=512,
|
||||
@ -274,8 +286,7 @@ def _rw_max_cfg(
|
||||
expand_first=False,
|
||||
pool_type=pool_type,
|
||||
dim_head=dim_head,
|
||||
window_size=to_2tuple(window_size),
|
||||
grid_size=to_2tuple(window_size),
|
||||
window_size=window_size,
|
||||
norm_layer=transformer_norm_layer,
|
||||
norm_layer_cl=transformer_norm_layer_cl,
|
||||
rel_pos_type=rel_pos_type,
|
||||
@ -291,7 +302,7 @@ def _next_cfg(
|
||||
conv_norm_layer_cl='layernorm',
|
||||
transformer_norm_layer='layernorm2d',
|
||||
transformer_norm_layer_cl='layernorm',
|
||||
window_size=7,
|
||||
window_size=None,
|
||||
rel_pos_type='bias',
|
||||
rel_pos_dim=512,
|
||||
):
|
||||
@ -308,8 +319,7 @@ def _next_cfg(
|
||||
transformer_cfg=MaxxVitTransformerCfg(
|
||||
expand_first=False,
|
||||
pool_type=pool_type,
|
||||
window_size=to_2tuple(window_size),
|
||||
grid_size=to_2tuple(window_size),
|
||||
window_size=window_size,
|
||||
norm_layer=transformer_norm_layer,
|
||||
norm_layer_cl=transformer_norm_layer_cl,
|
||||
rel_pos_type=rel_pos_type,
|
||||
@ -462,14 +472,14 @@ model_cfgs = dict(
|
||||
depths=(2, 2, 5, 2),
|
||||
block_type=('M',) * 4,
|
||||
stem_width=(24, 32),
|
||||
**_rw_max_cfg(window_size=8),
|
||||
**_rw_max_cfg(),
|
||||
),
|
||||
maxvit_nano_rw_256=MaxxVitCfg(
|
||||
embed_dim=(64, 128, 256, 512),
|
||||
depths=(1, 2, 3, 1),
|
||||
block_type=('M',) * 4,
|
||||
stem_width=(32, 64),
|
||||
**_rw_max_cfg(window_size=8),
|
||||
**_rw_max_cfg(),
|
||||
),
|
||||
maxvit_tiny_rw_224=MaxxVitCfg(
|
||||
embed_dim=(64, 128, 256, 512),
|
||||
@ -483,14 +493,21 @@ model_cfgs = dict(
|
||||
depths=(2, 2, 5, 2),
|
||||
block_type=('M',) * 4,
|
||||
stem_width=(32, 64),
|
||||
**_rw_max_cfg(window_size=8),
|
||||
**_rw_max_cfg(),
|
||||
),
|
||||
maxvit_rmlp_nano_rw_256=MaxxVitCfg(
|
||||
embed_dim=(64, 128, 256, 512),
|
||||
depths=(1, 2, 3, 1),
|
||||
block_type=('M',) * 4,
|
||||
stem_width=(32, 64),
|
||||
**_rw_max_cfg(rel_pos_type='mlp'),
|
||||
),
|
||||
maxvit_tiny_pm_256=MaxxVitCfg(
|
||||
embed_dim=(64, 128, 256, 512),
|
||||
depths=(2, 2, 5, 2),
|
||||
block_type=('PM',) * 4,
|
||||
stem_width=(32, 64),
|
||||
**_rw_max_cfg(window_size=8),
|
||||
**_rw_max_cfg(),
|
||||
),
|
||||
maxxvit_nano_rw_256=MaxxVitCfg(
|
||||
embed_dim=(64, 128, 256, 512),
|
||||
@ -498,7 +515,7 @@ model_cfgs = dict(
|
||||
block_type=('M',) * 4,
|
||||
stem_width=(32, 64),
|
||||
weight_init='normal',
|
||||
**_next_cfg(window_size=8),
|
||||
**_next_cfg(),
|
||||
),
|
||||
|
||||
# Trying to be like the MaxViT paper configs
|
||||
@ -1437,6 +1454,15 @@ class Stem(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]):
|
||||
if cfg.window_size is not None:
|
||||
assert cfg.grid_size
|
||||
return cfg
|
||||
partition_size = img_size[0] // cfg.partition_stride, img_size[1] // cfg.partition_stride
|
||||
cfg = replace(cfg, window_size=partition_size, grid_size=partition_size)
|
||||
return cfg
|
||||
|
||||
|
||||
class MaxxVit(nn.Module):
|
||||
""" CoaTNet + MaxVit base model.
|
||||
|
||||
@ -1455,6 +1481,7 @@ class MaxxVit(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
transformer_cfg = cfg_window_size(cfg.transformer_cfg, img_size)
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = global_pool
|
||||
self.num_features = cfg.embed_dim[-1]
|
||||
@ -1488,7 +1515,7 @@ class MaxxVit(nn.Module):
|
||||
depth=cfg.depths[i],
|
||||
block_types=cfg.block_type[i],
|
||||
conv_cfg=cfg.conv_cfg,
|
||||
transformer_cfg=cfg.transformer_cfg,
|
||||
transformer_cfg=transformer_cfg,
|
||||
feat_size=feat_size,
|
||||
drop_path=dpr[i],
|
||||
)]
|
||||
@ -1671,6 +1698,11 @@ def maxvit_tiny_rw_256(pretrained=False, **kwargs):
|
||||
return _create_maxxvit('maxvit_tiny_rw_256', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def maxvit_rmlp_nano_rw_256(pretrained=False, **kwargs):
|
||||
return _create_maxxvit('maxvit_rmlp_nano_rw_256', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def maxvit_tiny_pm_256(pretrained=False, **kwargs):
|
||||
return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user