diff --git a/configs/mmpose/pose-detection_simcc_ncnn_static-256x192.py b/configs/mmpose/pose-detection_simcc_ncnn_static-256x192.py new file mode 100644 index 000000000..ebbc0a5fa --- /dev/null +++ b/configs/mmpose/pose-detection_simcc_ncnn_static-256x192.py @@ -0,0 +1,3 @@ +_base_ = ['./pose-detection_static.py', '../_base_/backends/ncnn.py'] + +onnx_config = dict(input_shape=[192, 256], output_names=['simcc_x', 'simcc_y']) diff --git a/configs/mmpose/pose-detection_simcc_onnxruntime_dynamic.py b/configs/mmpose/pose-detection_simcc_onnxruntime_dynamic.py new file mode 100644 index 000000000..0bb991394 --- /dev/null +++ b/configs/mmpose/pose-detection_simcc_onnxruntime_dynamic.py @@ -0,0 +1,13 @@ +_base_ = ['./pose-detection_static.py', '../_base_/backends/onnxruntime.py'] + +onnx_config = dict( + input_shape=[192, 256], + output_names=['simcc_x', 'simcc_y'], + dynamic_axes={ + 'input': { + 0: 'batch', + }, + 'output': { + 0: 'batch' + } + }) diff --git a/configs/mmpose/pose-detection_simcc_sdk_static-256x192.py b/configs/mmpose/pose-detection_simcc_sdk_static-256x192.py new file mode 100644 index 000000000..460d8456a --- /dev/null +++ b/configs/mmpose/pose-detection_simcc_sdk_static-256x192.py @@ -0,0 +1,12 @@ +_base_ = ['./pose-detection_static.py', '../_base_/backends/sdk.py'] + +codebase_config = dict(model_type='sdk') +onnx_config = dict(output_names=['simcc_x', 'simcc_y']) + +backend_config = dict(pipeline=[ + dict(type='LoadImageFromFile'), + dict(type='GetBBoxCenterScale'), + dict(type='PackPoseInputs') +]) + +ext_info = dict(image_size=[192, 256], padding=1.25) diff --git a/configs/mmpose/pose-detection_simcc_tensorrt_dynamic-256x192.py b/configs/mmpose/pose-detection_simcc_tensorrt_dynamic-256x192.py new file mode 100644 index 000000000..5d0544ff9 --- /dev/null +++ b/configs/mmpose/pose-detection_simcc_tensorrt_dynamic-256x192.py @@ -0,0 +1,24 @@ +_base_ = ['./pose-detection_static.py', '../_base_/backends/tensorrt.py'] + +onnx_config = dict( + input_shape=[192, 256], + output_names=['simcc_x', 'simcc_y'], + dynamic_axes={ + 'input': { + 0: 'batch', + }, + 'output': { + 0: 'batch' + } + }) + +backend_config = dict( + common_config=dict(max_workspace_size=1 << 30), + model_inputs=[ + dict( + input_shapes=dict( + input=dict( + min_shape=[1, 3, 256, 192], + opt_shape=[2, 3, 256, 192], + max_shape=[4, 3, 256, 192]))) + ]) diff --git a/csrc/mmdeploy/codebase/mmpose/simcc_label.cpp b/csrc/mmdeploy/codebase/mmpose/simcc_label.cpp new file mode 100644 index 000000000..ce38c056d --- /dev/null +++ b/csrc/mmdeploy/codebase/mmpose/simcc_label.cpp @@ -0,0 +1,123 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include +#include + +#include "mmdeploy/core/device.h" +#include "mmdeploy/core/registry.h" +#include "mmdeploy/core/serialization.h" +#include "mmdeploy/core/tensor.h" +#include "mmdeploy/core/utils/device_utils.h" +#include "mmdeploy/core/utils/formatter.h" +#include "mmdeploy/core/value.h" +#include "mmdeploy/experimental/module_adapter.h" +#include "mmpose.h" +#include "opencv_utils.h" + +namespace mmdeploy::mmpose { + +using std::string; +using std::vector; + +template +struct _LoopBody : public cv::ParallelLoopBody { + F f_; + _LoopBody(F f) : f_(std::move(f)) {} + void operator()(const cv::Range& range) const override { f_(range); } +}; + +class SimCCLabelDecode : public MMPose { + public: + explicit SimCCLabelDecode(const Value& config) : MMPose(config) { + if (config.contains("params")) { + auto& params = config["params"]; + flip_test_ = params.value("flip_test", flip_test_); + simcc_split_ratio_ = params.value("simcc_split_ratio", simcc_split_ratio_); + if (params.contains("input_size")) { + from_value(params["input_size"], input_size_); + } + } + } + + Result operator()(const Value& _data, const Value& _prob) { + MMDEPLOY_DEBUG("preprocess_result: {}", _data); + MMDEPLOY_DEBUG("inference_result: {}", _prob); + + Device cpu_device{"cpu"}; + OUTCOME_TRY(auto simcc_x, + MakeAvailableOnDevice(_prob["simcc_x"].get(), cpu_device, stream())); + OUTCOME_TRY(auto simcc_y, + MakeAvailableOnDevice(_prob["simcc_y"].get(), cpu_device, stream())); + OUTCOME_TRY(stream().Wait()); + if (!(simcc_x.shape().size() == 3 && simcc_x.data_type() == DataType::kFLOAT)) { + MMDEPLOY_ERROR("unsupported `simcc_x` tensor, shape: {}, dtype: {}", simcc_x.shape(), + (int)simcc_x.data_type()); + return Status(eNotSupported); + } + + auto& img_metas = _data["img_metas"]; + + Tensor keypoints({Device{"cpu"}, DataType::kFLOAT, {simcc_x.shape(0), simcc_x.shape(1), 2}}); + Tensor scores({Device{"cpu"}, DataType::kFLOAT, {simcc_x.shape(0), simcc_x.shape(1), 1}}); + get_simcc_maximum(simcc_x, simcc_y, keypoints, scores); + + std::vector center; + std::vector scale; + from_value(img_metas["center"], center); + from_value(img_metas["scale"], scale); + PoseDetectorOutput output; + + float* keypoints_data = keypoints.data(); + float* scores_data = scores.data(); + float scale_value = 200, x = -1, y = -1, s = 0; + for (int i = 0; i < simcc_x.shape(1); i++) { + x = *(keypoints_data + 0) / simcc_split_ratio_; + y = *(keypoints_data + 1) / simcc_split_ratio_; + x = x * scale[0] * scale_value / input_size_[0] + center[0] - scale[0] * scale_value * 0.5; + y = y * scale[1] * scale_value / input_size_[1] + center[1] - scale[1] * scale_value * 0.5; + s = *(scores_data + 0); + output.key_points.push_back({{x, y}, s}); + keypoints_data += 2; + scores_data += 1; + } + return to_value(output); + } + + void get_simcc_maximum(const Tensor& simcc_x, const Tensor& simcc_y, Tensor& keypoints, + Tensor& scores) { + int K = simcc_x.shape(1); + int N_x = simcc_x.shape(2); + int N_y = simcc_y.shape(2); + + cv::parallel_for_(cv::Range(0, K), _LoopBody{[&](const cv::Range& r) { + for (int i = r.start; i < r.end; i++) { + float* data_x = const_cast(simcc_x.data()) + i * N_x; + float* data_y = const_cast(simcc_y.data()) + i * N_y; + cv::Mat mat_x = cv::Mat(N_x, 1, CV_32FC1, data_x); + cv::Mat mat_y = cv::Mat(N_y, 1, CV_32FC1, data_y); + double min_val_x, max_val_x, min_val_y, max_val_y; + cv::Point min_loc_x, max_loc_x, min_loc_y, max_loc_y; + cv::minMaxLoc(mat_x, &min_val_x, &max_val_x, &min_loc_x, &max_loc_x); + cv::minMaxLoc(mat_y, &min_val_y, &max_val_y, &min_loc_y, &max_loc_y); + float s = max_val_x > max_val_y ? max_val_y : max_val_x; + float x = s > 0 ? max_loc_x.y : -1.0; + float y = s > 0 ? max_loc_y.y : -1.0; + float* keypoints_data = keypoints.data() + i * 2; + float* scores_data = scores.data() + i; + *(scores_data) = s; + *(keypoints_data + 0) = x; + *(keypoints_data + 1) = y; + } + }}); + } + + private: + bool flip_test_{false}; + bool shift_heatmap_{false}; + float simcc_split_ratio_{2.0}; + std::vector input_size_{192, 256}; +}; + +REGISTER_CODEBASE_COMPONENT(MMPose, SimCCLabelDecode); + +} // namespace mmdeploy::mmpose diff --git a/docs/en/03-benchmark/benchmark.md b/docs/en/03-benchmark/benchmark.md index 56e57d3e8..744e4dbeb 100644 --- a/docs/en/03-benchmark/benchmark.md +++ b/docs/en/03-benchmark/benchmark.md @@ -1710,6 +1710,27 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](../ - 0.774 + + SimCC + Pose Detection + COCO + AP + 0.607 + - + 0.608 + - + - + - + + + AR + 0.668 + - + 0.672 + - + - + - + diff --git a/docs/en/03-benchmark/supported_models.md b/docs/en/03-benchmark/supported_models.md index e9e9f374a..9c42ef271 100644 --- a/docs/en/03-benchmark/supported_models.md +++ b/docs/en/03-benchmark/supported_models.md @@ -74,6 +74,7 @@ The table below lists the models that are guaranteed to be exportable to other b | MSPN | MMPose | N | Y | Y | Y | N | Y | N | N | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#mspn-arxiv-2019) | | LiteHRNet | MMPose | N | Y | Y | N | N | Y | N | N | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#litehrnet-cvpr-2021) | | Hourglass | MMPose | N | Y | Y | Y | N | Y | N | N | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#hourglass-eccv-2016) | +| SimCC | MMPose | N | Y | Y | Y | N | N | N | N | [config](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/algorithms.html#simcc-eccv-2022) | | PointPillars | MMDetection3d | ? | Y | Y | N | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/pointpillars) | | CenterPoint (pillar) | MMDetection3d | ? | Y | Y | N | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/centerpoint) | | RotatedRetinaNet | RotatedDetection | N | Y | Y | N | N | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/rotated_retinanet/README.md) | diff --git a/docs/en/04-supported-codebases/mmpose.md b/docs/en/04-supported-codebases/mmpose.md index 4c910c80e..690ce91f2 100644 --- a/docs/en/04-supported-codebases/mmpose.md +++ b/docs/en/04-supported-codebases/mmpose.md @@ -1,14 +1,15 @@ # MMPose Deployment -- [Installation](#installation) - - [Install mmpose](#install-mmpose) - - [Install mmdeploy](#install-mmdeploy) -- [Convert model](#convert-model) -- [Model specification](#model-specification) -- [Model inference](#model-inference) - - [Backend model inference](#backend-model-inference) - - [SDK model inference](#sdk-model-inference) -- [Supported models](#supported-models) +- [MMPose Deployment](#mmpose-deployment) + - [Installation](#installation) + - [Install mmpose](#install-mmpose) + - [Install mmdeploy](#install-mmdeploy) + - [Convert model](#convert-model) + - [Model specification](#model-specification) + - [Model inference](#model-inference) + - [Backend model inference](#backend-model-inference) + - [SDK model inference](#sdk-model-inference) + - [Supported models](#supported-models) ______________________________________________________________________ @@ -151,8 +152,10 @@ TODO ## Supported models -| Model | Task | ONNX Runtime | TensorRT | ncnn | PPLNN | OpenVINO | -| :---------------------------------------------------------------------------------------------------- | :------------ | :----------: | :------: | :--: | :---: | :------: | -| [HRNet](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/backbones.html#hrnet-cvpr-2019) | PoseDetection | Y | Y | Y | N | Y | -| [MSPN](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/backbones.html#mspn-arxiv-2019) | PoseDetection | Y | Y | Y | N | Y | -| [LiteHRNet](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/backbones.html#litehrnet-cvpr-2021) | PoseDetection | Y | Y | Y | N | Y | +| Model | Task | ONNX Runtime | TensorRT | ncnn | PPLNN | OpenVINO | +| :----------------------------------------------------------------------------------------------------- | :------------ | :----------: | :------: | :--: | :---: | :------: | +| [HRNet](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/backbones.html#hrnet-cvpr-2019) | PoseDetection | Y | Y | Y | N | Y | +| [MSPN](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/backbones.html#mspn-arxiv-2019) | PoseDetection | Y | Y | Y | N | Y | +| [LiteHRNet](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/backbones.html#litehrnet-cvpr-2021) | PoseDetection | Y | Y | Y | N | Y | +| [Hourglass](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/algorithms.html#hourglass-eccv-2016) | PoseDetection | Y | Y | Y | N | Y | +| [SimCC](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/algorithms.html#simcc-eccv-2022) | PoseDetection | Y | Y | Y | N | N | diff --git a/docs/zh_cn/03-benchmark/benchmark.md b/docs/zh_cn/03-benchmark/benchmark.md index 9bf5fac13..63ce9f678 100644 --- a/docs/zh_cn/03-benchmark/benchmark.md +++ b/docs/zh_cn/03-benchmark/benchmark.md @@ -1686,6 +1686,27 @@ GPU: ncnn, TensorRT, PPLNN - 0.774 + + SimCC + Pose Detection + COCO + AP + 0.607 + - + 0.608 + - + - + - + + + AR + 0.668 + - + 0.672 + - + - + - + diff --git a/docs/zh_cn/03-benchmark/supported_models.md b/docs/zh_cn/03-benchmark/supported_models.md index 13e2c2b14..77e327d4c 100644 --- a/docs/zh_cn/03-benchmark/supported_models.md +++ b/docs/zh_cn/03-benchmark/supported_models.md @@ -74,6 +74,7 @@ | MSPN | MMPose | N | Y | Y | Y | N | Y | N | N | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#mspn-arxiv-2019) | | LiteHRNet | MMPose | N | Y | Y | N | N | Y | N | N | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#litehrnet-cvpr-2021) | | Hourglass | MMPose | N | Y | Y | Y | N | Y | N | N | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#hourglass-eccv-2016) | +| SimCC | MMPose | N | Y | Y | Y | N | N | N | N | [config](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/algorithms.html#simcc-eccv-2022) | | PointPillars | MMDetection3d | ? | Y | Y | N | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/pointpillars) | | CenterPoint (pillar) | MMDetection3d | ? | Y | Y | N | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/centerpoint) | | RotatedRetinaNet | RotatedDetection | N | Y | Y | N | N | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/rotated_retinanet/README.md) | diff --git a/docs/zh_cn/04-supported-codebases/mmpose.md b/docs/zh_cn/04-supported-codebases/mmpose.md index 297b3783f..97c4933da 100644 --- a/docs/zh_cn/04-supported-codebases/mmpose.md +++ b/docs/zh_cn/04-supported-codebases/mmpose.md @@ -1,14 +1,15 @@ # MMPose 模型部署 -- [安装](#安装) - - [安装 mmcls](#安装-mmpose) - - [安装 mmdeploy](#安装-mmdeploy) -- [模型转换](#模型转换) -- [模型规范](#模型规范) -- [模型推理](#模型推理) - - [后端模型推理](#后端模型推理) - - [SDK 模型推理](#sdk-模型推理) -- [模型支持列表](#模型支持列表) +- [MMPose 模型部署](#mmpose-模型部署) + - [安装](#安装) + - [安装 mmpose](#安装-mmpose) + - [安装 mmdeploy](#安装-mmdeploy) + - [模型转换](#模型转换) + - [模型规范](#模型规范) + - [模型推理](#模型推理) + - [后端模型推理](#后端模型推理) + - [SDK 模型推理](#sdk-模型推理) + - [模型支持列表](#模型支持列表) ______________________________________________________________________ @@ -155,8 +156,10 @@ task_processor.visualize( ## 模型支持列表 -| Model | Task | ONNX Runtime | TensorRT | ncnn | PPLNN | OpenVINO | -| :---------------------------------------------------------------------------------------------------- | :------------ | :----------: | :------: | :--: | :---: | :------: | -| [HRNet](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/backbones.html#hrnet-cvpr-2019) | PoseDetection | Y | Y | Y | N | Y | -| [MSPN](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/backbones.html#mspn-arxiv-2019) | PoseDetection | Y | Y | Y | N | Y | -| [LiteHRNet](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/backbones.html#litehrnet-cvpr-2021) | PoseDetection | Y | Y | Y | N | Y | +| Model | Task | ONNX Runtime | TensorRT | ncnn | PPLNN | OpenVINO | +| :----------------------------------------------------------------------------------------------------- | :------------ | :----------: | :------: | :--: | :---: | :------: | +| [HRNet](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/backbones.html#hrnet-cvpr-2019) | PoseDetection | Y | Y | Y | N | Y | +| [MSPN](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/backbones.html#mspn-arxiv-2019) | PoseDetection | Y | Y | Y | N | Y | +| [LiteHRNet](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/backbones.html#litehrnet-cvpr-2021) | PoseDetection | Y | Y | Y | N | Y | +| [Hourglass](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/algorithms.html#hourglass-eccv-2016) | PoseDetection | Y | Y | Y | N | Y | +| [SimCC](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/algorithms.html#simcc-eccv-2022) | PoseDetection | Y | Y | Y | N | N | diff --git a/mmdeploy/codebase/mmpose/deploy/pose_detection.py b/mmdeploy/codebase/mmpose/deploy/pose_detection.py index aeca09d0e..2f48274f1 100644 --- a/mmdeploy/codebase/mmpose/deploy/pose_detection.py +++ b/mmdeploy/codebase/mmpose/deploy/pose_detection.py @@ -333,18 +333,25 @@ class PoseDetection(BaseTask): def get_postprocess(self, *args, **kwargs) -> Dict: """Get the postprocess information for SDK.""" + codec = self.model_cfg.codec + if isinstance(codec, (list, tuple)): + codec = codec[-1] component = 'UNKNOWN' params = copy.deepcopy(self.model_cfg.model.test_cfg) + params.update(codec) if self.model_cfg.model.type == 'TopdownPoseEstimator': - head_type = self.model_cfg.model.head.type - if head_type == 'HeatmapHead': + component = 'TopdownHeatmapSimpleHeadDecode' + if codec.type == 'MSRAHeatmap': params['post_process'] = 'default' - component = 'TopdownHeatmapSimpleHeadDecode' - elif head_type == 'MSPNHead': + elif codec.type == 'UDPHeatmap': + params['post_process'] = 'default' + params['use_udp'] = True + elif codec.type == 'MegviiHeatmap': params['post_process'] = 'megvii' params['modulate_kernel'] = self.model_cfg.kernel_sizes[-1] - component = 'TopdownHeatmapMSMUHeadDecode' + elif codec.type == 'SimCCLabel': + component = 'SimCCLabelDecode' else: - raise RuntimeError(f'Unsupported head type: {head_type}') + raise RuntimeError(f'Unsupported codecs type: {codec.type}') postprocess = dict(params=params, type=component) return postprocess diff --git a/mmdeploy/codebase/mmpose/deploy/pose_detection_model.py b/mmdeploy/codebase/mmpose/deploy/pose_detection_model.py index 1b71b250f..e61e4a627 100644 --- a/mmdeploy/codebase/mmpose/deploy/pose_detection_model.py +++ b/mmdeploy/codebase/mmpose/deploy/pose_detection_model.py @@ -101,24 +101,14 @@ class End2EndModel(BaseBackendModel): inputs = inputs.contiguous().to(self.device) batch_outputs = self.wrapper({self.input_name: inputs}) batch_outputs = self.wrapper.output_to_list(batch_outputs) - batch_heatmaps = batch_outputs[0] - # flip test - test_cfg = self.model_cfg.model.test_cfg - if test_cfg.get('flip_test', False): - from mmpose.models.utils.tta import flip_heatmaps - batch_inputs_flip = inputs.flip(-1).contiguous() - batch_outputs_flip = self.wrapper( - {self.input_name: batch_inputs_flip}) - batch_heatmaps_flip = self.wrapper.output_to_list( - batch_outputs_flip)[0] - flip_indices = data_samples[0].metainfo['flip_indices'] - batch_heatmaps_flip = flip_heatmaps( - batch_heatmaps_flip, - flip_mode=test_cfg.get('flip_mode', 'heatmap'), - flip_indices=flip_indices, - shift_heatmap=test_cfg.get('shift_heatmap', False)) - batch_heatmaps = (batch_heatmaps + batch_heatmaps_flip) * 0.5 - preds = self.head.decode(batch_heatmaps) + codec = self.model_cfg.codec + if isinstance(codec, (list, tuple)): + codec = codec[-1] + if codec.type == 'SimCCLabel': + batch_pred_x, batch_pred_y = batch_outputs + preds = self.head.decode((batch_pred_x, batch_pred_y)) + else: + preds = self.head.decode(batch_outputs[0]) results = self.pack_result(preds, data_samples) return results diff --git a/mmdeploy/codebase/mmpose/models/heads/__init__.py b/mmdeploy/codebase/mmpose/models/heads/__init__.py index 85e614269..d2894e698 100644 --- a/mmdeploy/codebase/mmpose/models/heads/__init__.py +++ b/mmdeploy/codebase/mmpose/models/heads/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from . import heatmap_head, mspn_head, regression_head +from . import heatmap_head, mspn_head, regression_head, simcc_head -__all__ = ['heatmap_head', 'mspn_head', 'regression_head'] +__all__ = ['heatmap_head', 'mspn_head', 'regression_head', 'simcc_head'] diff --git a/mmdeploy/codebase/mmpose/models/heads/simcc_head.py b/mmdeploy/codebase/mmpose/models/heads/simcc_head.py new file mode 100644 index 000000000..fbdf9e132 --- /dev/null +++ b/mmdeploy/codebase/mmpose/models/heads/simcc_head.py @@ -0,0 +1,22 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter( + 'mmpose.models.heads.heatmap_heads.SimCCHead.predict') +def simcc_head__predict(ctx, self, feats, batch_data_samples, test_cfg=None): + """Rewrite `predict` of HeatmapHead for default backend. + + 1. skip decoding and return output tensor directly. + + Args: + feats (tuple[Tensor]): Input features. + batch_data_samples (list[SampleList]): Data samples contain + image meta information. + test_cfg (ConfigType): test config. + + Returns: + output_heatmap (torch.Tensor): Output heatmaps. + """ + simcc_x, simcc_y = self.forward(feats) + return simcc_x, simcc_y diff --git a/mmdeploy/pytorch/functions/flatten.py b/mmdeploy/pytorch/functions/flatten.py index b6165aea2..d8d40dd54 100644 --- a/mmdeploy/pytorch/functions/flatten.py +++ b/mmdeploy/pytorch/functions/flatten.py @@ -5,6 +5,10 @@ from mmdeploy.core import FUNCTION_REWRITER from mmdeploy.utils import Backend +@FUNCTION_REWRITER.register_rewriter( + func_name='torch.Tensor.flatten', backend=Backend.NCNN.value) +@FUNCTION_REWRITER.register_rewriter( + func_name='torch.flatten', backend=Backend.NCNN.value) @FUNCTION_REWRITER.register_rewriter( func_name='torch.Tensor.flatten', backend=Backend.COREML.value) @FUNCTION_REWRITER.register_rewriter( diff --git a/tests/regression/mmpose.yml b/tests/regression/mmpose.yml index 16ffb864e..c2c85a36e 100644 --- a/tests/regression/mmpose.yml +++ b/tests/regression/mmpose.yml @@ -124,3 +124,18 @@ models: # - *pipeline_trt_static_fp32_256x256 # - *pipeline_ncnn_static_fp32_256x256 # - *pipeline_openvino_static_fp32_256x256 + +# TODO: no mobilenetv2_coco.yml in latest mmpose, enable this later +# - name: SimCC +# metafile: configs/body_2d_keypoint/simcc/coco/mobilenetv2_coco.yml +# model_configs: +# - configs/body_2d_keypoint/simcc/coco/simcc_mobilenetv2_wo-deconv-8xb64-210e_coco-256x192.py +# pipelines: +# - convert_image: *convert_image +# deploy_config: configs/mmpose/pose-detection_simcc_onnxruntime_dynamic.py +# - convert_image: *convert_image +# deploy_config: configs/mmpose/pose-detection_simcc_tensorrt_dynamic-256x192.py +# backend_test: *default_backend_test +# sdk_config: configs/mmpose/pose-detection_simcc_sdk_static-256x192.py +# - convert_image: *convert_image +# deploy_config: configs/mmpose/pose-detection_simcc_ncnn_static-256x192.py