mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Stem/Downsample rework
This commit is contained in:
parent
26a8e481a5
commit
366aae9304
@ -40,28 +40,57 @@ from ._registry import register_model
|
|||||||
|
|
||||||
__all__ = ['MetaFormer']
|
__all__ = ['MetaFormer']
|
||||||
|
|
||||||
|
|
||||||
|
class Stem(nn.Module):
|
||||||
|
"""
|
||||||
|
Stem implemented by a layer of convolution.
|
||||||
|
Conv2d params constant across all models.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
norm_layer=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=7,
|
||||||
|
stride=4,
|
||||||
|
padding=2
|
||||||
|
)
|
||||||
|
self.norm = norm_layer(out_channels) if norm_layer else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||||
|
return x
|
||||||
|
|
||||||
class Downsampling(nn.Module):
|
class Downsampling(nn.Module):
|
||||||
"""
|
"""
|
||||||
Downsampling implemented by a layer of convolution.
|
Downsampling implemented by a layer of convolution.
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_channels, out_channels,
|
def __init__(self,
|
||||||
kernel_size, stride=1, padding=0,
|
in_channels,
|
||||||
pre_norm=None, post_norm=None, pre_permute=False):
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
norm_layer=None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.pre_norm = pre_norm(in_channels) if pre_norm else nn.Identity()
|
self.norm = norm_layer(in_channels) if norm_layer else nn.Identity()
|
||||||
self.pre_permute = pre_permute
|
self.conv = nn.Conv2d(
|
||||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
|
in_channels,
|
||||||
stride=stride, padding=padding)
|
out_channels,
|
||||||
self.post_norm = post_norm(out_channels) if post_norm else nn.Identity()
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.pre_permute:
|
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||||
# if take [B, H, W, C] as input, permute it to [B, C, H, W]
|
|
||||||
x = x.permute(0, 3, 1, 2)
|
|
||||||
x = self.pre_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
||||||
|
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -462,13 +491,10 @@ class MetaFormer(nn.Module):
|
|||||||
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
||||||
|
|
||||||
|
|
||||||
self.stem = Downsampling(
|
self.stem = Stem(
|
||||||
in_chans,
|
in_chans,
|
||||||
dims[0],
|
dims[0],
|
||||||
kernel_size=7,
|
norm_layer=downsample_norm
|
||||||
stride=4,
|
|
||||||
padding=2,
|
|
||||||
post_norm=downsample_norm
|
|
||||||
)
|
)
|
||||||
|
|
||||||
stages = nn.ModuleList() # each stage consists of multiple metaformer blocks
|
stages = nn.ModuleList() # each stage consists of multiple metaformer blocks
|
||||||
@ -481,8 +507,7 @@ class MetaFormer(nn.Module):
|
|||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=2,
|
stride=2,
|
||||||
padding=1,
|
padding=1,
|
||||||
pre_norm=downsample_norm,
|
norm_layer=downsample_norm,
|
||||||
pre_permute=False
|
|
||||||
)),
|
)),
|
||||||
('blocks', nn.Sequential(*[MetaFormerBlock(
|
('blocks', nn.Sequential(*[MetaFormerBlock(
|
||||||
dim=dims[i],
|
dim=dims[i],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user