From ebd6b75f3b918e96b4a376397365959710565c03 Mon Sep 17 00:00:00 2001 From: RunningLeon <mnsheng@yeah.net> Date: Wed, 28 Jun 2023 16:15:01 +0800 Subject: [PATCH] Fix torch2onnx for pointpillars with multi-level outputs (#2210) * temp fix * fix * update --- .../voxel-detection/voxel-detection_static.py | 3 +- .../mmdet3d/deploy/voxel_detection_model.py | 15 +++++--- .../codebase/mmdet3d/models/mvx_two_stage.py | 35 ++++++++++++++++--- 3 files changed, 43 insertions(+), 10 deletions(-) diff --git a/configs/mmdet3d/voxel-detection/voxel-detection_static.py b/configs/mmdet3d/voxel-detection/voxel-detection_static.py index bba7e819f..639787ec6 100644 --- a/configs/mmdet3d/voxel-detection/voxel-detection_static.py +++ b/configs/mmdet3d/voxel-detection/voxel-detection_static.py @@ -3,4 +3,5 @@ codebase_config = dict( type='mmdet3d', task='VoxelDetection', model_type='end2end') onnx_config = dict( input_names=['voxels', 'num_points', 'coors'], - output_names=['cls_score', 'bbox_pred', 'dir_cls_pred']) + # need to change output_names for head with multi-level features + output_names=['cls_score0', 'bbox_pred0', 'dir_cls_pred0']) diff --git a/mmdeploy/codebase/mmdet3d/deploy/voxel_detection_model.py b/mmdeploy/codebase/mmdet3d/deploy/voxel_detection_model.py index 949f2902b..982085650 100644 --- a/mmdeploy/codebase/mmdet3d/deploy/voxel_detection_model.py +++ b/mmdeploy/codebase/mmdet3d/deploy/voxel_detection_model.py @@ -90,7 +90,14 @@ class VoxelDetectionModel(BaseBackendModel): } outputs = self.wrapper(input_dict) - + num_level = len(outputs) // 3 + new_outputs = dict( + cls_score=[outputs[f'cls_score{i}'] for i in range(num_level)], + bbox_pred=[outputs[f'bbox_pred{i}'] for i in range(num_level)], + dir_cls_pred=[ + outputs[f'dir_cls_pred{i}'] for i in range(num_level) + ]) + outputs = new_outputs if data_samples is None: return outputs @@ -239,9 +246,9 @@ class VoxelDetectionModel(BaseBackendModel): if not hasattr(head, 'task_heads'): data_instances_3d = head.predict_by_feat( - cls_scores=[cls_score], - bbox_preds=[bbox_pred], - dir_cls_preds=[dir_cls_pred], + cls_scores=cls_score, + bbox_preds=bbox_pred, + dir_cls_preds=dir_cls_pred, batch_input_metas=batch_input_metas, cfg=cfg) diff --git a/mmdeploy/codebase/mmdet3d/models/mvx_two_stage.py b/mmdeploy/codebase/mmdet3d/models/mvx_two_stage.py index 12df74ff5..b9d62e684 100644 --- a/mmdeploy/codebase/mmdet3d/models/mvx_two_stage.py +++ b/mmdeploy/codebase/mmdet3d/models/mvx_two_stage.py @@ -2,6 +2,7 @@ import torch from mmdeploy.core import FUNCTION_REWRITER +from mmdeploy.utils import get_ir_config @FUNCTION_REWRITER.register_rewriter( @@ -52,10 +53,21 @@ def mvxtwostagedetector__forward(self, inputs: list, **kwargs): inputs (list): input list comprises voxels, num_points and coors Returns: - bbox (Tensor): Decoded bbox after nms - scores (Tensor): bbox scores - labels (Tensor): bbox labels + tuple: A tuple of classification scores, bbox and direction + classification prediction. + + - cls_scores (list[Tensor]): Classification scores for all + scale levels, each is a 4D-tensor, the channels number + is num_base_priors * num_classes. + - bbox_preds (list[Tensor]): Box energies / deltas for all + scale levels, each is a 4D-tensor, the channels number + is num_base_priors * C. + - dir_cls_preds (list[Tensor|None]): Direction classification + predictions for all scale levels, each is a 4D-tensor, + the channels number is num_base_priors * 2. """ + ctx = FUNCTION_REWRITER.get_context() + deploy_cfg = ctx.cfg batch_inputs_dict = { 'voxels': { 'voxels': inputs[0], @@ -82,5 +94,18 @@ def mvxtwostagedetector__forward(self, inputs: list, **kwargs): dir_scores = torch.cat(dir_scores, dim=1) return scores, bbox_preds, dir_scores else: - cls_score, bbox_pred, dir_cls_pred = outs[0][0], outs[1][0], outs[2][0] - return cls_score, bbox_pred, dir_cls_pred + preds = [] + expect_names = [] + for i in range(len(outs[0])): + preds += [outs[0][i], outs[1][i], outs[2][i]] + expect_names += [ + f'cls_score{i}', f'bbox_pred{i}', f'dir_cls_pred{i}' + ] + # check if output_names is set correctly. + onnx_cfg = get_ir_config(deploy_cfg) + output_names = onnx_cfg['output_names'] + if output_names != list(expect_names): + raise RuntimeError(f'`output_names` should be {expect_names} ' + f'but given {output_names}\n' + f'Deploy config:\n{deploy_cfg.pretty_text}') + return tuple(preds)