mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add features_only, other bits to mambaout, define different base alternatives
This commit is contained in:
parent
c2da12c7e1
commit
4542cf03f9
@ -5,6 +5,7 @@ timm (https://github.com/rwightman/pytorch-image-models),
|
|||||||
MetaFormer (https://github.com/sail-sg/metaformer),
|
MetaFormer (https://github.com/sail-sg/metaformer),
|
||||||
InceptionNeXt (https://github.com/sail-sg/inceptionnext)
|
InceptionNeXt (https://github.com/sail-sg/inceptionnext)
|
||||||
"""
|
"""
|
||||||
|
from collections import OrderedDict
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -120,7 +121,7 @@ class MlpHead(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim,
|
in_features,
|
||||||
num_classes=1000,
|
num_classes=1000,
|
||||||
pool_type='avg',
|
pool_type='avg',
|
||||||
act_layer=nn.GELU,
|
act_layer=nn.GELU,
|
||||||
@ -130,27 +131,47 @@ class MlpHead(nn.Module):
|
|||||||
bias=True,
|
bias=True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
hidden_features = int(mlp_ratio * dim)
|
if mlp_ratio is not None:
|
||||||
|
hidden_size = int(mlp_ratio * in_features)
|
||||||
|
else:
|
||||||
|
hidden_size = None
|
||||||
self.pool_type = pool_type
|
self.pool_type = pool_type
|
||||||
|
self.in_features = in_features
|
||||||
|
self.hidden_size = hidden_size or in_features
|
||||||
|
|
||||||
self.norm1 = norm_layer(dim)
|
self.norm = norm_layer(in_features)
|
||||||
self.fc1 = nn.Linear(dim, hidden_features, bias=bias)
|
if hidden_size:
|
||||||
self.act = act_layer()
|
self.pre_logits = nn.Sequential(OrderedDict([
|
||||||
self.norm2 = norm_layer(hidden_features)
|
('fc', nn.Linear(in_features, hidden_size)),
|
||||||
self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias)
|
('act', act_layer()),
|
||||||
|
('norm', norm_layer(hidden_size))
|
||||||
|
]))
|
||||||
|
self.num_features = hidden_size
|
||||||
|
else:
|
||||||
|
self.num_features = in_features
|
||||||
|
self.pre_logits = nn.Identity()
|
||||||
|
|
||||||
|
self.fc = nn.Linear(hidden_size, num_classes, bias=bias)
|
||||||
self.head_dropout = nn.Dropout(drop_rate)
|
self.head_dropout = nn.Dropout(drop_rate)
|
||||||
|
|
||||||
|
def reset(self, num_classes: int, pool_type: Optional[str] = None, reset_other: bool = False):
|
||||||
|
if pool_type is not None:
|
||||||
|
self.pool_type = pool_type
|
||||||
|
if reset_other:
|
||||||
|
self.norm = nn.Identity()
|
||||||
|
self.pre_logits = nn.Identity()
|
||||||
|
self.num_features = self.in_features
|
||||||
|
self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, pre_logits: bool = False):
|
def forward(self, x, pre_logits: bool = False):
|
||||||
if self.pool_type == 'avg':
|
if self.pool_type == 'avg':
|
||||||
x = x.mean((1, 2))
|
x = x.mean((1, 2))
|
||||||
x = self.norm1(x)
|
x = self.norm(x)
|
||||||
x = self.fc1(x)
|
x = self.pre_logits(x)
|
||||||
x = self.act(x)
|
|
||||||
x = self.norm2(x)
|
|
||||||
x = self.head_dropout(x)
|
x = self.head_dropout(x)
|
||||||
if pre_logits:
|
if pre_logits:
|
||||||
return x
|
return x
|
||||||
x = self.fc2(x)
|
x = self.fc(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -284,6 +305,7 @@ class MambaOut(nn.Module):
|
|||||||
norm_layer=LayerNorm,
|
norm_layer=LayerNorm,
|
||||||
act_layer=nn.GELU,
|
act_layer=nn.GELU,
|
||||||
conv_ratio=1.0,
|
conv_ratio=1.0,
|
||||||
|
expansion_ratio=8/3,
|
||||||
kernel_size=7,
|
kernel_size=7,
|
||||||
stem_mid_norm=True,
|
stem_mid_norm=True,
|
||||||
ls_init_value=None,
|
ls_init_value=None,
|
||||||
@ -303,6 +325,7 @@ class MambaOut(nn.Module):
|
|||||||
|
|
||||||
num_stage = len(depths)
|
num_stage = len(depths)
|
||||||
self.num_stage = num_stage
|
self.num_stage = num_stage
|
||||||
|
self.feature_info = []
|
||||||
|
|
||||||
self.stem = Stem(
|
self.stem = Stem(
|
||||||
in_chans,
|
in_chans,
|
||||||
@ -313,16 +336,20 @@ class MambaOut(nn.Module):
|
|||||||
)
|
)
|
||||||
prev_dim = dims[0]
|
prev_dim = dims[0]
|
||||||
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
||||||
self.stages = nn.ModuleList()
|
|
||||||
cur = 0
|
cur = 0
|
||||||
|
curr_stride = 4
|
||||||
|
self.stages = nn.Sequential()
|
||||||
for i in range(num_stage):
|
for i in range(num_stage):
|
||||||
dim = dims[i]
|
dim = dims[i]
|
||||||
|
stride = 2 if curr_stride == 2 or i > 0 else 1
|
||||||
|
curr_stride *= stride
|
||||||
stage = MambaOutStage(
|
stage = MambaOutStage(
|
||||||
dim=prev_dim,
|
dim=prev_dim,
|
||||||
dim_out=dim,
|
dim_out=dim,
|
||||||
depth=depths[i],
|
depth=depths[i],
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
conv_ratio=conv_ratio,
|
conv_ratio=conv_ratio,
|
||||||
|
expansion_ratio=expansion_ratio,
|
||||||
downsample=downsample if i > 0 else '',
|
downsample=downsample if i > 0 else '',
|
||||||
ls_init_value=ls_init_value,
|
ls_init_value=ls_init_value,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
@ -331,6 +358,8 @@ class MambaOut(nn.Module):
|
|||||||
)
|
)
|
||||||
self.stages.append(stage)
|
self.stages.append(stage)
|
||||||
prev_dim = dim
|
prev_dim = dim
|
||||||
|
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
|
||||||
|
self.feature_info += [dict(num_chs=prev_dim, reduction=curr_stride, module=f'stages.{i}')]
|
||||||
cur += depths[i]
|
cur += depths[i]
|
||||||
|
|
||||||
if head_fn == 'default':
|
if head_fn == 'default':
|
||||||
@ -352,6 +381,8 @@ class MambaOut(nn.Module):
|
|||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
drop_rate=drop_rate,
|
drop_rate=drop_rate,
|
||||||
)
|
)
|
||||||
|
self.num_features = prev_dim
|
||||||
|
self.hidden_size = self.head.num_features
|
||||||
|
|
||||||
self.apply(self._init_weights)
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
@ -362,13 +393,31 @@ class MambaOut(nn.Module):
|
|||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def no_weight_decay(self):
|
def group_matcher(self, coarse=False):
|
||||||
return {}
|
return dict(
|
||||||
|
stem=r'^stem',
|
||||||
|
blocks=r'^stages\.(\d+)' if coarse else [
|
||||||
|
(r'^stages\.(\d+)\.downsample', (0,)), # blocks
|
||||||
|
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def set_grad_checkpointing(self, enable=True):
|
||||||
|
for s in self.stages:
|
||||||
|
s.grad_checkpointing = enable
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def get_classifier(self) -> nn.Module:
|
||||||
|
return self.head.fc
|
||||||
|
|
||||||
|
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.head.reset(num_classes, global_pool)
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
for s in self.stages:
|
x = self.stages(x)
|
||||||
x = s(x)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward_head(self, x, pre_logits: bool = False):
|
def forward_head(self, x, pre_logits: bool = False):
|
||||||
@ -391,10 +440,14 @@ def checkpoint_filter_fn(state_dict, model):
|
|||||||
k = k.replace('downsample_layers.0.', 'stem.')
|
k = k.replace('downsample_layers.0.', 'stem.')
|
||||||
k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
|
k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
|
||||||
k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k)
|
k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k)
|
||||||
|
# remap head names
|
||||||
if k.startswith('norm.'):
|
if k.startswith('norm.'):
|
||||||
k = k.replace('norm.', 'head.norm1.')
|
# this is moving to head since it's after the pooling
|
||||||
elif k.startswith('head.norm.'):
|
k = k.replace('norm.', 'head.norm.')
|
||||||
k = k.replace('head.norm.', 'head.norm2.')
|
elif k.startswith('head.'):
|
||||||
|
k = k.replace('head.fc1.', 'head.pre_logits.fc.')
|
||||||
|
k = k.replace('head.norm.', 'head.pre_logits.norm.')
|
||||||
|
k = k.replace('head.fc2.', 'head.fc.')
|
||||||
out_dict[k] = v
|
out_dict[k] = v
|
||||||
|
|
||||||
return out_dict
|
return out_dict
|
||||||
@ -405,7 +458,7 @@ def _cfg(url='', **kwargs):
|
|||||||
'url': url,
|
'url': url,
|
||||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||||
'crop_pct': 1.0, 'interpolation': 'bicubic',
|
'crop_pct': 1.0, 'interpolation': 'bicubic',
|
||||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head',
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head.fc',
|
||||||
**kwargs
|
**kwargs
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -422,7 +475,8 @@ default_cfgs = {
|
|||||||
'mambaout_base': _cfg(
|
'mambaout_base': _cfg(
|
||||||
url='https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_base.pth'),
|
url='https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_base.pth'),
|
||||||
'mambaout_small_rw': _cfg(),
|
'mambaout_small_rw': _cfg(),
|
||||||
'mambaout_base_rw': _cfg(),
|
'mambaout_base_slim_rw': _cfg(),
|
||||||
|
'mambaout_base_plus_rw': _cfg(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -480,12 +534,29 @@ def mambaout_small_rw(pretrained=False, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def mambaout_base_rw(pretrained=False, **kwargs):
|
def mambaout_base_slim_rw(pretrained=False, **kwargs):
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
depths=(3, 4, 27, 3),
|
depths=(3, 4, 27, 3),
|
||||||
dims=(128, 256, 512, 768),
|
dims=(128, 256, 512, 768),
|
||||||
|
expansion_ratio=2.5,
|
||||||
|
conv_ratio=1.25,
|
||||||
stem_mid_norm=False,
|
stem_mid_norm=False,
|
||||||
|
downsample='conv_nf',
|
||||||
ls_init_value=1e-6,
|
ls_init_value=1e-6,
|
||||||
head_fn='norm_mlp',
|
head_fn='norm_mlp',
|
||||||
)
|
)
|
||||||
return _create_mambaout('mambaout_base_rw', pretrained=pretrained, **dict(model_args, **kwargs))
|
return _create_mambaout('mambaout_base_slim_rw', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def mambaout_base_plus_rw(pretrained=False, **kwargs):
|
||||||
|
model_args = dict(
|
||||||
|
depths=(3, 4, 27, 3),
|
||||||
|
dims=(128, 256, 512, 768),
|
||||||
|
expansion_ratio=3.0,
|
||||||
|
stem_mid_norm=False,
|
||||||
|
downsample='conv_nf',
|
||||||
|
ls_init_value=1e-6,
|
||||||
|
head_fn='norm_mlp',
|
||||||
|
)
|
||||||
|
return _create_mambaout('mambaout_base_plus_rw', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user