mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update rw models, fix heads
This commit is contained in:
parent
f2086f51a0
commit
c2da12c7e1
@ -8,10 +8,10 @@ InceptionNeXt (https://github.com/sail-sg/inceptionnext)
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale
|
||||
from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model
|
||||
@ -122,6 +122,7 @@ class MlpHead(nn.Module):
|
||||
self,
|
||||
dim,
|
||||
num_classes=1000,
|
||||
pool_type='avg',
|
||||
act_layer=nn.GELU,
|
||||
mlp_ratio=4,
|
||||
norm_layer=LayerNorm,
|
||||
@ -130,17 +131,25 @@ class MlpHead(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
hidden_features = int(mlp_ratio * dim)
|
||||
self.pool_type = pool_type
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.fc1 = nn.Linear(dim, hidden_features, bias=bias)
|
||||
self.act = act_layer()
|
||||
self.norm = norm_layer(hidden_features)
|
||||
self.norm2 = norm_layer(hidden_features)
|
||||
self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias)
|
||||
self.head_dropout = nn.Dropout(drop_rate)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, pre_logits: bool = False):
|
||||
if self.pool_type == 'avg':
|
||||
x = x.mean((1, 2))
|
||||
x = self.norm1(x)
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.norm(x)
|
||||
x = self.norm2(x)
|
||||
x = self.head_dropout(x)
|
||||
if pre_logits:
|
||||
return x
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
@ -208,7 +217,7 @@ class MambaOutStage(nn.Module):
|
||||
expansion_ratio=8 / 3,
|
||||
kernel_size=7,
|
||||
conv_ratio=1.0,
|
||||
downsample: bool = False,
|
||||
downsample: str = '',
|
||||
ls_init_value: Optional[float] = None,
|
||||
norm_layer=LayerNorm,
|
||||
act_layer=nn.GELU,
|
||||
@ -218,8 +227,10 @@ class MambaOutStage(nn.Module):
|
||||
dim_out = dim_out or dim
|
||||
self.grad_checkpointing = False
|
||||
|
||||
if downsample:
|
||||
if downsample == 'conv':
|
||||
self.downsample = Downsample(dim, dim_out, norm_layer=norm_layer)
|
||||
elif downsample == 'conv_nf':
|
||||
self.downsample = DownsampleNormFirst(dim, dim_out, norm_layer=norm_layer)
|
||||
else:
|
||||
assert dim == dim_out
|
||||
self.downsample = nn.Identity()
|
||||
@ -276,10 +287,10 @@ class MambaOut(nn.Module):
|
||||
kernel_size=7,
|
||||
stem_mid_norm=True,
|
||||
ls_init_value=None,
|
||||
downsample='conv',
|
||||
drop_path_rate=0.,
|
||||
drop_rate=0.,
|
||||
output_norm=LayerNorm,
|
||||
head_fn=MlpHead,
|
||||
head_fn='default',
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
@ -312,7 +323,7 @@ class MambaOut(nn.Module):
|
||||
depth=depths[i],
|
||||
kernel_size=kernel_size,
|
||||
conv_ratio=conv_ratio,
|
||||
downsample=i > 0,
|
||||
downsample=downsample if i > 0 else '',
|
||||
ls_init_value=ls_init_value,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
@ -322,9 +333,25 @@ class MambaOut(nn.Module):
|
||||
prev_dim = dim
|
||||
cur += depths[i]
|
||||
|
||||
self.norm = output_norm(prev_dim)
|
||||
|
||||
self.head = head_fn(prev_dim, num_classes, drop_rate=drop_rate)
|
||||
if head_fn == 'default':
|
||||
# specific to this model, unusual norm -> pool -> fc -> act -> norm -> fc combo
|
||||
self.head = MlpHead(
|
||||
prev_dim,
|
||||
num_classes,
|
||||
pool_type='avg',
|
||||
drop_rate=drop_rate,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
else:
|
||||
# more typical norm -> pool -> fc -> act -> fc
|
||||
self.head = ClNormMlpClassifierHead(
|
||||
prev_dim,
|
||||
num_classes,
|
||||
hidden_size=int(prev_dim * 4),
|
||||
pool_type='avg',
|
||||
norm_layer=norm_layer,
|
||||
drop_rate=drop_rate,
|
||||
)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
@ -336,7 +363,7 @@ class MambaOut(nn.Module):
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {'norm'}
|
||||
return {}
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
@ -345,9 +372,7 @@ class MambaOut(nn.Module):
|
||||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
x = x.mean((1, 2))
|
||||
x = self.norm(x)
|
||||
x = self.head(x)
|
||||
x = self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
@ -366,6 +391,10 @@ def checkpoint_filter_fn(state_dict, model):
|
||||
k = k.replace('downsample_layers.0.', 'stem.')
|
||||
k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
|
||||
k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k)
|
||||
if k.startswith('norm.'):
|
||||
k = k.replace('norm.', 'head.norm1.')
|
||||
elif k.startswith('head.norm.'):
|
||||
k = k.replace('head.norm.', 'head.norm2.')
|
||||
out_dict[k] = v
|
||||
|
||||
return out_dict
|
||||
@ -443,7 +472,9 @@ def mambaout_small_rw(pretrained=False, **kwargs):
|
||||
depths=[3, 4, 27, 3],
|
||||
dims=[96, 192, 384, 576],
|
||||
stem_mid_norm=False,
|
||||
downsample='conv_nf',
|
||||
ls_init_value=1e-6,
|
||||
head_fn='norm_mlp',
|
||||
)
|
||||
return _create_mambaout('mambaout_small_rw', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
@ -455,5 +486,6 @@ def mambaout_base_rw(pretrained=False, **kwargs):
|
||||
dims=(128, 256, 512, 768),
|
||||
stem_mid_norm=False,
|
||||
ls_init_value=1e-6,
|
||||
head_fn='norm_mlp',
|
||||
)
|
||||
return _create_mambaout('mambaout_base_rw', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
Loading…
x
Reference in New Issue
Block a user