diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 6fd03b35..14e54a1d 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -63,7 +63,8 @@ class Stem(nn.Module): def forward(self, x): x = self.conv(x) - x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + x = self.norm(x.permute(0, 2, 3, 1)) + # [B, H, W, C] return x class Downsampling(nn.Module): @@ -89,8 +90,8 @@ class Downsampling(nn.Module): ) def forward(self, x): - x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - x = self.conv(x) + x = self.norm(x).permute(0, 3, 1, 2) + x = self.conv(x).permute(0, 2, 3, 1) return x @@ -396,7 +397,8 @@ class MetaFormerBlock(nn.Module): if res_scale_init_value else nn.Identity() def forward(self, x): - x = x.permute(0, 2, 3, 1) + print(x.shape) + #x = x.permute(0, 2, 3, 1) x = self.res_scale1(x) + \ self.layer_scale1( self.drop_path1( @@ -409,7 +411,8 @@ class MetaFormerBlock(nn.Module): self.mlp(self.norm2(x)) ) ) - x = x.permute(0, 3, 1, 2) + #x = x.permute(0, 3, 1, 2) + return x class MetaFormer(nn.Module):