mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
add support for 4D output (#150)
This commit is contained in:
parent
a2738fd9be
commit
08f30ea497
@ -265,6 +265,8 @@ class EncoderDecoder(BaseSegmentor):
|
|||||||
seg_logit = self.inference(img, img_meta, rescale)
|
seg_logit = self.inference(img, img_meta, rescale)
|
||||||
seg_pred = seg_logit.argmax(dim=1)
|
seg_pred = seg_logit.argmax(dim=1)
|
||||||
if torch.onnx.is_in_onnx_export():
|
if torch.onnx.is_in_onnx_export():
|
||||||
|
# our inference backend only support 4D output
|
||||||
|
seg_pred = seg_pred.unsqueeze(0)
|
||||||
return seg_pred
|
return seg_pred
|
||||||
seg_pred = seg_pred.cpu().numpy()
|
seg_pred = seg_pred.cpu().numpy()
|
||||||
# unravel batch dim
|
# unravel batch dim
|
||||||
|
Loading…
x
Reference in New Issue
Block a user