diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 361936d4..9aef43a5 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -437,8 +437,9 @@ class MetaFormerBlock(nn.Module): if res_scale_init_value else nn.Identity() def forward(self, x): - B, C, H, W = x.shape - x = x.view(B, H, W, C) + #B, C, H, W = x.shape + #x = x.view(B, H, W, C) + x = x.permute(0, 2, 3, 1) x = self.res_scale1(x) + \ self.layer_scale1( self.drop_path1( @@ -451,7 +452,8 @@ class MetaFormerBlock(nn.Module): self.mlp(self.norm2(x)) ) ) - x = x.view(B, C, H, W) + #x = x.view(B, C, H, W) + x = x.permute(0, 3, 1, 2) return x class MetaFormer(nn.Module): @@ -630,11 +632,12 @@ class MetaFormer(nn.Module): if pre_logits: return x - x = self.global_pool(x) - x = x.squeeze() - x = self.norm(x) + #x = self.global_pool(x) + #x = x.squeeze() + #x = self.norm(x) # (B, H, W, C) -> (B, C) - x = self.head(x) + #x = self.head(x) + x=self.head(self.norm(x.mean([2, 3]))) return x def forward_features(self, x): @@ -655,6 +658,7 @@ def checkpoint_filter_fn(state_dict, model): import re out_dict = {} for k, v in state_dict.items(): + ''' k = k.replace('proj', 'conv') k = re.sub(r'layer_scale_([0-9]+)', r'layer_scale\1.scale', k) k = k.replace('network.1', 'downsample_layers.1') @@ -664,6 +668,7 @@ def checkpoint_filter_fn(state_dict, model): k = k.replace('network.4', 'network.2') k = k.replace('network.6', 'network.3') k = k.replace('network', 'stages') + ''' k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k) k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k) k = k.replace('stages.0.downsample', 'patch_embed')