From ab6225b9414f534815958036f6d5a392038d7ab2 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 5 Feb 2023 01:04:55 -0800 Subject: [PATCH] try NHWC --- timm/models/metaformers.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 6fd03b35..14e54a1d 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -63,7 +63,8 @@ class Stem(nn.Module): def forward(self, x): x = self.conv(x) - x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + x = self.norm(x.permute(0, 2, 3, 1)) + # [B, H, W, C] return x class Downsampling(nn.Module): @@ -89,8 +90,8 @@ class Downsampling(nn.Module): ) def forward(self, x): - x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - x = self.conv(x) + x = self.norm(x).permute(0, 3, 1, 2) + x = self.conv(x).permute(0, 2, 3, 1) return x @@ -396,7 +397,8 @@ class MetaFormerBlock(nn.Module): if res_scale_init_value else nn.Identity() def forward(self, x): - x = x.permute(0, 2, 3, 1) + print(x.shape) + #x = x.permute(0, 2, 3, 1) x = self.res_scale1(x) + \ self.layer_scale1( self.drop_path1( @@ -409,7 +411,8 @@ class MetaFormerBlock(nn.Module): self.mlp(self.norm2(x)) ) ) - x = x.permute(0, 3, 1, 2) + #x = x.permute(0, 3, 1, 2) + return x class MetaFormer(nn.Module):