mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update metaformers.py
This commit is contained in:
parent
926d886527
commit
7f149f31d4
@ -521,14 +521,11 @@ class MetaFormerBlock(nn.Module):
|
|||||||
norm_layer=nn.LayerNorm,
|
norm_layer=nn.LayerNorm,
|
||||||
drop=0., drop_path=0.,
|
drop=0., drop_path=0.,
|
||||||
layer_scale_init_value=None,
|
layer_scale_init_value=None,
|
||||||
res_scale_init_value=None,
|
res_scale_init_value=None
|
||||||
downsample = nn.Identity()
|
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.downsample = downsample
|
|
||||||
|
|
||||||
self.norm1 = norm_layer(dim)
|
self.norm1 = norm_layer(dim)
|
||||||
self.token_mixer = token_mixer(dim=dim, drop=drop)
|
self.token_mixer = token_mixer(dim=dim, drop=drop)
|
||||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
@ -546,7 +543,6 @@ class MetaFormerBlock(nn.Module):
|
|||||||
if res_scale_init_value else nn.Identity()
|
if res_scale_init_value else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.downsample(x)
|
|
||||||
x = self.res_scale1(x) + \
|
x = self.res_scale1(x) + \
|
||||||
self.layer_scale1(
|
self.layer_scale1(
|
||||||
self.drop_path1(
|
self.drop_path1(
|
||||||
@ -653,18 +649,19 @@ class MetaFormer(nn.Module):
|
|||||||
stages = nn.ModuleList() # each stage consists of multiple metaformer blocks
|
stages = nn.ModuleList() # each stage consists of multiple metaformer blocks
|
||||||
cur = 0
|
cur = 0
|
||||||
for i in range(num_stage):
|
for i in range(num_stage):
|
||||||
stage = nn.Sequential(*[MetaFormerBlock(
|
stage = nn.Sequential(OrderedDict[
|
||||||
dim=dims[i],
|
('downsample', downsample_layers[i]),
|
||||||
token_mixer=token_mixers[i],
|
('blocks', nn.Sequential(*[MetaFormerBlock(
|
||||||
mlp=mlps[i],
|
dim=dims[i],
|
||||||
norm_layer=norm_layers[i],
|
token_mixer=token_mixers[i],
|
||||||
drop_path=dp_rates[cur + j],
|
mlp=mlps[i],
|
||||||
layer_scale_init_value=layer_scale_init_values[i],
|
norm_layer=norm_layers[i],
|
||||||
res_scale_init_value=res_scale_init_values[i],
|
drop_path=dp_rates[cur + j],
|
||||||
downsample = downsample_layers[i]
|
layer_scale_init_value=layer_scale_init_values[i],
|
||||||
) for j in range(depths[i])]
|
res_scale_init_value=res_scale_init_values[i]
|
||||||
|
) for j in range(depths[i])])
|
||||||
|
)]
|
||||||
)
|
)
|
||||||
|
|
||||||
stages.append(stage)
|
stages.append(stage)
|
||||||
cur += depths[i]
|
cur += depths[i]
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user