mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Stem/Downsample rework
This commit is contained in:
parent
26a8e481a5
commit
366aae9304
@ -40,28 +40,57 @@ from ._registry import register_model
|
||||
|
||||
__all__ = ['MetaFormer']
|
||||
|
||||
|
||||
class Stem(nn.Module):
|
||||
"""
|
||||
Stem implemented by a layer of convolution.
|
||||
Conv2d params constant across all models.
|
||||
"""
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
norm_layer=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=7,
|
||||
stride=4,
|
||||
padding=2
|
||||
)
|
||||
self.norm = norm_layer(out_channels) if norm_layer else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
return x
|
||||
|
||||
class Downsampling(nn.Module):
|
||||
"""
|
||||
Downsampling implemented by a layer of convolution.
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels,
|
||||
kernel_size, stride=1, padding=0,
|
||||
pre_norm=None, post_norm=None, pre_permute=False):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
norm_layer=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.pre_norm = pre_norm(in_channels) if pre_norm else nn.Identity()
|
||||
self.pre_permute = pre_permute
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
|
||||
stride=stride, padding=padding)
|
||||
self.post_norm = post_norm(out_channels) if post_norm else nn.Identity()
|
||||
self.norm = norm_layer(in_channels) if norm_layer else nn.Identity()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.pre_permute:
|
||||
# if take [B, H, W, C] as input, permute it to [B, C, H, W]
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
x = self.pre_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
|
||||
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
x = self.conv(x)
|
||||
x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
return x
|
||||
|
||||
|
||||
@ -462,13 +491,10 @@ class MetaFormer(nn.Module):
|
||||
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
||||
|
||||
|
||||
self.stem = Downsampling(
|
||||
self.stem = Stem(
|
||||
in_chans,
|
||||
dims[0],
|
||||
kernel_size=7,
|
||||
stride=4,
|
||||
padding=2,
|
||||
post_norm=downsample_norm
|
||||
norm_layer=downsample_norm
|
||||
)
|
||||
|
||||
stages = nn.ModuleList() # each stage consists of multiple metaformer blocks
|
||||
@ -481,8 +507,7 @@ class MetaFormer(nn.Module):
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
pre_norm=downsample_norm,
|
||||
pre_permute=False
|
||||
norm_layer=downsample_norm,
|
||||
)),
|
||||
('blocks', nn.Sequential(*[MetaFormerBlock(
|
||||
dim=dims[i],
|
||||
|
Loading…
x
Reference in New Issue
Block a user