Update metaformers.py

This commit is contained in:
Fredo Guan 2023-01-08 09:37:49 -08:00
parent 7aa3459caf
commit 944fd549c4

View File

@ -650,7 +650,7 @@ class MetaFormer(nn.Module):
stages = nn.ModuleList() # each stage consists of multiple metaformer blocks stages = nn.ModuleList() # each stage consists of multiple metaformer blocks
cur = 0 cur = 0
for i in range(num_stage): for i in range(num_stage):
stage = nn.Sequential(OrderedDict[ stage = nn.Sequential(OrderedDict([
('downsample', downsample_layers[i]), ('downsample', downsample_layers[i]),
('blocks', nn.Sequential(*[MetaFormerBlock( ('blocks', nn.Sequential(*[MetaFormerBlock(
dim=dims[i], dim=dims[i],
@ -661,7 +661,7 @@ class MetaFormer(nn.Module):
layer_scale_init_value=layer_scale_init_values[i], layer_scale_init_value=layer_scale_init_values[i],
res_scale_init_value=res_scale_init_values[i] res_scale_init_value=res_scale_init_values[i]
) for j in range(depths[i])]) ) for j in range(depths[i])])
)] )])
) )
stages.append(stage) stages.append(stage)
cur += depths[i] cur += depths[i]