mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update rw models, fix heads
This commit is contained in:
parent
f2086f51a0
commit
c2da12c7e1
@ -8,10 +8,10 @@ InceptionNeXt (https://github.com/sail-sg/inceptionnext)
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
from torch import nn
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale
|
from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._manipulate import checkpoint_seq
|
from ._manipulate import checkpoint_seq
|
||||||
from ._registry import register_model
|
from ._registry import register_model
|
||||||
@ -122,6 +122,7 @@ class MlpHead(nn.Module):
|
|||||||
self,
|
self,
|
||||||
dim,
|
dim,
|
||||||
num_classes=1000,
|
num_classes=1000,
|
||||||
|
pool_type='avg',
|
||||||
act_layer=nn.GELU,
|
act_layer=nn.GELU,
|
||||||
mlp_ratio=4,
|
mlp_ratio=4,
|
||||||
norm_layer=LayerNorm,
|
norm_layer=LayerNorm,
|
||||||
@ -130,17 +131,25 @@ class MlpHead(nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
hidden_features = int(mlp_ratio * dim)
|
hidden_features = int(mlp_ratio * dim)
|
||||||
|
self.pool_type = pool_type
|
||||||
|
|
||||||
|
self.norm1 = norm_layer(dim)
|
||||||
self.fc1 = nn.Linear(dim, hidden_features, bias=bias)
|
self.fc1 = nn.Linear(dim, hidden_features, bias=bias)
|
||||||
self.act = act_layer()
|
self.act = act_layer()
|
||||||
self.norm = norm_layer(hidden_features)
|
self.norm2 = norm_layer(hidden_features)
|
||||||
self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias)
|
self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias)
|
||||||
self.head_dropout = nn.Dropout(drop_rate)
|
self.head_dropout = nn.Dropout(drop_rate)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, pre_logits: bool = False):
|
||||||
|
if self.pool_type == 'avg':
|
||||||
|
x = x.mean((1, 2))
|
||||||
|
x = self.norm1(x)
|
||||||
x = self.fc1(x)
|
x = self.fc1(x)
|
||||||
x = self.act(x)
|
x = self.act(x)
|
||||||
x = self.norm(x)
|
x = self.norm2(x)
|
||||||
x = self.head_dropout(x)
|
x = self.head_dropout(x)
|
||||||
|
if pre_logits:
|
||||||
|
return x
|
||||||
x = self.fc2(x)
|
x = self.fc2(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -208,7 +217,7 @@ class MambaOutStage(nn.Module):
|
|||||||
expansion_ratio=8 / 3,
|
expansion_ratio=8 / 3,
|
||||||
kernel_size=7,
|
kernel_size=7,
|
||||||
conv_ratio=1.0,
|
conv_ratio=1.0,
|
||||||
downsample: bool = False,
|
downsample: str = '',
|
||||||
ls_init_value: Optional[float] = None,
|
ls_init_value: Optional[float] = None,
|
||||||
norm_layer=LayerNorm,
|
norm_layer=LayerNorm,
|
||||||
act_layer=nn.GELU,
|
act_layer=nn.GELU,
|
||||||
@ -218,8 +227,10 @@ class MambaOutStage(nn.Module):
|
|||||||
dim_out = dim_out or dim
|
dim_out = dim_out or dim
|
||||||
self.grad_checkpointing = False
|
self.grad_checkpointing = False
|
||||||
|
|
||||||
if downsample:
|
if downsample == 'conv':
|
||||||
self.downsample = Downsample(dim, dim_out, norm_layer=norm_layer)
|
self.downsample = Downsample(dim, dim_out, norm_layer=norm_layer)
|
||||||
|
elif downsample == 'conv_nf':
|
||||||
|
self.downsample = DownsampleNormFirst(dim, dim_out, norm_layer=norm_layer)
|
||||||
else:
|
else:
|
||||||
assert dim == dim_out
|
assert dim == dim_out
|
||||||
self.downsample = nn.Identity()
|
self.downsample = nn.Identity()
|
||||||
@ -276,10 +287,10 @@ class MambaOut(nn.Module):
|
|||||||
kernel_size=7,
|
kernel_size=7,
|
||||||
stem_mid_norm=True,
|
stem_mid_norm=True,
|
||||||
ls_init_value=None,
|
ls_init_value=None,
|
||||||
|
downsample='conv',
|
||||||
drop_path_rate=0.,
|
drop_path_rate=0.,
|
||||||
drop_rate=0.,
|
drop_rate=0.,
|
||||||
output_norm=LayerNorm,
|
head_fn='default',
|
||||||
head_fn=MlpHead,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -312,7 +323,7 @@ class MambaOut(nn.Module):
|
|||||||
depth=depths[i],
|
depth=depths[i],
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
conv_ratio=conv_ratio,
|
conv_ratio=conv_ratio,
|
||||||
downsample=i > 0,
|
downsample=downsample if i > 0 else '',
|
||||||
ls_init_value=ls_init_value,
|
ls_init_value=ls_init_value,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
act_layer=act_layer,
|
act_layer=act_layer,
|
||||||
@ -322,9 +333,25 @@ class MambaOut(nn.Module):
|
|||||||
prev_dim = dim
|
prev_dim = dim
|
||||||
cur += depths[i]
|
cur += depths[i]
|
||||||
|
|
||||||
self.norm = output_norm(prev_dim)
|
if head_fn == 'default':
|
||||||
|
# specific to this model, unusual norm -> pool -> fc -> act -> norm -> fc combo
|
||||||
self.head = head_fn(prev_dim, num_classes, drop_rate=drop_rate)
|
self.head = MlpHead(
|
||||||
|
prev_dim,
|
||||||
|
num_classes,
|
||||||
|
pool_type='avg',
|
||||||
|
drop_rate=drop_rate,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# more typical norm -> pool -> fc -> act -> fc
|
||||||
|
self.head = ClNormMlpClassifierHead(
|
||||||
|
prev_dim,
|
||||||
|
num_classes,
|
||||||
|
hidden_size=int(prev_dim * 4),
|
||||||
|
pool_type='avg',
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
drop_rate=drop_rate,
|
||||||
|
)
|
||||||
|
|
||||||
self.apply(self._init_weights)
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
@ -336,7 +363,7 @@ class MambaOut(nn.Module):
|
|||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def no_weight_decay(self):
|
def no_weight_decay(self):
|
||||||
return {'norm'}
|
return {}
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
@ -345,9 +372,7 @@ class MambaOut(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
def forward_head(self, x, pre_logits: bool = False):
|
def forward_head(self, x, pre_logits: bool = False):
|
||||||
x = x.mean((1, 2))
|
x = self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
|
||||||
x = self.norm(x)
|
|
||||||
x = self.head(x)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -366,6 +391,10 @@ def checkpoint_filter_fn(state_dict, model):
|
|||||||
k = k.replace('downsample_layers.0.', 'stem.')
|
k = k.replace('downsample_layers.0.', 'stem.')
|
||||||
k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
|
k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
|
||||||
k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k)
|
k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k)
|
||||||
|
if k.startswith('norm.'):
|
||||||
|
k = k.replace('norm.', 'head.norm1.')
|
||||||
|
elif k.startswith('head.norm.'):
|
||||||
|
k = k.replace('head.norm.', 'head.norm2.')
|
||||||
out_dict[k] = v
|
out_dict[k] = v
|
||||||
|
|
||||||
return out_dict
|
return out_dict
|
||||||
@ -443,7 +472,9 @@ def mambaout_small_rw(pretrained=False, **kwargs):
|
|||||||
depths=[3, 4, 27, 3],
|
depths=[3, 4, 27, 3],
|
||||||
dims=[96, 192, 384, 576],
|
dims=[96, 192, 384, 576],
|
||||||
stem_mid_norm=False,
|
stem_mid_norm=False,
|
||||||
|
downsample='conv_nf',
|
||||||
ls_init_value=1e-6,
|
ls_init_value=1e-6,
|
||||||
|
head_fn='norm_mlp',
|
||||||
)
|
)
|
||||||
return _create_mambaout('mambaout_small_rw', pretrained=pretrained, **dict(model_args, **kwargs))
|
return _create_mambaout('mambaout_small_rw', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
|
||||||
@ -455,5 +486,6 @@ def mambaout_base_rw(pretrained=False, **kwargs):
|
|||||||
dims=(128, 256, 512, 768),
|
dims=(128, 256, 512, 768),
|
||||||
stem_mid_norm=False,
|
stem_mid_norm=False,
|
||||||
ls_init_value=1e-6,
|
ls_init_value=1e-6,
|
||||||
|
head_fn='norm_mlp',
|
||||||
)
|
)
|
||||||
return _create_mambaout('mambaout_base_rw', pretrained=pretrained, **dict(model_args, **kwargs))
|
return _create_mambaout('mambaout_base_rw', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user