add support for 4D output (#150)
parent
7baed6513e
commit
588a2c036a
|
@ -265,6 +265,8 @@ class EncoderDecoder(BaseSegmentor):
|
|||
seg_logit = self.inference(img, img_meta, rescale)
|
||||
seg_pred = seg_logit.argmax(dim=1)
|
||||
if torch.onnx.is_in_onnx_export():
|
||||
# our inference backend only support 4D output
|
||||
seg_pred = seg_pred.unsqueeze(0)
|
||||
return seg_pred
|
||||
seg_pred = seg_pred.cpu().numpy()
|
||||
# unravel batch dim
|
||||
|
|
Loading…
Reference in New Issue