From d90ed530dc34fad8f80fb33618400920d7d86dec Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 7 Jan 2023 23:27:07 -0800 Subject: [PATCH] Update metaformers.py --- timm/models/metaformers.py | 43 ++++++++++++++++++++++++++------------ 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index ff8bc0d5..a90b25be 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -469,14 +469,19 @@ class MetaFormerBlock(nn.Module): Implementation of one MetaFormer block. """ def __init__(self, dim, - token_mixer=nn.Identity, mlp=Mlp, + token_mixer=nn.Identity, + mlp=Mlp, norm_layer=nn.LayerNorm, drop=0., drop_path=0., - layer_scale_init_value=None, res_scale_init_value=None + layer_scale_init_value=None, + res_scale_init_value=None, + downsample = nn.Identity() ): super().__init__() - + + self.downsample = nn.Identity() + self.norm1 = norm_layer(dim) self.token_mixer = token_mixer(dim=dim, drop=drop) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -494,6 +499,7 @@ class MetaFormerBlock(nn.Module): if res_scale_init_value else nn.Identity() def forward(self, x): + x = self.downsample(x) x = self.res_scale1(x) + \ self.layer_scale1( self.drop_path1( @@ -600,18 +606,18 @@ class MetaFormer(nn.Module): stages = nn.ModuleList() # each stage consists of multiple metaformer blocks cur = 0 for i in range(num_stage): - stage = nn.Sequential( - downsample_layers[i], - *[MetaFormerBlock( - dim=dims[i], - token_mixer=token_mixers[i], - mlp=mlps[i], - norm_layer=norm_layers[i], - drop_path=dp_rates[cur + j], - layer_scale_init_value=layer_scale_init_values[i], - res_scale_init_value=res_scale_init_values[i], + stage = nn.Sequential(*[MetaFormerBlock( + dim=dims[i], + token_mixer=token_mixers[i], + mlp=mlps[i], + norm_layer=norm_layers[i], + drop_path=dp_rates[cur + j], + layer_scale_init_value=layer_scale_init_values[i], + res_scale_init_value=res_scale_init_values[i], + downsample = downsample_layers[i] ) for j in range(depths[i])] ) + stages.append(stage) cur += depths[i] @@ -649,6 +655,17 @@ class MetaFormer(nn.Module): x = self.head(x) return x +def checkpoint_filter_fn(state_dict, model): + + import re + out_dict = {} + for k, v in state_dict.items(): + + k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k) + out_dict[k] = v + return out_dict + + def _create_metaformer(variant, pretrained=False, **kwargs): default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (2, 2, 6, 2)))) out_indices = kwargs.pop('out_indices', default_out_indices)