Update rw models, fix heads

This commit is contained in:
Ross Wightman 2024-08-27 14:03:59 -07:00
parent f2086f51a0
commit c2da12c7e1

View File

@ -8,10 +8,10 @@ InceptionNeXt (https://github.com/sail-sg/inceptionnext)
from typing import Optional
import torch
import torch.nn as nn
from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale
from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._registry import register_model
@ -122,6 +122,7 @@ class MlpHead(nn.Module):
self,
dim,
num_classes=1000,
pool_type='avg',
act_layer=nn.GELU,
mlp_ratio=4,
norm_layer=LayerNorm,
@ -130,17 +131,25 @@ class MlpHead(nn.Module):
):
super().__init__()
hidden_features = int(mlp_ratio * dim)
self.pool_type = pool_type
self.norm1 = norm_layer(dim)
self.fc1 = nn.Linear(dim, hidden_features, bias=bias)
self.act = act_layer()
self.norm = norm_layer(hidden_features)
self.norm2 = norm_layer(hidden_features)
self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias)
self.head_dropout = nn.Dropout(drop_rate)
def forward(self, x):
def forward(self, x, pre_logits: bool = False):
if self.pool_type == 'avg':
x = x.mean((1, 2))
x = self.norm1(x)
x = self.fc1(x)
x = self.act(x)
x = self.norm(x)
x = self.norm2(x)
x = self.head_dropout(x)
if pre_logits:
return x
x = self.fc2(x)
return x
@ -208,7 +217,7 @@ class MambaOutStage(nn.Module):
expansion_ratio=8 / 3,
kernel_size=7,
conv_ratio=1.0,
downsample: bool = False,
downsample: str = '',
ls_init_value: Optional[float] = None,
norm_layer=LayerNorm,
act_layer=nn.GELU,
@ -218,8 +227,10 @@ class MambaOutStage(nn.Module):
dim_out = dim_out or dim
self.grad_checkpointing = False
if downsample:
if downsample == 'conv':
self.downsample = Downsample(dim, dim_out, norm_layer=norm_layer)
elif downsample == 'conv_nf':
self.downsample = DownsampleNormFirst(dim, dim_out, norm_layer=norm_layer)
else:
assert dim == dim_out
self.downsample = nn.Identity()
@ -276,10 +287,10 @@ class MambaOut(nn.Module):
kernel_size=7,
stem_mid_norm=True,
ls_init_value=None,
downsample='conv',
drop_path_rate=0.,
drop_rate=0.,
output_norm=LayerNorm,
head_fn=MlpHead,
head_fn='default',
**kwargs,
):
super().__init__()
@ -312,7 +323,7 @@ class MambaOut(nn.Module):
depth=depths[i],
kernel_size=kernel_size,
conv_ratio=conv_ratio,
downsample=i > 0,
downsample=downsample if i > 0 else '',
ls_init_value=ls_init_value,
norm_layer=norm_layer,
act_layer=act_layer,
@ -322,9 +333,25 @@ class MambaOut(nn.Module):
prev_dim = dim
cur += depths[i]
self.norm = output_norm(prev_dim)
self.head = head_fn(prev_dim, num_classes, drop_rate=drop_rate)
if head_fn == 'default':
# specific to this model, unusual norm -> pool -> fc -> act -> norm -> fc combo
self.head = MlpHead(
prev_dim,
num_classes,
pool_type='avg',
drop_rate=drop_rate,
norm_layer=norm_layer,
)
else:
# more typical norm -> pool -> fc -> act -> fc
self.head = ClNormMlpClassifierHead(
prev_dim,
num_classes,
hidden_size=int(prev_dim * 4),
pool_type='avg',
norm_layer=norm_layer,
drop_rate=drop_rate,
)
self.apply(self._init_weights)
@ -336,7 +363,7 @@ class MambaOut(nn.Module):
@torch.jit.ignore
def no_weight_decay(self):
return {'norm'}
return {}
def forward_features(self, x):
x = self.stem(x)
@ -345,9 +372,7 @@ class MambaOut(nn.Module):
return x
def forward_head(self, x, pre_logits: bool = False):
x = x.mean((1, 2))
x = self.norm(x)
x = self.head(x)
x = self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
return x
def forward(self, x):
@ -366,6 +391,10 @@ def checkpoint_filter_fn(state_dict, model):
k = k.replace('downsample_layers.0.', 'stem.')
k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k)
if k.startswith('norm.'):
k = k.replace('norm.', 'head.norm1.')
elif k.startswith('head.norm.'):
k = k.replace('head.norm.', 'head.norm2.')
out_dict[k] = v
return out_dict
@ -443,7 +472,9 @@ def mambaout_small_rw(pretrained=False, **kwargs):
depths=[3, 4, 27, 3],
dims=[96, 192, 384, 576],
stem_mid_norm=False,
downsample='conv_nf',
ls_init_value=1e-6,
head_fn='norm_mlp',
)
return _create_mambaout('mambaout_small_rw', pretrained=pretrained, **dict(model_args, **kwargs))
@ -455,5 +486,6 @@ def mambaout_base_rw(pretrained=False, **kwargs):
dims=(128, 256, 512, 768),
stem_mid_norm=False,
ls_init_value=1e-6,
head_fn='norm_mlp',
)
return _create_mambaout('mambaout_base_rw', pretrained=pretrained, **dict(model_args, **kwargs))