Merge pull request #2334 from MengzhangLI/fix_knet_dev1.x

[Fix] Fix KNet IterativeDecodeHead bug in dev-1.x branch
This commit is contained in:
Miao Zheng 2022-11-22 23:07:50 +08:00 committed by GitHub
commit b19e54cb73
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -413,6 +413,9 @@ class IterativeDecodeHead(BaseDecodeHead):
def __init__(self, num_stages, kernel_generate_head, kernel_update_head,
**kwargs):
# ``IterativeDecodeHead`` would skip initialization of
# ``BaseDecodeHead`` which would be called when building
# ``self.kernel_generate_head``.
super(BaseDecodeHead, self).__init__(**kwargs)
assert num_stages == len(kernel_update_head)
self.num_stages = num_stages
@ -422,6 +425,7 @@ class IterativeDecodeHead(BaseDecodeHead):
self.num_classes = self.kernel_generate_head.num_classes
self.input_transform = self.kernel_generate_head.input_transform
self.ignore_index = self.kernel_generate_head.ignore_index
self.out_channels = self.num_classes
for head_cfg in kernel_update_head:
self.kernel_update_head.append(MODELS.build(head_cfg))