add support for 4D output (#150)

pull/1801/head
robin Han 2020-09-23 17:01:20 +08:00 committed by GitHub
parent 7baed6513e
commit 588a2c036a
1 changed files with 2 additions and 0 deletions

View File

@ -265,6 +265,8 @@ class EncoderDecoder(BaseSegmentor):
seg_logit = self.inference(img, img_meta, rescale)
seg_pred = seg_logit.argmax(dim=1)
if torch.onnx.is_in_onnx_export():
# our inference backend only support 4D output
seg_pred = seg_pred.unsqueeze(0)
return seg_pred
seg_pred = seg_pred.cpu().numpy()
# unravel batch dim