[Fix] Input previous results for the last cascade_decode_head (#1450)
* [Fix] Input previous results for the latter cascade_decode_head * minorspull/1477/head
parent
3f797072d8
commit
23ae1ebab6
|
@ -75,8 +75,12 @@ class CascadeEncoderDecoder(EncoderDecoder):
|
||||||
|
|
||||||
for i in range(1, self.num_stages):
|
for i in range(1, self.num_stages):
|
||||||
# forward test again, maybe unnecessary for most methods.
|
# forward test again, maybe unnecessary for most methods.
|
||||||
prev_outputs = self.decode_head[i - 1].forward_test(
|
if i == 1:
|
||||||
x, img_metas, self.test_cfg)
|
prev_outputs = self.decode_head[0].forward_test(
|
||||||
|
x, img_metas, self.test_cfg)
|
||||||
|
else:
|
||||||
|
prev_outputs = self.decode_head[i - 1].forward_test(
|
||||||
|
x, prev_outputs, img_metas, self.test_cfg)
|
||||||
loss_decode = self.decode_head[i].forward_train(
|
loss_decode = self.decode_head[i].forward_train(
|
||||||
x, prev_outputs, img_metas, gt_semantic_seg, self.train_cfg)
|
x, prev_outputs, img_metas, gt_semantic_seg, self.train_cfg)
|
||||||
losses.update(add_prefix(loss_decode, f'decode_{i}'))
|
losses.update(add_prefix(loss_decode, f'decode_{i}'))
|
||||||
|
|
Loading…
Reference in New Issue