pull/748/head
weishengyu 2021-05-26 14:38:40 +08:00
parent 1e938f7256
commit c9d5694cf2
1 changed files with 3 additions and 7 deletions

View File

@ -221,26 +221,22 @@ class Stage(TheseusLayer):
self._num_modules = num_modules
self.stage_func_list = []
self.stage_func_list = nn.LayerList()
for i in range(num_modules):
if i == num_modules - 1 and not multi_scale_output:
stage_func = self.add_sublayer(
"stage_{}_{}".format(name, i + 1),
self.stage_func_list.append(
HighResolutionModule(
num_filters=num_filters,
has_se=has_se,
multi_scale_output=False,
name=name + '_' + str(i + 1)))
else:
stage_func = self.add_sublayer(
"stage_{}_{}".format(name, i + 1),
self.stage_func_list.append(
HighResolutionModule(
num_filters=num_filters,
has_se=has_se,
name=name + '_' + str(i + 1)))
self.stage_func_list.append(stage_func)
def forward(self, input, res_dict=None):
out = input
for idx in range(self._num_modules):