From ebd6b75f3b918e96b4a376397365959710565c03 Mon Sep 17 00:00:00 2001
From: RunningLeon <mnsheng@yeah.net>
Date: Wed, 28 Jun 2023 16:15:01 +0800
Subject: [PATCH] Fix torch2onnx for pointpillars with multi-level outputs
 (#2210)

* temp fix

* fix

* update
---
 .../voxel-detection/voxel-detection_static.py |  3 +-
 .../mmdet3d/deploy/voxel_detection_model.py   | 15 +++++---
 .../codebase/mmdet3d/models/mvx_two_stage.py  | 35 ++++++++++++++++---
 3 files changed, 43 insertions(+), 10 deletions(-)

diff --git a/configs/mmdet3d/voxel-detection/voxel-detection_static.py b/configs/mmdet3d/voxel-detection/voxel-detection_static.py
index bba7e819f..639787ec6 100644
--- a/configs/mmdet3d/voxel-detection/voxel-detection_static.py
+++ b/configs/mmdet3d/voxel-detection/voxel-detection_static.py
@@ -3,4 +3,5 @@ codebase_config = dict(
     type='mmdet3d', task='VoxelDetection', model_type='end2end')
 onnx_config = dict(
     input_names=['voxels', 'num_points', 'coors'],
-    output_names=['cls_score', 'bbox_pred', 'dir_cls_pred'])
+    # need to change output_names for head with multi-level features
+    output_names=['cls_score0', 'bbox_pred0', 'dir_cls_pred0'])
diff --git a/mmdeploy/codebase/mmdet3d/deploy/voxel_detection_model.py b/mmdeploy/codebase/mmdet3d/deploy/voxel_detection_model.py
index 949f2902b..982085650 100644
--- a/mmdeploy/codebase/mmdet3d/deploy/voxel_detection_model.py
+++ b/mmdeploy/codebase/mmdet3d/deploy/voxel_detection_model.py
@@ -90,7 +90,14 @@ class VoxelDetectionModel(BaseBackendModel):
         }
 
         outputs = self.wrapper(input_dict)
-
+        num_level = len(outputs) // 3
+        new_outputs = dict(
+            cls_score=[outputs[f'cls_score{i}'] for i in range(num_level)],
+            bbox_pred=[outputs[f'bbox_pred{i}'] for i in range(num_level)],
+            dir_cls_pred=[
+                outputs[f'dir_cls_pred{i}'] for i in range(num_level)
+            ])
+        outputs = new_outputs
         if data_samples is None:
             return outputs
 
@@ -239,9 +246,9 @@ class VoxelDetectionModel(BaseBackendModel):
 
         if not hasattr(head, 'task_heads'):
             data_instances_3d = head.predict_by_feat(
-                cls_scores=[cls_score],
-                bbox_preds=[bbox_pred],
-                dir_cls_preds=[dir_cls_pred],
+                cls_scores=cls_score,
+                bbox_preds=bbox_pred,
+                dir_cls_preds=dir_cls_pred,
                 batch_input_metas=batch_input_metas,
                 cfg=cfg)
 
diff --git a/mmdeploy/codebase/mmdet3d/models/mvx_two_stage.py b/mmdeploy/codebase/mmdet3d/models/mvx_two_stage.py
index 12df74ff5..b9d62e684 100644
--- a/mmdeploy/codebase/mmdet3d/models/mvx_two_stage.py
+++ b/mmdeploy/codebase/mmdet3d/models/mvx_two_stage.py
@@ -2,6 +2,7 @@
 import torch
 
 from mmdeploy.core import FUNCTION_REWRITER
+from mmdeploy.utils import get_ir_config
 
 
 @FUNCTION_REWRITER.register_rewriter(
@@ -52,10 +53,21 @@ def mvxtwostagedetector__forward(self, inputs: list, **kwargs):
         inputs (list): input list comprises voxels, num_points and coors
 
     Returns:
-        bbox (Tensor): Decoded bbox after nms
-        scores (Tensor): bbox scores
-        labels (Tensor): bbox labels
+        tuple: A tuple of classification scores, bbox and direction
+            classification prediction.
+
+            - cls_scores (list[Tensor]): Classification scores for all
+                scale levels, each is a 4D-tensor, the channels number
+                is num_base_priors * num_classes.
+            - bbox_preds (list[Tensor]): Box energies / deltas for all
+                scale levels, each is a 4D-tensor, the channels number
+                is num_base_priors * C.
+            - dir_cls_preds (list[Tensor|None]): Direction classification
+                predictions for all scale levels, each is a 4D-tensor,
+                the channels number is num_base_priors * 2.
     """
+    ctx = FUNCTION_REWRITER.get_context()
+    deploy_cfg = ctx.cfg
     batch_inputs_dict = {
         'voxels': {
             'voxels': inputs[0],
@@ -82,5 +94,18 @@ def mvxtwostagedetector__forward(self, inputs: list, **kwargs):
         dir_scores = torch.cat(dir_scores, dim=1)
         return scores, bbox_preds, dir_scores
     else:
-        cls_score, bbox_pred, dir_cls_pred = outs[0][0], outs[1][0], outs[2][0]
-        return cls_score, bbox_pred, dir_cls_pred
+        preds = []
+        expect_names = []
+        for i in range(len(outs[0])):
+            preds += [outs[0][i], outs[1][i], outs[2][i]]
+            expect_names += [
+                f'cls_score{i}', f'bbox_pred{i}', f'dir_cls_pred{i}'
+            ]
+        # check if output_names is set correctly.
+        onnx_cfg = get_ir_config(deploy_cfg)
+        output_names = onnx_cfg['output_names']
+        if output_names != list(expect_names):
+            raise RuntimeError(f'`output_names` should be {expect_names} '
+                               f'but given {output_names}\n'
+                               f'Deploy config:\n{deploy_cfg.pretty_text}')
+        return tuple(preds)