pytorch-image-models/timm/models/mambaout.py

643 lines
20 KiB
Python

"""
MambaOut models for image classification.
Some implementations are modified from:
timm (https://github.com/rwightman/pytorch-image-models),
MetaFormer (https://github.com/sail-sg/metaformer),
InceptionNeXt (https://github.com/sail-sg/inceptionnext)
"""
from collections import OrderedDict
from typing import Optional
import torch
from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead, get_act_layer
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs
class Stem(nn.Module):
r""" Code modified from InternImage:
https://github.com/OpenGVLab/InternImage
"""
def __init__(
self,
in_chs=3,
out_chs=96,
mid_norm: bool = True,
act_layer=nn.GELU,
norm_layer=LayerNorm,
):
super().__init__()
self.conv1 = nn.Conv2d(
in_chs,
out_chs // 2,
kernel_size=3,
stride=2,
padding=1
)
self.norm1 = norm_layer(out_chs // 2) if mid_norm else None
self.act = act_layer()
self.conv2 = nn.Conv2d(
out_chs // 2,
out_chs,
kernel_size=3,
stride=2,
padding=1
)
self.norm2 = norm_layer(out_chs)
def forward(self, x):
x = self.conv1(x)
if self.norm1 is not None:
x = x.permute(0, 2, 3, 1)
x = self.norm1(x)
x = x.permute(0, 3, 1, 2)
x = self.act(x)
x = self.conv2(x)
x = x.permute(0, 2, 3, 1)
x = self.norm2(x)
return x
class DownsampleNormFirst(nn.Module):
def __init__(
self,
in_chs=96,
out_chs=198,
norm_layer=LayerNorm,
):
super().__init__()
self.norm = norm_layer(in_chs)
self.conv = nn.Conv2d(
in_chs,
out_chs,
kernel_size=3,
stride=2,
padding=1
)
def forward(self, x):
x = self.norm(x)
x = x.permute(0, 3, 1, 2)
x = self.conv(x)
x = x.permute(0, 2, 3, 1)
return x
class Downsample(nn.Module):
def __init__(
self,
in_chs=96,
out_chs=198,
norm_layer=LayerNorm,
):
super().__init__()
self.conv = nn.Conv2d(
in_chs,
out_chs,
kernel_size=3,
stride=2,
padding=1
)
self.norm = norm_layer(out_chs)
def forward(self, x):
x = x.permute(0, 3, 1, 2)
x = self.conv(x)
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
return x
class MlpHead(nn.Module):
""" MLP classification head
"""
def __init__(
self,
in_features,
num_classes=1000,
pool_type='avg',
act_layer=nn.GELU,
mlp_ratio=4,
norm_layer=LayerNorm,
drop_rate=0.,
bias=True,
):
super().__init__()
if mlp_ratio is not None:
hidden_size = int(mlp_ratio * in_features)
else:
hidden_size = None
self.pool_type = pool_type
self.in_features = in_features
self.hidden_size = hidden_size or in_features
self.norm = norm_layer(in_features)
if hidden_size:
self.pre_logits = nn.Sequential(OrderedDict([
('fc', nn.Linear(in_features, hidden_size)),
('act', act_layer()),
('norm', norm_layer(hidden_size))
]))
self.num_features = hidden_size
else:
self.num_features = in_features
self.pre_logits = nn.Identity()
self.fc = nn.Linear(self.num_features, num_classes, bias=bias) if num_classes > 0 else nn.Identity()
self.head_dropout = nn.Dropout(drop_rate)
def reset(self, num_classes: int, pool_type: Optional[str] = None, reset_other: bool = False):
if pool_type is not None:
self.pool_type = pool_type
if reset_other:
self.norm = nn.Identity()
self.pre_logits = nn.Identity()
self.num_features = self.in_features
self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward(self, x, pre_logits: bool = False):
if self.pool_type == 'avg':
x = x.mean((1, 2))
x = self.norm(x)
x = self.pre_logits(x)
x = self.head_dropout(x)
if pre_logits:
return x
x = self.fc(x)
return x
class GatedConvBlock(nn.Module):
r""" Our implementation of Gated CNN Block: https://arxiv.org/pdf/1612.08083
Args:
conv_ratio: control the number of channels to conduct depthwise convolution.
Conduct convolution on partial channels can improve paraitcal efficiency.
The idea of partial channels is from ShuffleNet V2 (https://arxiv.org/abs/1807.11164) and
also used by InceptionNeXt (https://arxiv.org/abs/2303.16900) and FasterNet (https://arxiv.org/abs/2303.03667)
"""
def __init__(
self,
dim,
expansion_ratio=8 / 3,
kernel_size=7,
conv_ratio=1.0,
ls_init_value=None,
norm_layer=LayerNorm,
act_layer=nn.GELU,
drop_path=0.,
**kwargs
):
super().__init__()
self.norm = norm_layer(dim)
hidden = int(expansion_ratio * dim)
self.fc1 = nn.Linear(dim, hidden * 2)
self.act = act_layer()
conv_channels = int(conv_ratio * dim)
self.split_indices = (hidden, hidden - conv_channels, conv_channels)
self.conv = nn.Conv2d(
conv_channels,
conv_channels,
kernel_size=kernel_size,
padding=kernel_size // 2,
groups=conv_channels
)
self.fc2 = nn.Linear(hidden, dim)
self.ls = LayerScale(dim) if ls_init_value is not None else nn.Identity()
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
shortcut = x # [B, H, W, C]
x = self.norm(x)
x = self.fc1(x)
g, i, c = torch.split(x, self.split_indices, dim=-1)
c = c.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
c = self.conv(c)
c = c.permute(0, 2, 3, 1) # [B, C, H, W] -> [B, H, W, C]
x = self.fc2(self.act(g) * torch.cat((i, c), dim=-1))
x = self.ls(x)
x = self.drop_path(x)
return x + shortcut
class MambaOutStage(nn.Module):
def __init__(
self,
dim,
dim_out: Optional[int] = None,
depth: int = 4,
expansion_ratio=8 / 3,
kernel_size=7,
conv_ratio=1.0,
downsample: str = '',
ls_init_value: Optional[float] = None,
norm_layer=LayerNorm,
act_layer=nn.GELU,
drop_path=0.,
):
super().__init__()
dim_out = dim_out or dim
self.grad_checkpointing = False
if downsample == 'conv':
self.downsample = Downsample(dim, dim_out, norm_layer=norm_layer)
elif downsample == 'conv_nf':
self.downsample = DownsampleNormFirst(dim, dim_out, norm_layer=norm_layer)
else:
assert dim == dim_out
self.downsample = nn.Identity()
self.blocks = nn.Sequential(*[
GatedConvBlock(
dim=dim_out,
expansion_ratio=expansion_ratio,
kernel_size=kernel_size,
conv_ratio=conv_ratio,
ls_init_value=ls_init_value,
norm_layer=norm_layer,
act_layer=act_layer,
drop_path=drop_path[j] if isinstance(drop_path, (list, tuple)) else drop_path,
)
for j in range(depth)
])
def forward(self, x):
x = self.downsample(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
return x
class MambaOut(nn.Module):
r""" MetaFormer
A PyTorch impl of : `MetaFormer Baselines for Vision` -
https://arxiv.org/abs/2210.13452
Args:
in_chans (int): Number of input image channels. Default: 3.
num_classes (int): Number of classes for classification head. Default: 1000.
depths (list or tuple): Number of blocks at each stage. Default: [3, 3, 9, 3].
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 576].
downsample_layers: (list or tuple): Downsampling layers before each stage.
drop_path_rate (float): Stochastic depth rate. Default: 0.
output_norm: norm before classifier head. Default: partial(nn.LayerNorm, eps=1e-6).
head_fn: classification head. Default: nn.Linear.
head_dropout (float): dropout for MLP classifier. Default: 0.
"""
def __init__(
self,
in_chans=3,
num_classes=1000,
global_pool='avg',
depths=(3, 3, 9, 3),
dims=(96, 192, 384, 576),
norm_layer=LayerNorm,
act_layer=nn.GELU,
conv_ratio=1.0,
expansion_ratio=8/3,
kernel_size=7,
stem_mid_norm=True,
ls_init_value=None,
downsample='conv',
drop_path_rate=0.,
drop_rate=0.,
head_fn='default',
):
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
self.output_fmt = 'NHWC'
if not isinstance(depths, (list, tuple)):
depths = [depths] # it means the model has only one stage
if not isinstance(dims, (list, tuple)):
dims = [dims]
act_layer = get_act_layer(act_layer)
num_stage = len(depths)
self.num_stage = num_stage
self.feature_info = []
self.stem = Stem(
in_chans,
dims[0],
mid_norm=stem_mid_norm,
act_layer=act_layer,
norm_layer=norm_layer,
)
prev_dim = dims[0]
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
cur = 0
curr_stride = 4
self.stages = nn.Sequential()
for i in range(num_stage):
dim = dims[i]
stride = 2 if curr_stride == 2 or i > 0 else 1
curr_stride *= stride
stage = MambaOutStage(
dim=prev_dim,
dim_out=dim,
depth=depths[i],
kernel_size=kernel_size,
conv_ratio=conv_ratio,
expansion_ratio=expansion_ratio,
downsample=downsample if i > 0 else '',
ls_init_value=ls_init_value,
norm_layer=norm_layer,
act_layer=act_layer,
drop_path=dp_rates[i],
)
self.stages.append(stage)
prev_dim = dim
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
self.feature_info += [dict(num_chs=prev_dim, reduction=curr_stride, module=f'stages.{i}')]
cur += depths[i]
if head_fn == 'default':
# specific to this model, unusual norm -> pool -> fc -> act -> norm -> fc combo
self.head = MlpHead(
prev_dim,
num_classes,
pool_type=global_pool,
drop_rate=drop_rate,
norm_layer=norm_layer,
)
else:
# more typical norm -> pool -> fc -> act -> fc
self.head = ClNormMlpClassifierHead(
prev_dim,
num_classes,
hidden_size=int(prev_dim * 4),
pool_type=global_pool,
norm_layer=norm_layer,
drop_rate=drop_rate,
)
self.num_features = prev_dim
self.head_hidden_size = self.head.num_features
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^stem',
blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages\.(\d+)\.downsample', (0,)), # blocks
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
for s in self.stages:
s.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self) -> nn.Module:
return self.head.fc
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)
return x
def forward_head(self, x, pre_logits: bool = False):
x = self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def checkpoint_filter_fn(state_dict, model):
if 'model' in state_dict:
state_dict = state_dict['model']
if 'stem.conv1.weight' in state_dict:
return state_dict
import re
out_dict = {}
for k, v in state_dict.items():
k = k.replace('downsample_layers.0.', 'stem.')
k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k)
# remap head names
if k.startswith('norm.'):
# this is moving to head since it's after the pooling
k = k.replace('norm.', 'head.norm.')
elif k.startswith('head.'):
k = k.replace('head.fc1.', 'head.pre_logits.fc.')
k = k.replace('head.norm.', 'head.pre_logits.norm.')
k = k.replace('head.fc2.', 'head.fc.')
out_dict[k] = v
return out_dict
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'test_input_size': (3, 288, 288),
'pool_size': (7, 7), 'crop_pct': 1.0, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv1', 'classifier': 'head.fc',
**kwargs
}
default_cfgs = generate_default_cfgs({
# original weights
'mambaout_femto.in1k': _cfg(
hf_hub_id='timm/'),
'mambaout_kobe.in1k': _cfg(
hf_hub_id='timm/'),
'mambaout_tiny.in1k': _cfg(
hf_hub_id='timm/'),
'mambaout_small.in1k': _cfg(
hf_hub_id='timm/'),
'mambaout_base.in1k': _cfg(
hf_hub_id='timm/'),
# timm experiments below
'mambaout_small_rw.sw_e450_in1k': _cfg(
hf_hub_id='timm/',
),
'mambaout_base_short_rw.sw_e500_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, test_crop_pct=1.0,
),
'mambaout_base_tall_rw.sw_e500_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, test_crop_pct=1.0,
),
'mambaout_base_wide_rw.sw_e500_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, test_crop_pct=1.0,
),
'mambaout_base_plus_rw.sw_e150_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
),
'mambaout_base_plus_rw.sw_e150_r384_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), test_input_size=(3, 384, 384), crop_mode='squash', pool_size=(12, 12),
),
'mambaout_base_plus_rw.sw_e150_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821,
),
'test_mambaout': _cfg(input_size=(3, 160, 160), test_input_size=(3, 192, 192), pool_size=(5, 5)),
})
def _create_mambaout(variant, pretrained=False, **kwargs):
model = build_model_with_cfg(
MambaOut, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
**kwargs,
)
return model
# a series of MambaOut models
@register_model
def mambaout_femto(pretrained=False, **kwargs):
model_args = dict(depths=(3, 3, 9, 3), dims=(48, 96, 192, 288))
return _create_mambaout('mambaout_femto', pretrained=pretrained, **dict(model_args, **kwargs))
# Kobe Memorial Version with 24 Gated CNN blocks
@register_model
def mambaout_kobe(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 15, 3], dims=[48, 96, 192, 288])
return _create_mambaout('mambaout_kobe', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def mambaout_tiny(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 9, 3], dims=[96, 192, 384, 576])
return _create_mambaout('mambaout_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def mambaout_small(pretrained=False, **kwargs):
model_args = dict(depths=[3, 4, 27, 3], dims=[96, 192, 384, 576])
return _create_mambaout('mambaout_small', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def mambaout_base(pretrained=False, **kwargs):
model_args = dict(depths=[3, 4, 27, 3], dims=[128, 256, 512, 768])
return _create_mambaout('mambaout_base', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def mambaout_small_rw(pretrained=False, **kwargs):
model_args = dict(
depths=[3, 4, 27, 3],
dims=[96, 192, 384, 576],
stem_mid_norm=False,
downsample='conv_nf',
ls_init_value=1e-6,
head_fn='norm_mlp',
)
return _create_mambaout('mambaout_small_rw', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def mambaout_base_short_rw(pretrained=False, **kwargs):
model_args = dict(
depths=(3, 3, 25, 3),
dims=(128, 256, 512, 768),
expansion_ratio=3.0,
conv_ratio=1.25,
stem_mid_norm=False,
downsample='conv_nf',
ls_init_value=1e-6,
head_fn='norm_mlp',
)
return _create_mambaout('mambaout_base_short_rw', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def mambaout_base_tall_rw(pretrained=False, **kwargs):
model_args = dict(
depths=(3, 4, 30, 3),
dims=(128, 256, 512, 768),
expansion_ratio=2.5,
conv_ratio=1.25,
stem_mid_norm=False,
downsample='conv_nf',
ls_init_value=1e-6,
head_fn='norm_mlp',
)
return _create_mambaout('mambaout_base_tall_rw', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def mambaout_base_wide_rw(pretrained=False, **kwargs):
model_args = dict(
depths=(3, 4, 27, 3),
dims=(128, 256, 512, 768),
expansion_ratio=3.0,
conv_ratio=1.5,
stem_mid_norm=False,
downsample='conv_nf',
ls_init_value=1e-6,
act_layer='silu',
head_fn='norm_mlp',
)
return _create_mambaout('mambaout_base_wide_rw', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def mambaout_base_plus_rw(pretrained=False, **kwargs):
model_args = dict(
depths=(3, 4, 30, 3),
dims=(128, 256, 512, 768),
expansion_ratio=3.0,
conv_ratio=1.5,
stem_mid_norm=False,
downsample='conv_nf',
ls_init_value=1e-6,
act_layer='silu',
head_fn='norm_mlp',
)
return _create_mambaout('mambaout_base_plus_rw', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def test_mambaout(pretrained=False, **kwargs):
model_args = dict(
depths=(1, 1, 3, 1),
dims=(16, 32, 48, 64),
expansion_ratio=3,
stem_mid_norm=False,
downsample='conv_nf',
ls_init_value=1e-4,
act_layer='silu',
head_fn='norm_mlp',
)
return _create_mambaout('test_mambaout', pretrained=pretrained, **dict(model_args, **kwargs))