diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 7632c69a..a5fffab9 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -216,7 +216,7 @@ cfgs_v2 = generate_default_cfgs({ url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21k.pth', num_classes=21841), }) -''' + class Downsampling(nn.Module): """ Downsampling implemented by a layer of convolution. @@ -255,15 +255,15 @@ class Downsampling(nn.Module): self.post_norm = post_norm(out_channels) if post_norm else nn.Identity() def forward(self, x): - #print(x.shape) + print(x.shape) x = self.pre_norm(x) - #print(x.shape) + print(x.shape) x = self.conv(x) - #print(x.shape) + print(x.shape) x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - #print(x.shape) + print(x.shape) return x - +''' class Scale(nn.Module): """ Scale vector by element multiplications. @@ -612,8 +612,9 @@ class MetaFormerBlock(nn.Module): if res_scale_init_value else nn.Identity() def forward(self, x): - B, C, H, W = x.shape - x = x.view(B, H, W, C) + #B, C, H, W = x.shape + #x = x.view(B, H, W, C) + x = x.permute(0, 2, 3, 1) x = self.res_scale1(x) + \ self.layer_scale1( self.drop_path1( @@ -626,7 +627,8 @@ class MetaFormerBlock(nn.Module): self.mlp(self.norm2(x)) ) ) - x = x.view(B, C, H, W) + #x = x.view(B, C, H, W) + x = x.permute(0, 3, 1, 2) return x class MetaFormer(nn.Module):