This commit is contained in:
Fredo Guan 2023-02-05 01:04:55 -08:00
parent 02fcc30eaa
commit ab6225b941

View File

@ -63,7 +63,8 @@ class Stem(nn.Module):
def forward(self, x):
x = self.conv(x)
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
x = self.norm(x.permute(0, 2, 3, 1))
# [B, H, W, C]
return x
class Downsampling(nn.Module):
@ -89,8 +90,8 @@ class Downsampling(nn.Module):
)
def forward(self, x):
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
x = self.conv(x)
x = self.norm(x).permute(0, 3, 1, 2)
x = self.conv(x).permute(0, 2, 3, 1)
return x
@ -396,7 +397,8 @@ class MetaFormerBlock(nn.Module):
if res_scale_init_value else nn.Identity()
def forward(self, x):
x = x.permute(0, 2, 3, 1)
print(x.shape)
#x = x.permute(0, 2, 3, 1)
x = self.res_scale1(x) + \
self.layer_scale1(
self.drop_path1(
@ -409,7 +411,8 @@ class MetaFormerBlock(nn.Module):
self.mlp(self.norm2(x))
)
)
x = x.permute(0, 3, 1, 2)
#x = x.permute(0, 3, 1, 2)
return x
class MetaFormer(nn.Module):