diff --git a/mmseg/models/segmentors/encoder_decoder.py b/mmseg/models/segmentors/encoder_decoder.py
index 9adf65bd0..2284906e3 100644
--- a/mmseg/models/segmentors/encoder_decoder.py
+++ b/mmseg/models/segmentors/encoder_decoder.py
@@ -265,6 +265,8 @@ class EncoderDecoder(BaseSegmentor):
         seg_logit = self.inference(img, img_meta, rescale)
         seg_pred = seg_logit.argmax(dim=1)
         if torch.onnx.is_in_onnx_export():
+            # our inference backend only support 4D output
+            seg_pred = seg_pred.unsqueeze(0)
             return seg_pred
         seg_pred = seg_pred.cpu().numpy()
         # unravel batch dim