diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 4c15a308..377bcd14 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -40,28 +40,57 @@ from ._registry import register_model __all__ = ['MetaFormer'] + +class Stem(nn.Module): + """ + Stem implemented by a layer of convolution. + Conv2d params constant across all models. + """ + def __init__(self, + in_channels, + out_channels, + norm_layer=None, + ): + super().__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=7, + stride=4, + padding=2 + ) + self.norm = norm_layer(out_channels) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.conv(x) + x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + class Downsampling(nn.Module): """ Downsampling implemented by a layer of convolution. """ - def __init__(self, in_channels, out_channels, - kernel_size, stride=1, padding=0, - pre_norm=None, post_norm=None, pre_permute=False): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + norm_layer=None, + ): super().__init__() - self.pre_norm = pre_norm(in_channels) if pre_norm else nn.Identity() - self.pre_permute = pre_permute - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, - stride=stride, padding=padding) - self.post_norm = post_norm(out_channels) if post_norm else nn.Identity() + self.norm = norm_layer(in_channels) if norm_layer else nn.Identity() + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding + ) def forward(self, x): - if self.pre_permute: - # if take [B, H, W, C] as input, permute it to [B, C, H, W] - x = x.permute(0, 3, 1, 2) - x = self.pre_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - + x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) x = self.conv(x) - x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) return x @@ -462,13 +491,10 @@ class MetaFormer(nn.Module): dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] - self.stem = Downsampling( + self.stem = Stem( in_chans, dims[0], - kernel_size=7, - stride=4, - padding=2, - post_norm=downsample_norm + norm_layer=downsample_norm ) stages = nn.ModuleList() # each stage consists of multiple metaformer blocks @@ -481,8 +507,7 @@ class MetaFormer(nn.Module): kernel_size=3, stride=2, padding=1, - pre_norm=downsample_norm, - pre_permute=False + norm_layer=downsample_norm, )), ('blocks', nn.Sequential(*[MetaFormerBlock( dim=dims[i],