From 588a2c036a4254a1c06e02e61ae2bc93c33cd21c Mon Sep 17 00:00:00 2001
From: robin Han <drcut@users.noreply.github.com>
Date: Wed, 23 Sep 2020 17:01:20 +0800
Subject: [PATCH] add support for 4D output (#150)

---
 mmseg/models/segmentors/encoder_decoder.py | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/mmseg/models/segmentors/encoder_decoder.py b/mmseg/models/segmentors/encoder_decoder.py
index 9adf65bd0..2284906e3 100644
--- a/mmseg/models/segmentors/encoder_decoder.py
+++ b/mmseg/models/segmentors/encoder_decoder.py
@@ -265,6 +265,8 @@ class EncoderDecoder(BaseSegmentor):
         seg_logit = self.inference(img, img_meta, rescale)
         seg_pred = seg_logit.argmax(dim=1)
         if torch.onnx.is_in_onnx_export():
+            # our inference backend only support 4D output
+            seg_pred = seg_pred.unsqueeze(0)
             return seg_pred
         seg_pred = seg_pred.cpu().numpy()
         # unravel batch dim