diff --git a/timm/models/mambaout.py b/timm/models/mambaout.py index a57ba8f3..3c9900a0 100644 --- a/timm/models/mambaout.py +++ b/timm/models/mambaout.py @@ -8,10 +8,10 @@ InceptionNeXt (https://github.com/sail-sg/inceptionnext) from typing import Optional import torch -import torch.nn as nn +from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale +from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq from ._registry import register_model @@ -122,6 +122,7 @@ class MlpHead(nn.Module): self, dim, num_classes=1000, + pool_type='avg', act_layer=nn.GELU, mlp_ratio=4, norm_layer=LayerNorm, @@ -130,17 +131,25 @@ class MlpHead(nn.Module): ): super().__init__() hidden_features = int(mlp_ratio * dim) + self.pool_type = pool_type + + self.norm1 = norm_layer(dim) self.fc1 = nn.Linear(dim, hidden_features, bias=bias) self.act = act_layer() - self.norm = norm_layer(hidden_features) + self.norm2 = norm_layer(hidden_features) self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias) self.head_dropout = nn.Dropout(drop_rate) - def forward(self, x): + def forward(self, x, pre_logits: bool = False): + if self.pool_type == 'avg': + x = x.mean((1, 2)) + x = self.norm1(x) x = self.fc1(x) x = self.act(x) - x = self.norm(x) + x = self.norm2(x) x = self.head_dropout(x) + if pre_logits: + return x x = self.fc2(x) return x @@ -208,7 +217,7 @@ class MambaOutStage(nn.Module): expansion_ratio=8 / 3, kernel_size=7, conv_ratio=1.0, - downsample: bool = False, + downsample: str = '', ls_init_value: Optional[float] = None, norm_layer=LayerNorm, act_layer=nn.GELU, @@ -218,8 +227,10 @@ class MambaOutStage(nn.Module): dim_out = dim_out or dim self.grad_checkpointing = False - if downsample: + if downsample == 'conv': self.downsample = Downsample(dim, dim_out, norm_layer=norm_layer) + elif downsample == 'conv_nf': + self.downsample = DownsampleNormFirst(dim, dim_out, norm_layer=norm_layer) else: assert dim == dim_out self.downsample = nn.Identity() @@ -276,10 +287,10 @@ class MambaOut(nn.Module): kernel_size=7, stem_mid_norm=True, ls_init_value=None, + downsample='conv', drop_path_rate=0., drop_rate=0., - output_norm=LayerNorm, - head_fn=MlpHead, + head_fn='default', **kwargs, ): super().__init__() @@ -312,7 +323,7 @@ class MambaOut(nn.Module): depth=depths[i], kernel_size=kernel_size, conv_ratio=conv_ratio, - downsample=i > 0, + downsample=downsample if i > 0 else '', ls_init_value=ls_init_value, norm_layer=norm_layer, act_layer=act_layer, @@ -322,9 +333,25 @@ class MambaOut(nn.Module): prev_dim = dim cur += depths[i] - self.norm = output_norm(prev_dim) - - self.head = head_fn(prev_dim, num_classes, drop_rate=drop_rate) + if head_fn == 'default': + # specific to this model, unusual norm -> pool -> fc -> act -> norm -> fc combo + self.head = MlpHead( + prev_dim, + num_classes, + pool_type='avg', + drop_rate=drop_rate, + norm_layer=norm_layer, + ) + else: + # more typical norm -> pool -> fc -> act -> fc + self.head = ClNormMlpClassifierHead( + prev_dim, + num_classes, + hidden_size=int(prev_dim * 4), + pool_type='avg', + norm_layer=norm_layer, + drop_rate=drop_rate, + ) self.apply(self._init_weights) @@ -336,7 +363,7 @@ class MambaOut(nn.Module): @torch.jit.ignore def no_weight_decay(self): - return {'norm'} + return {} def forward_features(self, x): x = self.stem(x) @@ -345,9 +372,7 @@ class MambaOut(nn.Module): return x def forward_head(self, x, pre_logits: bool = False): - x = x.mean((1, 2)) - x = self.norm(x) - x = self.head(x) + x = self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) return x def forward(self, x): @@ -366,6 +391,10 @@ def checkpoint_filter_fn(state_dict, model): k = k.replace('downsample_layers.0.', 'stem.') k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k) k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k) + if k.startswith('norm.'): + k = k.replace('norm.', 'head.norm1.') + elif k.startswith('head.norm.'): + k = k.replace('head.norm.', 'head.norm2.') out_dict[k] = v return out_dict @@ -443,7 +472,9 @@ def mambaout_small_rw(pretrained=False, **kwargs): depths=[3, 4, 27, 3], dims=[96, 192, 384, 576], stem_mid_norm=False, + downsample='conv_nf', ls_init_value=1e-6, + head_fn='norm_mlp', ) return _create_mambaout('mambaout_small_rw', pretrained=pretrained, **dict(model_args, **kwargs)) @@ -455,5 +486,6 @@ def mambaout_base_rw(pretrained=False, **kwargs): dims=(128, 256, 512, 768), stem_mid_norm=False, ls_init_value=1e-6, + head_fn='norm_mlp', ) return _create_mambaout('mambaout_base_rw', pretrained=pretrained, **dict(model_args, **kwargs))