dbg
parent
1e938f7256
commit
c9d5694cf2
|
@ -221,26 +221,22 @@ class Stage(TheseusLayer):
|
|||
|
||||
self._num_modules = num_modules
|
||||
|
||||
self.stage_func_list = []
|
||||
self.stage_func_list = nn.LayerList()
|
||||
for i in range(num_modules):
|
||||
if i == num_modules - 1 and not multi_scale_output:
|
||||
stage_func = self.add_sublayer(
|
||||
"stage_{}_{}".format(name, i + 1),
|
||||
self.stage_func_list.append(
|
||||
HighResolutionModule(
|
||||
num_filters=num_filters,
|
||||
has_se=has_se,
|
||||
multi_scale_output=False,
|
||||
name=name + '_' + str(i + 1)))
|
||||
else:
|
||||
stage_func = self.add_sublayer(
|
||||
"stage_{}_{}".format(name, i + 1),
|
||||
self.stage_func_list.append(
|
||||
HighResolutionModule(
|
||||
num_filters=num_filters,
|
||||
has_se=has_se,
|
||||
name=name + '_' + str(i + 1)))
|
||||
|
||||
self.stage_func_list.append(stage_func)
|
||||
|
||||
def forward(self, input, res_dict=None):
|
||||
out = input
|
||||
for idx in range(self._num_modules):
|
||||
|
|
Loading…
Reference in New Issue