diff --git a/mmseg/models/decode_heads/knet_head.py b/mmseg/models/decode_heads/knet_head.py index f73daccb6..78a270277 100644 --- a/mmseg/models/decode_heads/knet_head.py +++ b/mmseg/models/decode_heads/knet_head.py @@ -411,6 +411,9 @@ class IterativeDecodeHead(BaseDecodeHead): def __init__(self, num_stages, kernel_generate_head, kernel_update_head, **kwargs): + # ``IterativeDecodeHead`` would skip initialization of + # ``BaseDecodeHead`` which would be called when building + # ``self.kernel_generate_head``. super(BaseDecodeHead, self).__init__(**kwargs) assert num_stages == len(kernel_update_head) self.num_stages = num_stages @@ -420,6 +423,7 @@ class IterativeDecodeHead(BaseDecodeHead): self.num_classes = self.kernel_generate_head.num_classes self.input_transform = self.kernel_generate_head.input_transform self.ignore_index = self.kernel_generate_head.ignore_index + self.out_channels = self.num_classes for head_cfg in kernel_update_head: self.kernel_update_head.append(build_head(head_cfg))