Stem/Downsample rework

This commit is contained in:
Fredo Guan 2023-02-04 23:37:30 -08:00
parent 26a8e481a5
commit 366aae9304

View File

@ -40,28 +40,57 @@ from ._registry import register_model
__all__ = ['MetaFormer']
class Stem(nn.Module):
"""
Stem implemented by a layer of convolution.
Conv2d params constant across all models.
"""
def __init__(self,
in_channels,
out_channels,
norm_layer=None,
):
super().__init__()
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=7,
stride=4,
padding=2
)
self.norm = norm_layer(out_channels) if norm_layer else nn.Identity()
def forward(self, x):
x = self.conv(x)
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x
class Downsampling(nn.Module):
"""
Downsampling implemented by a layer of convolution.
"""
def __init__(self, in_channels, out_channels,
kernel_size, stride=1, padding=0,
pre_norm=None, post_norm=None, pre_permute=False):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
norm_layer=None,
):
super().__init__()
self.pre_norm = pre_norm(in_channels) if pre_norm else nn.Identity()
self.pre_permute = pre_permute
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
stride=stride, padding=padding)
self.post_norm = post_norm(out_channels) if post_norm else nn.Identity()
self.norm = norm_layer(in_channels) if norm_layer else nn.Identity()
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding
)
def forward(self, x):
if self.pre_permute:
# if take [B, H, W, C] as input, permute it to [B, C, H, W]
x = x.permute(0, 3, 1, 2)
x = self.pre_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
x = self.conv(x)
x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x
@ -462,13 +491,10 @@ class MetaFormer(nn.Module):
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
self.stem = Downsampling(
self.stem = Stem(
in_chans,
dims[0],
kernel_size=7,
stride=4,
padding=2,
post_norm=downsample_norm
norm_layer=downsample_norm
)
stages = nn.ModuleList() # each stage consists of multiple metaformer blocks
@ -481,8 +507,7 @@ class MetaFormer(nn.Module):
kernel_size=3,
stride=2,
padding=1,
pre_norm=downsample_norm,
pre_permute=False
norm_layer=downsample_norm,
)),
('blocks', nn.Sequential(*[MetaFormerBlock(
dim=dims[i],