diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index de5433f6..fbd93ea0 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -441,20 +441,13 @@ class MetaFormerBlock(nn.Module): x = self.res_scale1(x) + \ self.layer_scale1( self.drop_path1( - self.token_mixer( - self.norm1( - x.permute(0, 3, 1, 2) - ).permute(0, 2, 3, 1) - ) + self.token_mixer(self.norm1(x)) ) ) x = self.res_scale2(x) + \ self.layer_scale2( self.drop_path2( - self.mlp(self.norm2( - x.permute(0, 3, 1, 2) - )#.permute(0, 2, 3, 1) - ) + self.mlp(self.norm2(x)) ) ) #x = x.view(B, C, H, W) @@ -915,10 +908,10 @@ def poolformerv1_s24(pretrained=False, **kwargs): dims=[64, 128, 320, 512], downsample_norm=None, token_mixers=Pooling, - mlp_fn=partial(nn.Conv2d, kernel_size=1), + mlp_fn=partial(Conv2dChannelsLast, kernel_size=1), mlp_act=nn.GELU, mlp_bias=True, - norm_layers=GroupNorm1, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=True), layer_scale_init_values=1e-5, res_scale_init_values=None, **kwargs) @@ -931,10 +924,10 @@ def poolformerv1_s36(pretrained=False, **kwargs): dims=[64, 128, 320, 512], downsample_norm=None, token_mixers=Pooling, - mlp_fn=partial(nn.Conv2d, kernel_size=1), + mlp_fn=partial(Conv2dChannelsLast, kernel_size=1), mlp_act=nn.GELU, mlp_bias=True, - norm_layers=GroupNorm1, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=True), layer_scale_init_values=1e-6, res_scale_init_values=None, **kwargs) @@ -947,10 +940,10 @@ def poolformerv1_m36(pretrained=False, **kwargs): dims=[96, 192, 384, 768], downsample_norm=None, token_mixers=Pooling, - mlp_fn=partial(nn.Conv2d, kernel_size=1), + mlp_fn=partial(Conv2dChannelsLast, kernel_size=1), mlp_act=nn.GELU, mlp_bias=True, - norm_layers=GroupNorm1, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=True), layer_scale_init_values=1e-6, res_scale_init_values=None, **kwargs) @@ -963,10 +956,10 @@ def poolformerv1_m48(pretrained=False, **kwargs): dims=[96, 192, 384, 768], downsample_norm=None, token_mixers=Pooling, - mlp_fn=partial(nn.Conv2d, kernel_size=1), + mlp_fn=partial(Conv2dChannelsLast, kernel_size=1), mlp_act=nn.GELU, mlp_bias=True, - norm_layers=GroupNorm1, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=True), layer_scale_init_values=1e-6, res_scale_init_values=None, **kwargs)