Update metaformers.py
parent
01f671ed08
commit
c7e1819ca5
|
@ -588,7 +588,7 @@ class MetaFormer(nn.Module):
|
|||
if not isinstance(res_scale_init_values, (list, tuple)):
|
||||
res_scale_init_values = [res_scale_init_values] * num_stage
|
||||
|
||||
self.stages = nn.ModuleList() # each stage consists of multiple metaformer blocks
|
||||
stages = nn.ModuleList() # each stage consists of multiple metaformer blocks
|
||||
cur = 0
|
||||
for i in range(num_stage):
|
||||
stage = nn.Sequential(
|
||||
|
@ -603,8 +603,11 @@ class MetaFormer(nn.Module):
|
|||
)
|
||||
self.stages.append(stage)
|
||||
cur += depths[i]
|
||||
|
||||
|
||||
self.stages = nn.Sequential(*stages)
|
||||
self.norm = output_norm(dims[-1])
|
||||
|
||||
|
||||
|
||||
if head_dropout > 0.0:
|
||||
self.head = head_fn(dims[-1], num_classes, head_dropout=head_dropout)
|
||||
|
|
Loading…
Reference in New Issue