Stem/Downsample rework

This commit is contained in:
Fredo Guan 2023-02-04 23:37:30 -08:00
parent 26a8e481a5
commit 366aae9304

View File

@ -40,28 +40,57 @@ from ._registry import register_model
__all__ = ['MetaFormer'] __all__ = ['MetaFormer']
class Stem(nn.Module):
"""
Stem implemented by a layer of convolution.
Conv2d params constant across all models.
"""
def __init__(self,
in_channels,
out_channels,
norm_layer=None,
):
super().__init__()
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=7,
stride=4,
padding=2
)
self.norm = norm_layer(out_channels) if norm_layer else nn.Identity()
def forward(self, x):
x = self.conv(x)
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x
class Downsampling(nn.Module): class Downsampling(nn.Module):
""" """
Downsampling implemented by a layer of convolution. Downsampling implemented by a layer of convolution.
""" """
def __init__(self, in_channels, out_channels, def __init__(self,
kernel_size, stride=1, padding=0, in_channels,
pre_norm=None, post_norm=None, pre_permute=False): out_channels,
kernel_size,
stride=1,
padding=0,
norm_layer=None,
):
super().__init__() super().__init__()
self.pre_norm = pre_norm(in_channels) if pre_norm else nn.Identity() self.norm = norm_layer(in_channels) if norm_layer else nn.Identity()
self.pre_permute = pre_permute self.conv = nn.Conv2d(
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, in_channels,
stride=stride, padding=padding) out_channels,
self.post_norm = post_norm(out_channels) if post_norm else nn.Identity() kernel_size=kernel_size,
stride=stride,
padding=padding
)
def forward(self, x): def forward(self, x):
if self.pre_permute: x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
# if take [B, H, W, C] as input, permute it to [B, C, H, W]
x = x.permute(0, 3, 1, 2)
x = self.pre_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
x = self.conv(x) x = self.conv(x)
x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x return x
@ -462,13 +491,10 @@ class MetaFormer(nn.Module):
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
self.stem = Downsampling( self.stem = Stem(
in_chans, in_chans,
dims[0], dims[0],
kernel_size=7, norm_layer=downsample_norm
stride=4,
padding=2,
post_norm=downsample_norm
) )
stages = nn.ModuleList() # each stage consists of multiple metaformer blocks stages = nn.ModuleList() # each stage consists of multiple metaformer blocks
@ -481,8 +507,7 @@ class MetaFormer(nn.Module):
kernel_size=3, kernel_size=3,
stride=2, stride=2,
padding=1, padding=1,
pre_norm=downsample_norm, norm_layer=downsample_norm,
pre_permute=False
)), )),
('blocks', nn.Sequential(*[MetaFormerBlock( ('blocks', nn.Sequential(*[MetaFormerBlock(
dim=dims[i], dim=dims[i],