diff --git a/mmseg/models/segmentors/cascade_encoder_decoder.py b/mmseg/models/segmentors/cascade_encoder_decoder.py index 7f9f9006c..1913a22e2 100644 --- a/mmseg/models/segmentors/cascade_encoder_decoder.py +++ b/mmseg/models/segmentors/cascade_encoder_decoder.py @@ -75,8 +75,12 @@ class CascadeEncoderDecoder(EncoderDecoder): for i in range(1, self.num_stages): # forward test again, maybe unnecessary for most methods. - prev_outputs = self.decode_head[i - 1].forward_test( - x, img_metas, self.test_cfg) + if i == 1: + prev_outputs = self.decode_head[0].forward_test( + x, img_metas, self.test_cfg) + else: + prev_outputs = self.decode_head[i - 1].forward_test( + x, prev_outputs, img_metas, self.test_cfg) loss_decode = self.decode_head[i].forward_train( x, prev_outputs, img_metas, gt_semantic_seg, self.train_cfg) losses.update(add_prefix(loss_decode, f'decode_{i}'))