Update metaformers.py

This commit is contained in:
Fredo Guan 2023-01-17 11:05:30 -08:00
parent 473403d905
commit 13876ada4c

View File

@ -216,7 +216,7 @@ cfgs_v2 = generate_default_cfgs({
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21k.pth',
num_classes=21841),
})
'''
class Downsampling(nn.Module):
"""
Downsampling implemented by a layer of convolution.
@ -255,15 +255,15 @@ class Downsampling(nn.Module):
self.post_norm = post_norm(out_channels) if post_norm else nn.Identity()
def forward(self, x):
#print(x.shape)
print(x.shape)
x = self.pre_norm(x)
#print(x.shape)
print(x.shape)
x = self.conv(x)
#print(x.shape)
print(x.shape)
x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
#print(x.shape)
print(x.shape)
return x
'''
class Scale(nn.Module):
"""
Scale vector by element multiplications.
@ -612,8 +612,9 @@ class MetaFormerBlock(nn.Module):
if res_scale_init_value else nn.Identity()
def forward(self, x):
B, C, H, W = x.shape
x = x.view(B, H, W, C)
#B, C, H, W = x.shape
#x = x.view(B, H, W, C)
x = x.permute(0, 2, 3, 1)
x = self.res_scale1(x) + \
self.layer_scale1(
self.drop_path1(
@ -626,7 +627,8 @@ class MetaFormerBlock(nn.Module):
self.mlp(self.norm2(x))
)
)
x = x.view(B, C, H, W)
#x = x.view(B, C, H, W)
x = x.permute(0, 3, 1, 2)
return x
class MetaFormer(nn.Module):