From 4542cf03f9628fc5ec6292337c41cb14e07d5278 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 13 Sep 2024 11:25:04 -0700 Subject: [PATCH] Add features_only, other bits to mambaout, define different base alternatives --- timm/models/mambaout.py | 119 ++++++++++++++++++++++++++++++++-------- 1 file changed, 95 insertions(+), 24 deletions(-) diff --git a/timm/models/mambaout.py b/timm/models/mambaout.py index 3c9900a0..5c472237 100644 --- a/timm/models/mambaout.py +++ b/timm/models/mambaout.py @@ -5,6 +5,7 @@ timm (https://github.com/rwightman/pytorch-image-models), MetaFormer (https://github.com/sail-sg/metaformer), InceptionNeXt (https://github.com/sail-sg/inceptionnext) """ +from collections import OrderedDict from typing import Optional import torch @@ -120,7 +121,7 @@ class MlpHead(nn.Module): def __init__( self, - dim, + in_features, num_classes=1000, pool_type='avg', act_layer=nn.GELU, @@ -130,27 +131,47 @@ class MlpHead(nn.Module): bias=True, ): super().__init__() - hidden_features = int(mlp_ratio * dim) + if mlp_ratio is not None: + hidden_size = int(mlp_ratio * in_features) + else: + hidden_size = None self.pool_type = pool_type + self.in_features = in_features + self.hidden_size = hidden_size or in_features - self.norm1 = norm_layer(dim) - self.fc1 = nn.Linear(dim, hidden_features, bias=bias) - self.act = act_layer() - self.norm2 = norm_layer(hidden_features) - self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias) + self.norm = norm_layer(in_features) + if hidden_size: + self.pre_logits = nn.Sequential(OrderedDict([ + ('fc', nn.Linear(in_features, hidden_size)), + ('act', act_layer()), + ('norm', norm_layer(hidden_size)) + ])) + self.num_features = hidden_size + else: + self.num_features = in_features + self.pre_logits = nn.Identity() + + self.fc = nn.Linear(hidden_size, num_classes, bias=bias) self.head_dropout = nn.Dropout(drop_rate) + def reset(self, num_classes: int, pool_type: Optional[str] = None, reset_other: bool = False): + if pool_type is not None: + self.pool_type = pool_type + if reset_other: + self.norm = nn.Identity() + self.pre_logits = nn.Identity() + self.num_features = self.in_features + self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + def forward(self, x, pre_logits: bool = False): if self.pool_type == 'avg': x = x.mean((1, 2)) - x = self.norm1(x) - x = self.fc1(x) - x = self.act(x) - x = self.norm2(x) + x = self.norm(x) + x = self.pre_logits(x) x = self.head_dropout(x) if pre_logits: return x - x = self.fc2(x) + x = self.fc(x) return x @@ -284,6 +305,7 @@ class MambaOut(nn.Module): norm_layer=LayerNorm, act_layer=nn.GELU, conv_ratio=1.0, + expansion_ratio=8/3, kernel_size=7, stem_mid_norm=True, ls_init_value=None, @@ -303,6 +325,7 @@ class MambaOut(nn.Module): num_stage = len(depths) self.num_stage = num_stage + self.feature_info = [] self.stem = Stem( in_chans, @@ -313,16 +336,20 @@ class MambaOut(nn.Module): ) prev_dim = dims[0] dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] - self.stages = nn.ModuleList() cur = 0 + curr_stride = 4 + self.stages = nn.Sequential() for i in range(num_stage): dim = dims[i] + stride = 2 if curr_stride == 2 or i > 0 else 1 + curr_stride *= stride stage = MambaOutStage( dim=prev_dim, dim_out=dim, depth=depths[i], kernel_size=kernel_size, conv_ratio=conv_ratio, + expansion_ratio=expansion_ratio, downsample=downsample if i > 0 else '', ls_init_value=ls_init_value, norm_layer=norm_layer, @@ -331,6 +358,8 @@ class MambaOut(nn.Module): ) self.stages.append(stage) prev_dim = dim + # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2 + self.feature_info += [dict(num_chs=prev_dim, reduction=curr_stride, module=f'stages.{i}')] cur += depths[i] if head_fn == 'default': @@ -352,6 +381,8 @@ class MambaOut(nn.Module): norm_layer=norm_layer, drop_rate=drop_rate, ) + self.num_features = prev_dim + self.hidden_size = self.head.num_features self.apply(self._init_weights) @@ -362,13 +393,31 @@ class MambaOut(nn.Module): nn.init.constant_(m.bias, 0) @torch.jit.ignore - def no_weight_decay(self): - return {} + def group_matcher(self, coarse=False): + return dict( + stem=r'^stem', + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+)\.downsample', (0,)), # blocks + (r'^stages\.(\d+)\.blocks\.(\d+)', None), + ] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self) -> nn.Module: + return self.head.fc + + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + self.num_classes = num_classes + self.head.reset(num_classes, global_pool) def forward_features(self, x): x = self.stem(x) - for s in self.stages: - x = s(x) + x = self.stages(x) return x def forward_head(self, x, pre_logits: bool = False): @@ -391,10 +440,14 @@ def checkpoint_filter_fn(state_dict, model): k = k.replace('downsample_layers.0.', 'stem.') k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k) k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k) + # remap head names if k.startswith('norm.'): - k = k.replace('norm.', 'head.norm1.') - elif k.startswith('head.norm.'): - k = k.replace('head.norm.', 'head.norm2.') + # this is moving to head since it's after the pooling + k = k.replace('norm.', 'head.norm.') + elif k.startswith('head.'): + k = k.replace('head.fc1.', 'head.pre_logits.fc.') + k = k.replace('head.norm.', 'head.pre_logits.norm.') + k = k.replace('head.fc2.', 'head.fc.') out_dict[k] = v return out_dict @@ -405,7 +458,7 @@ def _cfg(url='', **kwargs): 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': 1.0, 'interpolation': 'bicubic', - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head.fc', **kwargs } @@ -422,7 +475,8 @@ default_cfgs = { 'mambaout_base': _cfg( url='https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_base.pth'), 'mambaout_small_rw': _cfg(), - 'mambaout_base_rw': _cfg(), + 'mambaout_base_slim_rw': _cfg(), + 'mambaout_base_plus_rw': _cfg(), } @@ -480,12 +534,29 @@ def mambaout_small_rw(pretrained=False, **kwargs): @register_model -def mambaout_base_rw(pretrained=False, **kwargs): +def mambaout_base_slim_rw(pretrained=False, **kwargs): model_args = dict( depths=(3, 4, 27, 3), dims=(128, 256, 512, 768), + expansion_ratio=2.5, + conv_ratio=1.25, stem_mid_norm=False, + downsample='conv_nf', ls_init_value=1e-6, head_fn='norm_mlp', ) - return _create_mambaout('mambaout_base_rw', pretrained=pretrained, **dict(model_args, **kwargs)) + return _create_mambaout('mambaout_base_slim_rw', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def mambaout_base_plus_rw(pretrained=False, **kwargs): + model_args = dict( + depths=(3, 4, 27, 3), + dims=(128, 256, 512, 768), + expansion_ratio=3.0, + stem_mid_norm=False, + downsample='conv_nf', + ls_init_value=1e-6, + head_fn='norm_mlp', + ) + return _create_mambaout('mambaout_base_plus_rw', pretrained=pretrained, **dict(model_args, **kwargs))