pytorch-image-models/timm/models/maxxvit.py

2316 lines
82 KiB
Python
Raw Normal View History

""" MaxVit and CoAtNet Vision Transformer - CNN Hybrids in PyTorch
This is a from-scratch implementation of both CoAtNet and MaxVit in PyTorch.
99% of the implementation was done from papers, however last minute some adjustments were made
based on the (as yet unfinished?) public code release https://github.com/google-research/maxvit
There are multiple sets of models defined for both architectures. Typically, names with a
`_rw` suffix are my own original configs prior to referencing https://github.com/google-research/maxvit.
These configs work well and appear to be a bit faster / lower resource than the paper.
The models without extra prefix / suffix' (coatnet_0_224, maxvit_tiny_224, etc), are intended to
match paper, BUT, without any official pretrained weights it's difficult to confirm a 100% match.
Papers:
MaxViT: Multi-Axis Vision Transformer - https://arxiv.org/abs/2204.01697
@article{tu2022maxvit,
title={MaxViT: Multi-Axis Vision Transformer},
author={Tu, Zhengzhong and Talebi, Hossein and Zhang, Han and Yang, Feng and Milanfar, Peyman and Bovik, Alan and Li, Yinxiao},
journal={ECCV},
year={2022},
}
CoAtNet: Marrying Convolution and Attention for All Data Sizes - https://arxiv.org/abs/2106.04803
@article{DBLP:journals/corr/abs-2106-04803,
author = {Zihang Dai and Hanxiao Liu and Quoc V. Le and Mingxing Tan},
title = {CoAtNet: Marrying Convolution and Attention for All Data Sizes},
journal = {CoRR},
volume = {abs/2106.04803},
year = {2021}
}
Hacked together by / Copyright 2022, Ross Wightman
"""
import math
from collections import OrderedDict
2023-02-04 03:28:57 +08:00
from dataclasses import dataclass, replace, field
from functools import partial
from typing import Callable, Optional, Union, Tuple, List
import torch
from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import Mlp, ConvMlp, DropPath, LayerNorm, ClassifierHead, NormMlpClassifierHead
from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, create_pool2d
from timm.layers import trunc_normal_tf_, to_2tuple, extend_tuple, make_divisible, _assert
from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function
from ._manipulate import named_apply, checkpoint_seq
from ._pretrained import generate_default_cfgs
from ._registry import register_model
__all__ = ['MaxxVitCfg', 'MaxxVitConvCfg', 'MaxxVitTransformerCfg', 'MaxxVit']
@dataclass
class MaxxVitTransformerCfg:
dim_head: int = 32
head_first: bool = True # head ordering in qkv channel dim
expand_ratio: float = 4.0
expand_first: bool = True
shortcut_bias: bool = True
attn_bias: bool = True
attn_drop: float = 0.
proj_drop: float = 0.
pool_type: str = 'avg2'
rel_pos_type: str = 'bias'
rel_pos_dim: int = 512 # for relative position types w/ MLP
partition_ratio: int = 32
window_size: Optional[Tuple[int, int]] = None
grid_size: Optional[Tuple[int, int]] = None
no_block_attn: bool = False # disable window block attention for maxvit (ie only grid)
use_nchw_attn: bool = False # for MaxViT variants (not used for CoAt), keep tensors in NCHW order
init_values: Optional[float] = None
act_layer: str = 'gelu'
norm_layer: str = 'layernorm2d'
norm_layer_cl: str = 'layernorm'
norm_eps: float = 1e-6
def __post_init__(self):
if self.grid_size is not None:
self.grid_size = to_2tuple(self.grid_size)
if self.window_size is not None:
self.window_size = to_2tuple(self.window_size)
if self.grid_size is None:
self.grid_size = self.window_size
@dataclass
class MaxxVitConvCfg:
block_type: str = 'mbconv'
expand_ratio: float = 4.0
expand_output: bool = True # calculate expansion channels from output (vs input chs)
kernel_size: int = 3
group_size: int = 1 # 1 == depthwise
pre_norm_act: bool = False # activation after pre-norm
output_bias: bool = True # bias for shortcut + final 1x1 projection conv
stride_mode: str = 'dw' # stride done via one of 'pool', '1x1', 'dw'
pool_type: str = 'avg2'
downsample_pool_type: str = 'avg2'
padding: str = ''
attn_early: bool = False # apply attn between conv2 and norm2, instead of after norm2
attn_layer: str = 'se'
attn_act_layer: str = 'silu'
attn_ratio: float = 0.25
init_values: Optional[float] = 1e-6 # for ConvNeXt block, ignored by MBConv
act_layer: str = 'gelu'
norm_layer: str = ''
norm_layer_cl: str = ''
norm_eps: Optional[float] = None
def __post_init__(self):
# mbconv vs convnext blocks have different defaults, set in post_init to avoid explicit config args
assert self.block_type in ('mbconv', 'convnext')
use_mbconv = self.block_type == 'mbconv'
if not self.norm_layer:
self.norm_layer = 'batchnorm2d' if use_mbconv else 'layernorm2d'
if not self.norm_layer_cl and not use_mbconv:
self.norm_layer_cl = 'layernorm'
if self.norm_eps is None:
self.norm_eps = 1e-5 if use_mbconv else 1e-6
self.downsample_pool_type = self.downsample_pool_type or self.pool_type
@dataclass
class MaxxVitCfg:
embed_dim: Tuple[int, ...] = (96, 192, 384, 768)
depths: Tuple[int, ...] = (2, 3, 5, 2)
block_type: Tuple[Union[str, Tuple[str, ...]], ...] = ('C', 'C', 'T', 'T')
stem_width: Union[int, Tuple[int, int]] = 64
stem_bias: bool = False
2023-02-04 03:28:57 +08:00
conv_cfg: MaxxVitConvCfg = field(default_factory=MaxxVitConvCfg)
transformer_cfg: MaxxVitTransformerCfg = field(default_factory=MaxxVitTransformerCfg)
head_hidden_size: int = None
weight_init: str = 'vit_eff'
class Attention2d(nn.Module):
""" multi-head attention for 2D NCHW tensors"""
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
dim_head: int = 32,
bias: bool = True,
expand_first: bool = True,
head_first: bool = True,
rel_pos_cls: Callable = None,
attn_drop: float = 0.,
proj_drop: float = 0.
):
super().__init__()
dim_out = dim_out or dim
dim_attn = dim_out if expand_first else dim
self.num_heads = dim_attn // dim_head
self.dim_head = dim_head
self.head_first = head_first
self.scale = dim_head ** -0.5
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias)
self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
B, C, H, W = x.shape
if self.head_first:
q, k, v = self.qkv(x).view(B, self.num_heads, self.dim_head * 3, -1).chunk(3, dim=2)
else:
q, k, v = self.qkv(x).reshape(B, 3, self.num_heads, self.dim_head, -1).unbind(1)
if self.fast_attn:
if self.rel_pos is not None:
attn_bias = self.rel_pos.get_bias()
elif shared_rel_pos is not None:
attn_bias = shared_rel_pos
else:
attn_bias = None
x = torch.nn.functional.scaled_dot_product_attention(
q.transpose(-1, -2),
k.transpose(-1, -2),
v.transpose(-1, -2),
attn_mask=attn_bias,
dropout_p=self.attn_drop.p,
).transpose(-1, -2).reshape(B, -1, H, W)
else:
q = q * self.scale
attn = q.transpose(-2, -1) @ k
if self.rel_pos is not None:
attn = self.rel_pos(attn)
elif shared_rel_pos is not None:
attn = attn + shared_rel_pos
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
x = self.proj(x)
x = self.proj_drop(x)
return x
class AttentionCl(nn.Module):
""" Channels-last multi-head attention (B, ..., C) """
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
dim_head: int = 32,
bias: bool = True,
expand_first: bool = True,
head_first: bool = True,
rel_pos_cls: Callable = None,
attn_drop: float = 0.,
proj_drop: float = 0.
):
super().__init__()
dim_out = dim_out or dim
dim_attn = dim_out if expand_first and dim_out > dim else dim
assert dim_attn % dim_head == 0, 'attn dim should be divisible by head_dim'
self.num_heads = dim_attn // dim_head
self.dim_head = dim_head
self.head_first = head_first
self.scale = dim_head ** -0.5
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
self.qkv = nn.Linear(dim, dim_attn * 3, bias=bias)
self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim_attn, dim_out, bias=bias)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
B = x.shape[0]
restore_shape = x.shape[:-1]
if self.head_first:
q, k, v = self.qkv(x).view(B, -1, self.num_heads, self.dim_head * 3).transpose(1, 2).chunk(3, dim=3)
else:
q, k, v = self.qkv(x).reshape(B, -1, 3, self.num_heads, self.dim_head).transpose(1, 3).unbind(2)
if self.fast_attn:
if self.rel_pos is not None:
attn_bias = self.rel_pos.get_bias()
elif shared_rel_pos is not None:
attn_bias = shared_rel_pos
else:
attn_bias = None
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_bias,
dropout_p=self.attn_drop.p,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
if self.rel_pos is not None:
attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos)
elif shared_rel_pos is not None:
attn = attn + shared_rel_pos
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(restore_shape + (-1,))
x = self.proj(x)
x = self.proj_drop(x)
return x
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
gamma = self.gamma
return x.mul_(gamma) if self.inplace else x * gamma
class LayerScale2d(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
gamma = self.gamma.view(1, -1, 1, 1)
return x.mul_(gamma) if self.inplace else x * gamma
class Downsample2d(nn.Module):
""" A downsample pooling module supporting several maxpool and avgpool modes
* 'max' - MaxPool2d w/ kernel_size 3, stride 2, padding 1
* 'max2' - MaxPool2d w/ kernel_size = stride = 2
* 'avg' - AvgPool2d w/ kernel_size 3, stride 2, padding 1
* 'avg2' - AvgPool2d w/ kernel_size = stride = 2
"""
def __init__(
self,
dim: int,
dim_out: int,
pool_type: str = 'avg2',
padding: str = '',
bias: bool = True,
):
super().__init__()
assert pool_type in ('max', 'max2', 'avg', 'avg2')
if pool_type == 'max':
self.pool = create_pool2d('max', kernel_size=3, stride=2, padding=padding or 1)
elif pool_type == 'max2':
self.pool = create_pool2d('max', 2, padding=padding or 0) # kernel_size == stride == 2
elif pool_type == 'avg':
self.pool = create_pool2d(
'avg', kernel_size=3, stride=2, count_include_pad=False, padding=padding or 1)
else:
self.pool = create_pool2d('avg', 2, padding=padding or 0)
if dim != dim_out:
self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias)
else:
self.expand = nn.Identity()
def forward(self, x):
x = self.pool(x) # spatial downsample
x = self.expand(x) # expand chs
return x
def _init_transformer(module, name, scheme=''):
if isinstance(module, (nn.Conv2d, nn.Linear)):
if scheme == 'normal':
nn.init.normal_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif scheme == 'trunc_normal':
trunc_normal_tf_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif scheme == 'xavier_normal':
nn.init.xavier_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
else:
# vit like
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
if 'mlp' in name:
nn.init.normal_(module.bias, std=1e-6)
else:
nn.init.zeros_(module.bias)
class TransformerBlock2d(nn.Module):
""" Transformer block with 2D downsampling
'2D' NCHW tensor layout
Some gains can be seen on GPU using a 1D / CL block, BUT w/ the need to switch back/forth to NCHW
for spatial pooling, the benefit is minimal so ended up using just this variant for CoAt configs.
This impl was faster on TPU w/ PT XLA than the 1D experiment.
"""
def __init__(
self,
dim: int,
dim_out: int,
stride: int = 1,
rel_pos_cls: Callable = None,
cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
drop_path: float = 0.,
):
super().__init__()
norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps)
act_layer = get_act_layer(cfg.act_layer)
if stride == 2:
self.shortcut = Downsample2d(dim, dim_out, pool_type=cfg.pool_type, bias=cfg.shortcut_bias)
self.norm1 = nn.Sequential(OrderedDict([
('norm', norm_layer(dim)),
('down', Downsample2d(dim, dim, pool_type=cfg.pool_type)),
]))
else:
assert dim == dim_out
self.shortcut = nn.Identity()
self.norm1 = norm_layer(dim)
self.attn = Attention2d(
dim,
dim_out,
dim_head=cfg.dim_head,
expand_first=cfg.expand_first,
bias=cfg.attn_bias,
rel_pos_cls=rel_pos_cls,
attn_drop=cfg.attn_drop,
proj_drop=cfg.proj_drop
)
self.ls1 = LayerScale2d(dim_out, init_values=cfg.init_values) if cfg.init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim_out)
self.mlp = ConvMlp(
in_features=dim_out,
hidden_features=int(dim_out * cfg.expand_ratio),
act_layer=act_layer,
drop=cfg.proj_drop)
self.ls2 = LayerScale2d(dim_out, init_values=cfg.init_values) if cfg.init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def init_weights(self, scheme=''):
named_apply(partial(_init_transformer, scheme=scheme), self)
def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
x = self.shortcut(x) + self.drop_path1(self.ls1(self.attn(self.norm1(x), shared_rel_pos=shared_rel_pos)))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
def _init_conv(module, name, scheme=''):
if isinstance(module, nn.Conv2d):
if scheme == 'normal':
nn.init.normal_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif scheme == 'trunc_normal':
trunc_normal_tf_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif scheme == 'xavier_normal':
nn.init.xavier_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
else:
# efficientnet like
fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
fan_out //= module.groups
nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out))
if module.bias is not None:
nn.init.zeros_(module.bias)
def num_groups(group_size, channels):
if not group_size: # 0 or None
return 1 # normal conv with 1 group
else:
# NOTE group_size == 1 -> depthwise conv
assert channels % group_size == 0
return channels // group_size
class MbConvBlock(nn.Module):
""" Pre-Norm Conv Block - 1x1 - kxk - 1x1, w/ inverted bottleneck (expand)
"""
def __init__(
self,
in_chs: int,
out_chs: int,
stride: int = 1,
dilation: Tuple[int, int] = (1, 1),
cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
drop_path: float = 0.
):
super(MbConvBlock, self).__init__()
norm_act_layer = partial(get_norm_act_layer(cfg.norm_layer, cfg.act_layer), eps=cfg.norm_eps)
mid_chs = make_divisible((out_chs if cfg.expand_output else in_chs) * cfg.expand_ratio)
groups = num_groups(cfg.group_size, mid_chs)
if stride == 2:
self.shortcut = Downsample2d(
in_chs, out_chs, pool_type=cfg.pool_type, bias=cfg.output_bias, padding=cfg.padding)
else:
self.shortcut = nn.Identity()
assert cfg.stride_mode in ('pool', '1x1', 'dw')
stride_pool, stride_1, stride_2 = 1, 1, 1
if cfg.stride_mode == 'pool':
# NOTE this is not described in paper, experiment to find faster option that doesn't stride in 1x1
stride_pool, dilation_2 = stride, dilation[1]
# FIXME handle dilation of avg pool
elif cfg.stride_mode == '1x1':
# NOTE I don't like this option described in paper, 1x1 w/ stride throws info away
stride_1, dilation_2 = stride, dilation[1]
else:
stride_2, dilation_2 = stride, dilation[0]
self.pre_norm = norm_act_layer(in_chs, apply_act=cfg.pre_norm_act)
if stride_pool > 1:
self.down = Downsample2d(in_chs, in_chs, pool_type=cfg.downsample_pool_type, padding=cfg.padding)
else:
self.down = nn.Identity()
self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=stride_1)
self.norm1 = norm_act_layer(mid_chs)
self.conv2_kxk = create_conv2d(
mid_chs, mid_chs, cfg.kernel_size,
stride=stride_2, dilation=dilation_2, groups=groups, padding=cfg.padding)
attn_kwargs = {}
if isinstance(cfg.attn_layer, str):
if cfg.attn_layer == 'se' or cfg.attn_layer == 'eca':
attn_kwargs['act_layer'] = cfg.attn_act_layer
attn_kwargs['rd_channels'] = int(cfg.attn_ratio * (out_chs if cfg.expand_output else mid_chs))
# two different orderings for SE and norm2 (due to some weights and trials using SE before norm2)
if cfg.attn_early:
self.se_early = create_attn(cfg.attn_layer, mid_chs, **attn_kwargs)
self.norm2 = norm_act_layer(mid_chs)
self.se = None
else:
self.se_early = None
self.norm2 = norm_act_layer(mid_chs)
self.se = create_attn(cfg.attn_layer, mid_chs, **attn_kwargs)
self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=cfg.output_bias)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def init_weights(self, scheme=''):
named_apply(partial(_init_conv, scheme=scheme), self)
def forward(self, x):
shortcut = self.shortcut(x)
x = self.pre_norm(x)
x = self.down(x)
# 1x1 expansion conv & norm-act
x = self.conv1_1x1(x)
x = self.norm1(x)
# depthwise / grouped 3x3 conv w/ SE (or other) channel attention & norm-act
x = self.conv2_kxk(x)
if self.se_early is not None:
x = self.se_early(x)
x = self.norm2(x)
if self.se is not None:
x = self.se(x)
# 1x1 linear projection to output width
x = self.conv3_1x1(x)
x = self.drop_path(x) + shortcut
return x
class ConvNeXtBlock(nn.Module):
""" ConvNeXt Block
"""
def __init__(
self,
in_chs: int,
out_chs: Optional[int] = None,
kernel_size: int = 7,
stride: int = 1,
dilation: Tuple[int, int] = (1, 1),
cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
conv_mlp: bool = True,
drop_path: float = 0.
):
super().__init__()
out_chs = out_chs or in_chs
act_layer = get_act_layer(cfg.act_layer)
if conv_mlp:
norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps)
mlp_layer = ConvMlp
else:
assert 'layernorm' in cfg.norm_layer
norm_layer = LayerNorm
mlp_layer = Mlp
self.use_conv_mlp = conv_mlp
if stride == 2:
self.shortcut = Downsample2d(in_chs, out_chs)
elif in_chs != out_chs:
self.shortcut = nn.Conv2d(in_chs, out_chs, kernel_size=1, bias=cfg.output_bias)
else:
self.shortcut = nn.Identity()
assert cfg.stride_mode in ('pool', 'dw')
stride_pool, stride_dw = 1, 1
# FIXME handle dilation?
if cfg.stride_mode == 'pool':
stride_pool = stride
else:
stride_dw = stride
if stride_pool == 2:
self.down = Downsample2d(in_chs, in_chs, pool_type=cfg.downsample_pool_type)
else:
self.down = nn.Identity()
self.conv_dw = create_conv2d(
in_chs, out_chs, kernel_size=kernel_size, stride=stride_dw, dilation=dilation[1],
depthwise=True, bias=cfg.output_bias)
self.norm = norm_layer(out_chs)
self.mlp = mlp_layer(out_chs, int(cfg.expand_ratio * out_chs), bias=cfg.output_bias, act_layer=act_layer)
if conv_mlp:
self.ls = LayerScale2d(out_chs, cfg.init_values) if cfg.init_values else nn.Identity()
else:
self.ls = LayerScale(out_chs, cfg.init_values) if cfg.init_values else nn.Identity()
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
shortcut = self.shortcut(x)
x = self.down(x)
x = self.conv_dw(x)
if self.use_conv_mlp:
x = self.norm(x)
x = self.mlp(x)
x = self.ls(x)
else:
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
x = self.mlp(x)
x = self.ls(x)
x = x.permute(0, 3, 1, 2)
x = self.drop_path(x) + shortcut
return x
def window_partition(x, window_size: List[int]):
B, H, W, C = x.shape
_assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})')
_assert(W % window_size[1] == 0, '')
x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
return windows
@register_notrace_function # reason: int argument is a Proxy
def window_reverse(windows, window_size: List[int], img_size: List[int]):
H, W = img_size
C = windows.shape[-1]
x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
return x
def grid_partition(x, grid_size: List[int]):
B, H, W, C = x.shape
_assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}')
_assert(W % grid_size[1] == 0, '')
x = x.view(B, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1], C)
windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, grid_size[0], grid_size[1], C)
return windows
@register_notrace_function # reason: int argument is a Proxy
def grid_reverse(windows, grid_size: List[int], img_size: List[int]):
H, W = img_size
C = windows.shape[-1]
x = windows.view(-1, H // grid_size[0], W // grid_size[1], grid_size[0], grid_size[1], C)
x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, H, W, C)
return x
def get_rel_pos_cls(cfg: MaxxVitTransformerCfg, window_size):
rel_pos_cls = None
if cfg.rel_pos_type == 'mlp':
rel_pos_cls = partial(RelPosMlp, window_size=window_size, hidden_dim=cfg.rel_pos_dim)
elif cfg.rel_pos_type == 'bias':
rel_pos_cls = partial(RelPosBias, window_size=window_size)
elif cfg.rel_pos_type == 'bias_tf':
rel_pos_cls = partial(RelPosBiasTf, window_size=window_size)
return rel_pos_cls
class PartitionAttentionCl(nn.Module):
""" Grid or Block partition + Attn + FFN.
NxC 'channels last' tensor layout.
"""
def __init__(
self,
dim: int,
partition_type: str = 'block',
cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
drop_path: float = 0.,
):
super().__init__()
norm_layer = partial(get_norm_layer(cfg.norm_layer_cl), eps=cfg.norm_eps) # NOTE this block is channels-last
act_layer = get_act_layer(cfg.act_layer)
self.partition_block = partition_type == 'block'
self.partition_size = to_2tuple(cfg.window_size if self.partition_block else cfg.grid_size)
rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size)
self.norm1 = norm_layer(dim)
self.attn = AttentionCl(
dim,
dim,
dim_head=cfg.dim_head,
bias=cfg.attn_bias,
head_first=cfg.head_first,
rel_pos_cls=rel_pos_cls,
attn_drop=cfg.attn_drop,
proj_drop=cfg.proj_drop,
)
self.ls1 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * cfg.expand_ratio),
act_layer=act_layer,
drop=cfg.proj_drop)
self.ls2 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def _partition_attn(self, x):
img_size = x.shape[1:3]
if self.partition_block:
partitioned = window_partition(x, self.partition_size)
else:
partitioned = grid_partition(x, self.partition_size)
partitioned = self.attn(partitioned)
if self.partition_block:
x = window_reverse(partitioned, self.partition_size, img_size)
else:
x = grid_reverse(partitioned, self.partition_size, img_size)
return x
def forward(self, x):
x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x))))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
class ParallelPartitionAttention(nn.Module):
""" Experimental. Grid and Block partition + single FFN
NxC tensor layout.
"""
def __init__(
self,
dim: int,
cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
drop_path: float = 0.,
):
super().__init__()
assert dim % 2 == 0
norm_layer = partial(get_norm_layer(cfg.norm_layer_cl), eps=cfg.norm_eps) # NOTE this block is channels-last
act_layer = get_act_layer(cfg.act_layer)
assert cfg.window_size == cfg.grid_size
self.partition_size = to_2tuple(cfg.window_size)
rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size)
self.norm1 = norm_layer(dim)
self.attn_block = AttentionCl(
dim,
dim // 2,
dim_head=cfg.dim_head,
bias=cfg.attn_bias,
head_first=cfg.head_first,
rel_pos_cls=rel_pos_cls,
attn_drop=cfg.attn_drop,
proj_drop=cfg.proj_drop,
)
self.attn_grid = AttentionCl(
dim,
dim // 2,
dim_head=cfg.dim_head,
bias=cfg.attn_bias,
head_first=cfg.head_first,
rel_pos_cls=rel_pos_cls,
attn_drop=cfg.attn_drop,
proj_drop=cfg.proj_drop,
)
self.ls1 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * cfg.expand_ratio),
out_features=dim,
act_layer=act_layer,
drop=cfg.proj_drop)
self.ls2 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def _partition_attn(self, x):
img_size = x.shape[1:3]
partitioned_block = window_partition(x, self.partition_size)
partitioned_block = self.attn_block(partitioned_block)
x_window = window_reverse(partitioned_block, self.partition_size, img_size)
partitioned_grid = grid_partition(x, self.partition_size)
partitioned_grid = self.attn_grid(partitioned_grid)
x_grid = grid_reverse(partitioned_grid, self.partition_size, img_size)
return torch.cat([x_window, x_grid], dim=-1)
def forward(self, x):
x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x))))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
def window_partition_nchw(x, window_size: List[int]):
B, C, H, W = x.shape
_assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})')
_assert(W % window_size[1] == 0, '')
x = x.view(B, C, H // window_size[0], window_size[0], W // window_size[1], window_size[1])
windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, C, window_size[0], window_size[1])
return windows
@register_notrace_function # reason: int argument is a Proxy
def window_reverse_nchw(windows, window_size: List[int], img_size: List[int]):
H, W = img_size
C = windows.shape[1]
x = windows.view(-1, H // window_size[0], W // window_size[1], C, window_size[0], window_size[1])
x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, C, H, W)
return x
def grid_partition_nchw(x, grid_size: List[int]):
B, C, H, W = x.shape
_assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}')
_assert(W % grid_size[1] == 0, '')
x = x.view(B, C, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1])
windows = x.permute(0, 3, 5, 1, 2, 4).contiguous().view(-1, C, grid_size[0], grid_size[1])
return windows
@register_notrace_function # reason: int argument is a Proxy
def grid_reverse_nchw(windows, grid_size: List[int], img_size: List[int]):
H, W = img_size
C = windows.shape[1]
x = windows.view(-1, H // grid_size[0], W // grid_size[1], C, grid_size[0], grid_size[1])
x = x.permute(0, 3, 4, 1, 5, 2).contiguous().view(-1, C, H, W)
return x
class PartitionAttention2d(nn.Module):
""" Grid or Block partition + Attn + FFN
'2D' NCHW tensor layout.
"""
def __init__(
self,
dim: int,
partition_type: str = 'block',
cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
drop_path: float = 0.,
):
super().__init__()
norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) # NOTE this block is channels-last
act_layer = get_act_layer(cfg.act_layer)
self.partition_block = partition_type == 'block'
self.partition_size = to_2tuple(cfg.window_size if self.partition_block else cfg.grid_size)
rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size)
self.norm1 = norm_layer(dim)
self.attn = Attention2d(
dim,
dim,
dim_head=cfg.dim_head,
bias=cfg.attn_bias,
head_first=cfg.head_first,
rel_pos_cls=rel_pos_cls,
attn_drop=cfg.attn_drop,
proj_drop=cfg.proj_drop,
)
self.ls1 = LayerScale2d(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = ConvMlp(
in_features=dim,
hidden_features=int(dim * cfg.expand_ratio),
act_layer=act_layer,
drop=cfg.proj_drop)
self.ls2 = LayerScale2d(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def _partition_attn(self, x):
img_size = x.shape[-2:]
if self.partition_block:
partitioned = window_partition_nchw(x, self.partition_size)
else:
partitioned = grid_partition_nchw(x, self.partition_size)
partitioned = self.attn(partitioned)
if self.partition_block:
x = window_reverse_nchw(partitioned, self.partition_size, img_size)
else:
x = grid_reverse_nchw(partitioned, self.partition_size, img_size)
return x
def forward(self, x):
x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x))))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
class MaxxVitBlock(nn.Module):
""" MaxVit conv, window partition + FFN , grid partition + FFN
"""
def __init__(
self,
dim: int,
dim_out: int,
stride: int = 1,
conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
drop_path: float = 0.,
):
super().__init__()
self.nchw_attn = transformer_cfg.use_nchw_attn
conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock
self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)
attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path)
partition_layer = PartitionAttention2d if self.nchw_attn else PartitionAttentionCl
self.attn_block = None if transformer_cfg.no_block_attn else partition_layer(**attn_kwargs)
self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs)
def init_weights(self, scheme=''):
if self.attn_block is not None:
named_apply(partial(_init_transformer, scheme=scheme), self.attn_block)
named_apply(partial(_init_transformer, scheme=scheme), self.attn_grid)
named_apply(partial(_init_conv, scheme=scheme), self.conv)
def forward(self, x):
# NCHW format
x = self.conv(x)
if not self.nchw_attn:
x = x.permute(0, 2, 3, 1) # to NHWC (channels-last)
if self.attn_block is not None:
x = self.attn_block(x)
x = self.attn_grid(x)
if not self.nchw_attn:
x = x.permute(0, 3, 1, 2) # back to NCHW
return x
class ParallelMaxxVitBlock(nn.Module):
""" MaxVit block with parallel cat(window + grid), one FF
Experimental timm block.
"""
def __init__(
self,
dim,
dim_out,
stride=1,
num_conv=2,
conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
drop_path=0.,
):
super().__init__()
conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock
if num_conv > 1:
convs = [conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)]
convs += [conv_cls(dim_out, dim_out, cfg=conv_cfg, drop_path=drop_path)] * (num_conv - 1)
self.conv = nn.Sequential(*convs)
else:
self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)
self.attn = ParallelPartitionAttention(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path)
def init_weights(self, scheme=''):
named_apply(partial(_init_transformer, scheme=scheme), self.attn)
named_apply(partial(_init_conv, scheme=scheme), self.conv)
def forward(self, x):
x = self.conv(x)
x = x.permute(0, 2, 3, 1)
x = self.attn(x)
x = x.permute(0, 3, 1, 2)
return x
class MaxxVitStage(nn.Module):
def __init__(
self,
in_chs: int,
out_chs: int,
stride: int = 2,
depth: int = 4,
feat_size: Tuple[int, int] = (14, 14),
block_types: Union[str, Tuple[str]] = 'C',
transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
drop_path: Union[float, List[float]] = 0.,
):
super().__init__()
self.grad_checkpointing = False
block_types = extend_tuple(block_types, depth)
blocks = []
for i, t in enumerate(block_types):
block_stride = stride if i == 0 else 1
assert t in ('C', 'T', 'M', 'PM')
if t == 'C':
conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock
blocks += [conv_cls(
in_chs,
out_chs,
stride=block_stride,
cfg=conv_cfg,
drop_path=drop_path[i],
)]
elif t == 'T':
rel_pos_cls = get_rel_pos_cls(transformer_cfg, feat_size)
blocks += [TransformerBlock2d(
in_chs,
out_chs,
stride=block_stride,
rel_pos_cls=rel_pos_cls,
cfg=transformer_cfg,
drop_path=drop_path[i],
)]
elif t == 'M':
blocks += [MaxxVitBlock(
in_chs,
out_chs,
stride=block_stride,
conv_cfg=conv_cfg,
transformer_cfg=transformer_cfg,
drop_path=drop_path[i],
)]
elif t == 'PM':
blocks += [ParallelMaxxVitBlock(
in_chs,
out_chs,
stride=block_stride,
conv_cfg=conv_cfg,
transformer_cfg=transformer_cfg,
drop_path=drop_path[i],
)]
in_chs = out_chs
self.blocks = nn.Sequential(*blocks)
def forward(self, x):
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
return x
class Stem(nn.Module):
def __init__(
self,
in_chs: int,
out_chs: int,
kernel_size: int = 3,
padding: str = '',
bias: bool = False,
act_layer: str = 'gelu',
norm_layer: str = 'batchnorm2d',
norm_eps: float = 1e-5,
):
super().__init__()
if not isinstance(out_chs, (list, tuple)):
out_chs = to_2tuple(out_chs)
norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps)
self.out_chs = out_chs[-1]
self.stride = 2
self.conv1 = create_conv2d(in_chs, out_chs[0], kernel_size, stride=2, padding=padding, bias=bias)
self.norm1 = norm_act_layer(out_chs[0])
self.conv2 = create_conv2d(out_chs[0], out_chs[1], kernel_size, stride=1, padding=padding, bias=bias)
def init_weights(self, scheme=''):
named_apply(partial(_init_conv, scheme=scheme), self)
def forward(self, x):
x = self.conv1(x)
x = self.norm1(x)
x = self.conv2(x)
return x
def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]):
if cfg.window_size is not None:
assert cfg.grid_size
return cfg
partition_size = img_size[0] // cfg.partition_ratio, img_size[1] // cfg.partition_ratio
cfg = replace(cfg, window_size=partition_size, grid_size=partition_size)
return cfg
2023-01-08 08:29:12 +08:00
def _overlay_kwargs(cfg: MaxxVitCfg, **kwargs):
transformer_kwargs = {}
conv_kwargs = {}
base_kwargs = {}
for k, v in kwargs.items():
if k.startswith('transformer_'):
transformer_kwargs[k.replace('transformer_', '')] = v
elif k.startswith('conv_'):
conv_kwargs[k.replace('conv_', '')] = v
else:
base_kwargs[k] = v
cfg = replace(
cfg,
transformer_cfg=replace(cfg.transformer_cfg, **transformer_kwargs),
conv_cfg=replace(cfg.conv_cfg, **conv_kwargs),
**base_kwargs
)
return cfg
class MaxxVit(nn.Module):
""" CoaTNet + MaxVit base model.
Highly configurable for different block compositions, tensor layouts, pooling types.
"""
def __init__(
self,
cfg: MaxxVitCfg,
img_size: Union[int, Tuple[int, int]] = 224,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'avg',
drop_rate: float = 0.,
2023-01-08 08:29:12 +08:00
drop_path_rate: float = 0.,
**kwargs,
):
super().__init__()
img_size = to_2tuple(img_size)
2023-01-08 08:29:12 +08:00
if kwargs:
cfg = _overlay_kwargs(cfg, **kwargs)
transformer_cfg = cfg_window_size(cfg.transformer_cfg, img_size)
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = cfg.embed_dim[-1]
self.drop_rate = drop_rate
self.grad_checkpointing = False
self.feature_info = []
self.stem = Stem(
in_chs=in_chans,
out_chs=cfg.stem_width,
padding=cfg.conv_cfg.padding,
bias=cfg.stem_bias,
act_layer=cfg.conv_cfg.act_layer,
norm_layer=cfg.conv_cfg.norm_layer,
norm_eps=cfg.conv_cfg.norm_eps,
)
stride = self.stem.stride
self.feature_info += [dict(num_chs=self.stem.out_chs, reduction=2, module='stem')]
feat_size = tuple([i // s for i, s in zip(img_size, to_2tuple(stride))])
num_stages = len(cfg.embed_dim)
assert len(cfg.depths) == num_stages
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
in_chs = self.stem.out_chs
stages = []
for i in range(num_stages):
stage_stride = 2
out_chs = cfg.embed_dim[i]
feat_size = tuple([(r - 1) // stage_stride + 1 for r in feat_size])
stages += [MaxxVitStage(
in_chs,
out_chs,
depth=cfg.depths[i],
block_types=cfg.block_type[i],
conv_cfg=cfg.conv_cfg,
transformer_cfg=transformer_cfg,
feat_size=feat_size,
drop_path=dpr[i],
)]
stride *= stage_stride
in_chs = out_chs
self.feature_info += [dict(num_chs=out_chs, reduction=stride, module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)
final_norm_layer = partial(get_norm_layer(cfg.transformer_cfg.norm_layer), eps=cfg.transformer_cfg.norm_eps)
self.head_hidden_size = cfg.head_hidden_size
if self.head_hidden_size:
self.norm = nn.Identity()
self.head = NormMlpClassifierHead(
self.num_features,
num_classes,
hidden_size=self.head_hidden_size,
pool_type=global_pool,
drop_rate=drop_rate,
norm_layer=final_norm_layer,
)
else:
# standard classifier head w/ norm, pooling, fc classifier
self.norm = final_norm_layer(self.num_features)
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
# Weight init (default PyTorch init works well for AdamW if scheme not set)
assert cfg.weight_init in ('', 'normal', 'trunc_normal', 'xavier_normal', 'vit_eff')
if cfg.weight_init:
named_apply(partial(self._init_weights, scheme=cfg.weight_init), self)
def _init_weights(self, module, name, scheme=''):
if hasattr(module, 'init_weights'):
try:
module.init_weights(scheme=scheme)
except TypeError:
module.init_weights()
@torch.jit.ignore
def no_weight_decay(self):
return {
k for k, _ in self.named_parameters()
if any(n in k for n in ["relative_position_bias_table", "rel_pos.mlp"])}
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^stem', # stem and embed
blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
)
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
for s in self.stages:
s.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)
x = self.norm(x)
return x
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def _rw_coat_cfg(
stride_mode='pool',
pool_type='avg2',
conv_output_bias=False,
conv_attn_early=False,
conv_attn_act_layer='relu',
conv_norm_layer='',
transformer_shortcut_bias=True,
transformer_norm_layer='layernorm2d',
transformer_norm_layer_cl='layernorm',
init_values=None,
rel_pos_type='bias',
rel_pos_dim=512,
):
# 'RW' timm variant models were created and trained before seeing https://github.com/google-research/maxvit
# Common differences for initial timm models:
# - pre-norm layer in MZBConv included an activation after norm
# - mbconv expansion calculated from input instead of output chs
# - mbconv shortcut and final 1x1 conv did not have a bias
# - SE act layer was relu, not silu
# - mbconv uses silu in timm, not gelu
# - expansion in attention block done via output proj, not input proj
# Variable differences (evolved over training initial models):
# - avg pool with kernel_size=2 favoured downsampling (instead of maxpool for coat)
# - SE attention was between conv2 and norm/act
# - default to avg pool for mbconv downsample instead of 1x1 or dw conv
# - transformer block shortcut has no bias
return dict(
conv_cfg=MaxxVitConvCfg(
stride_mode=stride_mode,
pool_type=pool_type,
pre_norm_act=True,
expand_output=False,
output_bias=conv_output_bias,
attn_early=conv_attn_early,
attn_act_layer=conv_attn_act_layer,
act_layer='silu',
norm_layer=conv_norm_layer,
),
transformer_cfg=MaxxVitTransformerCfg(
expand_first=False,
shortcut_bias=transformer_shortcut_bias,
pool_type=pool_type,
init_values=init_values,
norm_layer=transformer_norm_layer,
norm_layer_cl=transformer_norm_layer_cl,
rel_pos_type=rel_pos_type,
rel_pos_dim=rel_pos_dim,
),
)
def _rw_max_cfg(
stride_mode='dw',
pool_type='avg2',
conv_output_bias=False,
conv_attn_ratio=1 / 16,
conv_norm_layer='',
transformer_norm_layer='layernorm2d',
transformer_norm_layer_cl='layernorm',
window_size=None,
dim_head=32,
init_values=None,
rel_pos_type='bias',
rel_pos_dim=512,
):
# 'RW' timm variant models were created and trained before seeing https://github.com/google-research/maxvit
# Differences of initial timm models:
# - mbconv expansion calculated from input instead of output chs
# - mbconv shortcut and final 1x1 conv did not have a bias
# - mbconv uses silu in timm, not gelu
# - expansion in attention block done via output proj, not input proj
return dict(
conv_cfg=MaxxVitConvCfg(
stride_mode=stride_mode,
pool_type=pool_type,
expand_output=False,
output_bias=conv_output_bias,
attn_ratio=conv_attn_ratio,
act_layer='silu',
norm_layer=conv_norm_layer,
),
transformer_cfg=MaxxVitTransformerCfg(
expand_first=False,
pool_type=pool_type,
dim_head=dim_head,
window_size=window_size,
init_values=init_values,
norm_layer=transformer_norm_layer,
norm_layer_cl=transformer_norm_layer_cl,
rel_pos_type=rel_pos_type,
rel_pos_dim=rel_pos_dim,
),
)
def _next_cfg(
stride_mode='dw',
pool_type='avg2',
conv_norm_layer='layernorm2d',
conv_norm_layer_cl='layernorm',
transformer_norm_layer='layernorm2d',
transformer_norm_layer_cl='layernorm',
window_size=None,
no_block_attn=False,
init_values=1e-6,
rel_pos_type='mlp', # MLP by default for maxxvit
rel_pos_dim=512,
):
# For experimental models with convnext instead of mbconv
init_values = to_2tuple(init_values)
return dict(
conv_cfg=MaxxVitConvCfg(
block_type='convnext',
stride_mode=stride_mode,
pool_type=pool_type,
expand_output=False,
init_values=init_values[0],
norm_layer=conv_norm_layer,
norm_layer_cl=conv_norm_layer_cl,
),
transformer_cfg=MaxxVitTransformerCfg(
expand_first=False,
pool_type=pool_type,
window_size=window_size,
no_block_attn=no_block_attn, # enabled for MaxxViT-V2
init_values=init_values[1],
norm_layer=transformer_norm_layer,
norm_layer_cl=transformer_norm_layer_cl,
rel_pos_type=rel_pos_type,
rel_pos_dim=rel_pos_dim,
),
)
def _tf_cfg():
return dict(
conv_cfg=MaxxVitConvCfg(
norm_eps=1e-3,
act_layer='gelu_tanh',
padding='same',
),
transformer_cfg=MaxxVitTransformerCfg(
norm_eps=1e-5,
act_layer='gelu_tanh',
head_first=False, # heads are interleaved (q_nh, q_hdim, k_nh, q_hdim, ....)
rel_pos_type='bias_tf',
),
)
model_cfgs = dict(
# timm specific CoAtNet configs
coatnet_pico_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(2, 3, 5, 2),
stem_width=(32, 64),
**_rw_max_cfg( # using newer max defaults here
conv_output_bias=True,
conv_attn_ratio=0.25,
),
),
coatnet_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(3, 4, 6, 3),
stem_width=(32, 64),
**_rw_max_cfg( # using newer max defaults here
stride_mode='pool',
conv_output_bias=True,
conv_attn_ratio=0.25,
),
),
coatnet_0_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 3, 7, 2), # deeper than paper '0' model
stem_width=(32, 64),
**_rw_coat_cfg(
conv_attn_early=True,
transformer_shortcut_bias=False,
),
),
coatnet_1_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2),
stem_width=(32, 64),
**_rw_coat_cfg(
stride_mode='dw',
conv_attn_early=True,
transformer_shortcut_bias=False,
)
),
coatnet_2_rw=MaxxVitCfg(
embed_dim=(128, 256, 512, 1024),
depths=(2, 6, 14, 2),
stem_width=(64, 128),
**_rw_coat_cfg(
stride_mode='dw',
conv_attn_act_layer='silu',
#init_values=1e-6,
),
),
coatnet_3_rw=MaxxVitCfg(
embed_dim=(192, 384, 768, 1536),
depths=(2, 6, 14, 2),
stem_width=(96, 192),
**_rw_coat_cfg(
stride_mode='dw',
conv_attn_act_layer='silu',
init_values=1e-6,
),
),
# Experimental CoAtNet configs w/ ImageNet-1k train (different norm layers, MLP rel-pos)
coatnet_bn_0_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 3, 7, 2), # deeper than paper '0' model
stem_width=(32, 64),
**_rw_coat_cfg(
stride_mode='dw',
conv_attn_early=True,
transformer_shortcut_bias=False,
transformer_norm_layer='batchnorm2d',
)
),
coatnet_rmlp_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(3, 4, 6, 3),
stem_width=(32, 64),
**_rw_max_cfg(
conv_output_bias=True,
conv_attn_ratio=0.25,
rel_pos_type='mlp',
rel_pos_dim=384,
),
),
coatnet_rmlp_0_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 3, 7, 2), # deeper than paper '0' model
stem_width=(32, 64),
**_rw_coat_cfg(
stride_mode='dw',
rel_pos_type='mlp',
),
),
coatnet_rmlp_1_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2),
stem_width=(32, 64),
**_rw_coat_cfg(
pool_type='max',
conv_attn_early=True,
transformer_shortcut_bias=False,
rel_pos_type='mlp',
rel_pos_dim=384, # was supposed to be 512, woops
),
),
coatnet_rmlp_1_rw2=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2),
stem_width=(32, 64),
**_rw_coat_cfg(
stride_mode='dw',
rel_pos_type='mlp',
rel_pos_dim=512, # was supposed to be 512, woops
),
),
coatnet_rmlp_2_rw=MaxxVitCfg(
embed_dim=(128, 256, 512, 1024),
depths=(2, 6, 14, 2),
stem_width=(64, 128),
**_rw_coat_cfg(
stride_mode='dw',
conv_attn_act_layer='silu',
init_values=1e-6,
rel_pos_type='mlp'
),
),
coatnet_rmlp_3_rw=MaxxVitCfg(
embed_dim=(192, 384, 768, 1536),
depths=(2, 6, 14, 2),
stem_width=(96, 192),
**_rw_coat_cfg(
stride_mode='dw',
conv_attn_act_layer='silu',
init_values=1e-6,
rel_pos_type='mlp'
),
),
coatnet_nano_cc=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(3, 4, 6, 3),
stem_width=(32, 64),
block_type=('C', 'C', ('C', 'T'), ('C', 'T')),
**_rw_coat_cfg(),
),
coatnext_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(3, 4, 6, 3),
stem_width=(32, 64),
weight_init='normal',
**_next_cfg(
rel_pos_type='bias',
init_values=(1e-5, None)
),
),
# Trying to be like the CoAtNet paper configs
coatnet_0=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 3, 5, 2),
stem_width=64,
head_hidden_size=768,
),
coatnet_1=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2),
stem_width=64,
head_hidden_size=768,
),
coatnet_2=MaxxVitCfg(
embed_dim=(128, 256, 512, 1024),
depths=(2, 6, 14, 2),
stem_width=128,
head_hidden_size=1024,
),
coatnet_3=MaxxVitCfg(
embed_dim=(192, 384, 768, 1536),
depths=(2, 6, 14, 2),
stem_width=192,
head_hidden_size=1536,
),
coatnet_4=MaxxVitCfg(
embed_dim=(192, 384, 768, 1536),
depths=(2, 12, 28, 2),
stem_width=192,
head_hidden_size=1536,
),
coatnet_5=MaxxVitCfg(
embed_dim=(256, 512, 1280, 2048),
depths=(2, 12, 28, 2),
stem_width=192,
head_hidden_size=2048,
),
# Experimental MaxVit configs
maxvit_pico_rw=MaxxVitCfg(
embed_dim=(32, 64, 128, 256),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(24, 32),
**_rw_max_cfg(),
),
maxvit_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(1, 2, 3, 1),
block_type=('M',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(),
),
maxvit_tiny_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(),
),
maxvit_tiny_pm=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2),
block_type=('PM',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(),
),
maxvit_rmlp_pico_rw=MaxxVitCfg(
embed_dim=(32, 64, 128, 256),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(24, 32),
**_rw_max_cfg(rel_pos_type='mlp'),
),
maxvit_rmlp_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(1, 2, 3, 1),
block_type=('M',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(rel_pos_type='mlp'),
),
maxvit_rmlp_tiny_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(rel_pos_type='mlp'),
),
maxvit_rmlp_small_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(
rel_pos_type='mlp',
init_values=1e-6,
),
),
maxvit_rmlp_base_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2),
block_type=('M',) * 4,
stem_width=(32, 64),
head_hidden_size=768,
**_rw_max_cfg(
rel_pos_type='mlp',
),
),
maxxvit_rmlp_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(1, 2, 3, 1),
block_type=('M',) * 4,
stem_width=(32, 64),
weight_init='normal',
**_next_cfg(),
),
maxxvit_rmlp_tiny_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(32, 64),
**_next_cfg(),
),
maxxvit_rmlp_small_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(48, 96),
**_next_cfg(),
),
maxxvitv2_nano_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(1, 2, 3, 1),
block_type=('M',) * 4,
stem_width=(48, 96),
weight_init='normal',
**_next_cfg(
no_block_attn=True,
rel_pos_type='bias',
),
),
maxxvitv2_rmlp_base_rw=MaxxVitCfg(
embed_dim=(128, 256, 512, 1024),
depths=(2, 6, 12, 2),
block_type=('M',) * 4,
stem_width=(64, 128),
**_next_cfg(
no_block_attn=True,
),
),
maxxvitv2_rmlp_large_rw=MaxxVitCfg(
embed_dim=(160, 320, 640, 1280),
depths=(2, 6, 16, 2),
block_type=('M',) * 4,
stem_width=(80, 160),
head_hidden_size=1280,
**_next_cfg(
no_block_attn=True,
),
),
# Trying to be like the MaxViT paper configs
maxvit_tiny_tf=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=64,
stem_bias=True,
head_hidden_size=512,
**_tf_cfg(),
),
maxvit_small_tf=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=64,
stem_bias=True,
head_hidden_size=768,
**_tf_cfg(),
),
maxvit_base_tf=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2),
block_type=('M',) * 4,
stem_width=64,
stem_bias=True,
head_hidden_size=768,
**_tf_cfg(),
),
maxvit_large_tf=MaxxVitCfg(
embed_dim=(128, 256, 512, 1024),
depths=(2, 6, 14, 2),
block_type=('M',) * 4,
stem_width=128,
stem_bias=True,
head_hidden_size=1024,
**_tf_cfg(),
),
maxvit_xlarge_tf=MaxxVitCfg(
embed_dim=(192, 384, 768, 1536),
depths=(2, 6, 14, 2),
block_type=('M',) * 4,
stem_width=192,
stem_bias=True,
head_hidden_size=1536,
**_tf_cfg(),
),
)
def checkpoint_filter_fn(state_dict, model: nn.Module):
model_state_dict = model.state_dict()
out_dict = {}
for k, v in state_dict.items():
if k in model_state_dict and v.ndim != model_state_dict[k].ndim and v.numel() == model_state_dict[k].numel():
# adapt between conv2d / linear layers
assert v.ndim in (2, 4)
v = v.reshape(model_state_dict[k].shape)
out_dict[k] = v
return out_dict
def _create_maxxvit(variant, cfg_variant=None, pretrained=False, **kwargs):
if cfg_variant is None:
if variant in model_cfgs:
cfg_variant = variant
else:
cfg_variant = '_'.join(variant.split('_')[:-1])
return build_model_with_cfg(
MaxxVit, variant, pretrained,
model_cfg=model_cfgs[cfg_variant],
feature_cfg=dict(flatten_sequential=True),
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.95, 'interpolation': 'bicubic',
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
'first_conv': 'stem.conv1', 'classifier': 'head.fc',
'fixed_input_size': True,
**kwargs
}
default_cfgs = generate_default_cfgs({
# timm specific CoAtNet configs, ImageNet-1k pretrain, fixed rel-pos
'coatnet_pico_rw_224.untrained': _cfg(url=''),
'coatnet_nano_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_nano_rw_224_sw-f53093b4.pth',
crop_pct=0.9),
'coatnet_0_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_0_rw_224_sw-a6439706.pth'),
'coatnet_1_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth'
),
# timm specific CoAtNet configs, ImageNet-12k pretrain w/ 1k fine-tune, fixed rel-pos
'coatnet_2_rw_224.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/'),
#'coatnet_3_rw_224.untrained': _cfg(url=''),
# Experimental CoAtNet configs w/ ImageNet-12k pretrain -> 1k fine-tune (different norm layers, MLP rel-pos)
'coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/'),
'coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/'),
'coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
# Experimental CoAtNet configs w/ ImageNet-1k train (different norm layers, MLP rel-pos)
'coatnet_bn_0_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_bn_0_rw_224_sw-c228e218.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
crop_pct=0.95),
'coatnet_rmlp_nano_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_nano_rw_224_sw-bd1d51b3.pth',
crop_pct=0.9),
'coatnet_rmlp_0_rw_224.untrained': _cfg(url=''),
'coatnet_rmlp_1_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'),
'coatnet_rmlp_2_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_2_rw_224_sw-5ccfac55.pth'),
'coatnet_rmlp_3_rw_224.untrained': _cfg(url=''),
'coatnet_nano_cc_224.untrained': _cfg(url=''),
'coatnext_nano_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnext_nano_rw_224_ad-22cb71c2.pth',
crop_pct=0.9),
# ImagenNet-12k pretrain CoAtNet
'coatnet_2_rw_224.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821),
'coatnet_3_rw_224.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821),
'coatnet_rmlp_1_rw2_224.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821),
'coatnet_rmlp_2_rw_224.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821),
# Trying to be like the CoAtNet paper configs (will adapt if 'tf' weights are ever released)
'coatnet_0_224.untrained': _cfg(url=''),
'coatnet_1_224.untrained': _cfg(url=''),
'coatnet_2_224.untrained': _cfg(url=''),
'coatnet_3_224.untrained': _cfg(url=''),
'coatnet_4_224.untrained': _cfg(url=''),
'coatnet_5_224.untrained': _cfg(url=''),
# timm specific MaxVit configs, ImageNet-1k pretrain or untrained
'maxvit_pico_rw_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_nano_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-fb127241.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_tiny_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_tiny_rw_224_sw-7d0dffeb.pth'),
'maxvit_tiny_rw_256.untrained': _cfg(
url='',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_tiny_pm_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
# timm specific MaxVit w/ MLP rel-pos, ImageNet-1k pretrain
'maxvit_rmlp_pico_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_pico_rw_256_sw-8d82f2c6.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_nano_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_tiny_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_tiny_rw_256_sw-bbef0ff5.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_small_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_small_rw_224_sw-6ef0ae4f.pth',
crop_pct=0.9,
),
'maxvit_rmlp_small_rw_256.untrained': _cfg(
url='',
input_size=(3, 256, 256), pool_size=(8, 8)),
# timm specific MaxVit w/ ImageNet-12k pretrain and 1k fine-tune
'maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
),
'maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
# timm specific MaxVit w/ ImageNet-12k pretrain
'maxvit_rmlp_base_rw_224.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821,
),
# timm MaxxViT configs (ConvNeXt conv blocks mixed with MaxVit transformer blocks)
'maxxvit_rmlp_nano_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_nano_rw_256_sw-0325d459.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvit_rmlp_tiny_rw_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvit_rmlp_small_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_small_rw_256_sw-37e217ff.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
# timm MaxxViT-V2 configs (ConvNeXt conv blocks mixed with MaxVit transformer blocks, more width, no block attn)
'maxxvitv2_nano_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/'),
'maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxxvitv2_rmlp_large_rw_224.untrained': _cfg(url=''),
'maxxvitv2_rmlp_base_rw_224.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821),
# MaxViT models ported from official Tensorflow impl
'maxvit_tiny_tf_224.in1k': _cfg(
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'maxvit_tiny_tf_384.in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_tiny_tf_512.in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
'maxvit_small_tf_224.in1k': _cfg(
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'maxvit_small_tf_384.in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_small_tf_512.in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
'maxvit_base_tf_224.in1k': _cfg(
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'maxvit_base_tf_384.in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_base_tf_512.in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
'maxvit_large_tf_224.in1k': _cfg(
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'maxvit_large_tf_384.in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_large_tf_512.in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
'maxvit_base_tf_224.in21k': _cfg(
url=''),
'maxvit_base_tf_384.in21k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_base_tf_512.in21k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
'maxvit_large_tf_224.in21k': _cfg(
url=''),
'maxvit_large_tf_384.in21k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_large_tf_512.in21k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
'maxvit_xlarge_tf_224.in21k': _cfg(
url=''),
'maxvit_xlarge_tf_384.in21k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_xlarge_tf_512.in21k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
})
@register_model
def coatnet_pico_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_pico_rw_224', pretrained=pretrained, **kwargs)
@register_model
def coatnet_nano_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_nano_rw_224', pretrained=pretrained, **kwargs)
@register_model
def coatnet_0_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_0_rw_224', pretrained=pretrained, **kwargs)
@register_model
def coatnet_1_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_1_rw_224', pretrained=pretrained, **kwargs)
@register_model
def coatnet_2_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_2_rw_224', pretrained=pretrained, **kwargs)
@register_model
def coatnet_3_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_3_rw_224', pretrained=pretrained, **kwargs)
@register_model
def coatnet_bn_0_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_bn_0_rw_224', pretrained=pretrained, **kwargs)
@register_model
def coatnet_rmlp_nano_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_rmlp_nano_rw_224', pretrained=pretrained, **kwargs)
@register_model
def coatnet_rmlp_0_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_rmlp_0_rw_224', pretrained=pretrained, **kwargs)
@register_model
def coatnet_rmlp_1_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_rmlp_1_rw_224', pretrained=pretrained, **kwargs)
@register_model
def coatnet_rmlp_1_rw2_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_rmlp_1_rw2_224', pretrained=pretrained, **kwargs)
@register_model
def coatnet_rmlp_2_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_rmlp_2_rw_224', pretrained=pretrained, **kwargs)
@register_model
def coatnet_rmlp_2_rw_384(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_rmlp_2_rw_384', pretrained=pretrained, **kwargs)
@register_model
def coatnet_rmlp_3_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_rmlp_3_rw_224', pretrained=pretrained, **kwargs)
@register_model
def coatnet_nano_cc_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_nano_cc_224', pretrained=pretrained, **kwargs)
@register_model
def coatnext_nano_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnext_nano_rw_224', pretrained=pretrained, **kwargs)
@register_model
def coatnet_0_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_0_224', pretrained=pretrained, **kwargs)
@register_model
def coatnet_1_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_1_224', pretrained=pretrained, **kwargs)
@register_model
def coatnet_2_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_2_224', pretrained=pretrained, **kwargs)
@register_model
def coatnet_3_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_3_224', pretrained=pretrained, **kwargs)
@register_model
def coatnet_4_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_4_224', pretrained=pretrained, **kwargs)
@register_model
def coatnet_5_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_5_224', pretrained=pretrained, **kwargs)
@register_model
def maxvit_pico_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_pico_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxvit_nano_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_nano_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxvit_tiny_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_tiny_rw_224', pretrained=pretrained, **kwargs)
@register_model
def maxvit_tiny_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_tiny_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxvit_rmlp_pico_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_rmlp_pico_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxvit_rmlp_nano_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_rmlp_nano_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxvit_rmlp_tiny_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_rmlp_tiny_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxvit_rmlp_small_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_rmlp_small_rw_224', pretrained=pretrained, **kwargs)
@register_model
def maxvit_rmlp_small_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_rmlp_small_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxvit_rmlp_base_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_rmlp_base_rw_224', pretrained=pretrained, **kwargs)
@register_model
def maxvit_rmlp_base_rw_384(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_rmlp_base_rw_384', pretrained=pretrained, **kwargs)
@register_model
def maxvit_tiny_pm_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs)
@register_model
def maxxvit_rmlp_nano_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxxvit_rmlp_nano_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxxvit_rmlp_tiny_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxxvit_rmlp_tiny_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxxvit_rmlp_small_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxxvit_rmlp_small_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxxvitv2_nano_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxxvitv2_nano_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxxvitv2_rmlp_base_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('maxxvitv2_rmlp_base_rw_224', pretrained=pretrained, **kwargs)
@register_model
def maxxvitv2_rmlp_base_rw_384(pretrained=False, **kwargs):
return _create_maxxvit('maxxvitv2_rmlp_base_rw_384', pretrained=pretrained, **kwargs)
@register_model
def maxxvitv2_rmlp_large_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('maxxvitv2_rmlp_large_rw_224', pretrained=pretrained, **kwargs)
@register_model
def maxvit_tiny_tf_224(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_tiny_tf_224', 'maxvit_tiny_tf', pretrained=pretrained, **kwargs)
@register_model
def maxvit_tiny_tf_384(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_tiny_tf_384', 'maxvit_tiny_tf', pretrained=pretrained, **kwargs)
@register_model
def maxvit_tiny_tf_512(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_tiny_tf_512', 'maxvit_tiny_tf', pretrained=pretrained, **kwargs)
@register_model
def maxvit_small_tf_224(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_small_tf_224', 'maxvit_small_tf', pretrained=pretrained, **kwargs)
@register_model
def maxvit_small_tf_384(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_small_tf_384', 'maxvit_small_tf', pretrained=pretrained, **kwargs)
@register_model
def maxvit_small_tf_512(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_small_tf_512', 'maxvit_small_tf', pretrained=pretrained, **kwargs)
@register_model
def maxvit_base_tf_224(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_base_tf_224', 'maxvit_base_tf', pretrained=pretrained, **kwargs)
@register_model
def maxvit_base_tf_384(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_base_tf_384', 'maxvit_base_tf', pretrained=pretrained, **kwargs)
@register_model
def maxvit_base_tf_512(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_base_tf_512', 'maxvit_base_tf', pretrained=pretrained, **kwargs)
@register_model
def maxvit_large_tf_224(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_large_tf_224', 'maxvit_large_tf', pretrained=pretrained, **kwargs)
@register_model
def maxvit_large_tf_384(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_large_tf_384', 'maxvit_large_tf', pretrained=pretrained, **kwargs)
@register_model
def maxvit_large_tf_512(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_large_tf_512', 'maxvit_large_tf', pretrained=pretrained, **kwargs)
@register_model
def maxvit_xlarge_tf_224(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_xlarge_tf_224', 'maxvit_xlarge_tf', pretrained=pretrained, **kwargs)
@register_model
def maxvit_xlarge_tf_384(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_xlarge_tf_384', 'maxvit_xlarge_tf', pretrained=pretrained, **kwargs)
@register_model
def maxvit_xlarge_tf_512(pretrained=False, **kwargs):
2023-02-04 03:28:57 +08:00
return _create_maxxvit('maxvit_xlarge_tf_512', 'maxvit_xlarge_tf', pretrained=pretrained, **kwargs)