diff --git a/mmseg/models/segmentors/encoder_decoder.py b/mmseg/models/segmentors/encoder_decoder.py index 9adf65bd0..2284906e3 100644 --- a/mmseg/models/segmentors/encoder_decoder.py +++ b/mmseg/models/segmentors/encoder_decoder.py @@ -265,6 +265,8 @@ class EncoderDecoder(BaseSegmentor): seg_logit = self.inference(img, img_meta, rescale) seg_pred = seg_logit.argmax(dim=1) if torch.onnx.is_in_onnx_export(): + # our inference backend only support 4D output + seg_pred = seg_pred.unsqueeze(0) return seg_pred seg_pred = seg_pred.cpu().numpy() # unravel batch dim