mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Mambaout tweaks
This commit is contained in:
parent
4542cf03f9
commit
91e743f2dd
@ -12,7 +12,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead
|
from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead, get_act_layer
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._manipulate import checkpoint_seq
|
from ._manipulate import checkpoint_seq
|
||||||
from ._registry import register_model
|
from ._registry import register_model
|
||||||
@ -318,10 +318,12 @@ class MambaOut(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.drop_rate = drop_rate
|
self.drop_rate = drop_rate
|
||||||
|
self.output_fmt = 'NHWC'
|
||||||
if not isinstance(depths, (list, tuple)):
|
if not isinstance(depths, (list, tuple)):
|
||||||
depths = [depths] # it means the model has only one stage
|
depths = [depths] # it means the model has only one stage
|
||||||
if not isinstance(dims, (list, tuple)):
|
if not isinstance(dims, (list, tuple)):
|
||||||
dims = [dims]
|
dims = [dims]
|
||||||
|
act_layer = get_act_layer(act_layer)
|
||||||
|
|
||||||
num_stage = len(depths)
|
num_stage = len(depths)
|
||||||
self.num_stage = num_stage
|
self.num_stage = num_stage
|
||||||
@ -456,7 +458,7 @@ def checkpoint_filter_fn(state_dict, model):
|
|||||||
def _cfg(url='', **kwargs):
|
def _cfg(url='', **kwargs):
|
||||||
return {
|
return {
|
||||||
'url': url,
|
'url': url,
|
||||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||||
'crop_pct': 1.0, 'interpolation': 'bicubic',
|
'crop_pct': 1.0, 'interpolation': 'bicubic',
|
||||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head.fc',
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head.fc',
|
||||||
**kwargs
|
**kwargs
|
||||||
@ -477,6 +479,7 @@ default_cfgs = {
|
|||||||
'mambaout_small_rw': _cfg(),
|
'mambaout_small_rw': _cfg(),
|
||||||
'mambaout_base_slim_rw': _cfg(),
|
'mambaout_base_slim_rw': _cfg(),
|
||||||
'mambaout_base_plus_rw': _cfg(),
|
'mambaout_base_plus_rw': _cfg(),
|
||||||
|
'test_mambaout': _cfg(input_size=(3, 160, 160), pool_size=(5, 5)),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -554,9 +557,26 @@ def mambaout_base_plus_rw(pretrained=False, **kwargs):
|
|||||||
depths=(3, 4, 27, 3),
|
depths=(3, 4, 27, 3),
|
||||||
dims=(128, 256, 512, 768),
|
dims=(128, 256, 512, 768),
|
||||||
expansion_ratio=3.0,
|
expansion_ratio=3.0,
|
||||||
|
conv_ratio=1.5,
|
||||||
stem_mid_norm=False,
|
stem_mid_norm=False,
|
||||||
downsample='conv_nf',
|
downsample='conv_nf',
|
||||||
ls_init_value=1e-6,
|
ls_init_value=1e-6,
|
||||||
|
act_layer='silu',
|
||||||
head_fn='norm_mlp',
|
head_fn='norm_mlp',
|
||||||
)
|
)
|
||||||
return _create_mambaout('mambaout_base_plus_rw', pretrained=pretrained, **dict(model_args, **kwargs))
|
return _create_mambaout('mambaout_base_plus_rw', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def test_mambaout(pretrained=False, **kwargs):
|
||||||
|
model_args = dict(
|
||||||
|
depths=(1, 1, 3, 1),
|
||||||
|
dims=(16, 32, 48, 64),
|
||||||
|
expansion_ratio=3,
|
||||||
|
stem_mid_norm=False,
|
||||||
|
downsample='conv_nf',
|
||||||
|
ls_init_value=1e-4,
|
||||||
|
act_layer='silu',
|
||||||
|
head_fn='norm_mlp',
|
||||||
|
)
|
||||||
|
return _create_mambaout('test_mambaout', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user