[Feat]: Support simcc from mmpose (#1187)
* add rewriting for simcc * add simccdecode for sdk * remove debug lines * fix cpp lint * move simcc decode to sdk * add simcc sdk config * update docs and regression yaml * update ymlpull/1278/head
parent
c4285aed4a
commit
0efc9e3c6d
|
@ -0,0 +1,3 @@
|
|||
_base_ = ['./pose-detection_static.py', '../_base_/backends/ncnn.py']
|
||||
|
||||
onnx_config = dict(input_shape=[192, 256], output_names=['simcc_x', 'simcc_y'])
|
|
@ -0,0 +1,13 @@
|
|||
_base_ = ['./pose-detection_static.py', '../_base_/backends/onnxruntime.py']
|
||||
|
||||
onnx_config = dict(
|
||||
input_shape=[192, 256],
|
||||
output_names=['simcc_x', 'simcc_y'],
|
||||
dynamic_axes={
|
||||
'input': {
|
||||
0: 'batch',
|
||||
},
|
||||
'output': {
|
||||
0: 'batch'
|
||||
}
|
||||
})
|
|
@ -0,0 +1,12 @@
|
|||
_base_ = ['./pose-detection_static.py', '../_base_/backends/sdk.py']
|
||||
|
||||
codebase_config = dict(model_type='sdk')
|
||||
onnx_config = dict(output_names=['simcc_x', 'simcc_y'])
|
||||
|
||||
backend_config = dict(pipeline=[
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='GetBBoxCenterScale'),
|
||||
dict(type='PackPoseInputs')
|
||||
])
|
||||
|
||||
ext_info = dict(image_size=[192, 256], padding=1.25)
|
|
@ -0,0 +1,24 @@
|
|||
_base_ = ['./pose-detection_static.py', '../_base_/backends/tensorrt.py']
|
||||
|
||||
onnx_config = dict(
|
||||
input_shape=[192, 256],
|
||||
output_names=['simcc_x', 'simcc_y'],
|
||||
dynamic_axes={
|
||||
'input': {
|
||||
0: 'batch',
|
||||
},
|
||||
'output': {
|
||||
0: 'batch'
|
||||
}
|
||||
})
|
||||
|
||||
backend_config = dict(
|
||||
common_config=dict(max_workspace_size=1 << 30),
|
||||
model_inputs=[
|
||||
dict(
|
||||
input_shapes=dict(
|
||||
input=dict(
|
||||
min_shape=[1, 3, 256, 192],
|
||||
opt_shape=[2, 3, 256, 192],
|
||||
max_shape=[4, 3, 256, 192])))
|
||||
])
|
|
@ -0,0 +1,123 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#include <cctype>
|
||||
#include <opencv2/imgproc.hpp>
|
||||
|
||||
#include "mmdeploy/core/device.h"
|
||||
#include "mmdeploy/core/registry.h"
|
||||
#include "mmdeploy/core/serialization.h"
|
||||
#include "mmdeploy/core/tensor.h"
|
||||
#include "mmdeploy/core/utils/device_utils.h"
|
||||
#include "mmdeploy/core/utils/formatter.h"
|
||||
#include "mmdeploy/core/value.h"
|
||||
#include "mmdeploy/experimental/module_adapter.h"
|
||||
#include "mmpose.h"
|
||||
#include "opencv_utils.h"
|
||||
|
||||
namespace mmdeploy::mmpose {
|
||||
|
||||
using std::string;
|
||||
using std::vector;
|
||||
|
||||
template <class F>
|
||||
struct _LoopBody : public cv::ParallelLoopBody {
|
||||
F f_;
|
||||
_LoopBody(F f) : f_(std::move(f)) {}
|
||||
void operator()(const cv::Range& range) const override { f_(range); }
|
||||
};
|
||||
|
||||
class SimCCLabelDecode : public MMPose {
|
||||
public:
|
||||
explicit SimCCLabelDecode(const Value& config) : MMPose(config) {
|
||||
if (config.contains("params")) {
|
||||
auto& params = config["params"];
|
||||
flip_test_ = params.value("flip_test", flip_test_);
|
||||
simcc_split_ratio_ = params.value("simcc_split_ratio", simcc_split_ratio_);
|
||||
if (params.contains("input_size")) {
|
||||
from_value(params["input_size"], input_size_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Result<Value> operator()(const Value& _data, const Value& _prob) {
|
||||
MMDEPLOY_DEBUG("preprocess_result: {}", _data);
|
||||
MMDEPLOY_DEBUG("inference_result: {}", _prob);
|
||||
|
||||
Device cpu_device{"cpu"};
|
||||
OUTCOME_TRY(auto simcc_x,
|
||||
MakeAvailableOnDevice(_prob["simcc_x"].get<Tensor>(), cpu_device, stream()));
|
||||
OUTCOME_TRY(auto simcc_y,
|
||||
MakeAvailableOnDevice(_prob["simcc_y"].get<Tensor>(), cpu_device, stream()));
|
||||
OUTCOME_TRY(stream().Wait());
|
||||
if (!(simcc_x.shape().size() == 3 && simcc_x.data_type() == DataType::kFLOAT)) {
|
||||
MMDEPLOY_ERROR("unsupported `simcc_x` tensor, shape: {}, dtype: {}", simcc_x.shape(),
|
||||
(int)simcc_x.data_type());
|
||||
return Status(eNotSupported);
|
||||
}
|
||||
|
||||
auto& img_metas = _data["img_metas"];
|
||||
|
||||
Tensor keypoints({Device{"cpu"}, DataType::kFLOAT, {simcc_x.shape(0), simcc_x.shape(1), 2}});
|
||||
Tensor scores({Device{"cpu"}, DataType::kFLOAT, {simcc_x.shape(0), simcc_x.shape(1), 1}});
|
||||
get_simcc_maximum(simcc_x, simcc_y, keypoints, scores);
|
||||
|
||||
std::vector<float> center;
|
||||
std::vector<float> scale;
|
||||
from_value(img_metas["center"], center);
|
||||
from_value(img_metas["scale"], scale);
|
||||
PoseDetectorOutput output;
|
||||
|
||||
float* keypoints_data = keypoints.data<float>();
|
||||
float* scores_data = scores.data<float>();
|
||||
float scale_value = 200, x = -1, y = -1, s = 0;
|
||||
for (int i = 0; i < simcc_x.shape(1); i++) {
|
||||
x = *(keypoints_data + 0) / simcc_split_ratio_;
|
||||
y = *(keypoints_data + 1) / simcc_split_ratio_;
|
||||
x = x * scale[0] * scale_value / input_size_[0] + center[0] - scale[0] * scale_value * 0.5;
|
||||
y = y * scale[1] * scale_value / input_size_[1] + center[1] - scale[1] * scale_value * 0.5;
|
||||
s = *(scores_data + 0);
|
||||
output.key_points.push_back({{x, y}, s});
|
||||
keypoints_data += 2;
|
||||
scores_data += 1;
|
||||
}
|
||||
return to_value(output);
|
||||
}
|
||||
|
||||
void get_simcc_maximum(const Tensor& simcc_x, const Tensor& simcc_y, Tensor& keypoints,
|
||||
Tensor& scores) {
|
||||
int K = simcc_x.shape(1);
|
||||
int N_x = simcc_x.shape(2);
|
||||
int N_y = simcc_y.shape(2);
|
||||
|
||||
cv::parallel_for_(cv::Range(0, K), _LoopBody{[&](const cv::Range& r) {
|
||||
for (int i = r.start; i < r.end; i++) {
|
||||
float* data_x = const_cast<float*>(simcc_x.data<float>()) + i * N_x;
|
||||
float* data_y = const_cast<float*>(simcc_y.data<float>()) + i * N_y;
|
||||
cv::Mat mat_x = cv::Mat(N_x, 1, CV_32FC1, data_x);
|
||||
cv::Mat mat_y = cv::Mat(N_y, 1, CV_32FC1, data_y);
|
||||
double min_val_x, max_val_x, min_val_y, max_val_y;
|
||||
cv::Point min_loc_x, max_loc_x, min_loc_y, max_loc_y;
|
||||
cv::minMaxLoc(mat_x, &min_val_x, &max_val_x, &min_loc_x, &max_loc_x);
|
||||
cv::minMaxLoc(mat_y, &min_val_y, &max_val_y, &min_loc_y, &max_loc_y);
|
||||
float s = max_val_x > max_val_y ? max_val_y : max_val_x;
|
||||
float x = s > 0 ? max_loc_x.y : -1.0;
|
||||
float y = s > 0 ? max_loc_y.y : -1.0;
|
||||
float* keypoints_data = keypoints.data<float>() + i * 2;
|
||||
float* scores_data = scores.data<float>() + i;
|
||||
*(scores_data) = s;
|
||||
*(keypoints_data + 0) = x;
|
||||
*(keypoints_data + 1) = y;
|
||||
}
|
||||
}});
|
||||
}
|
||||
|
||||
private:
|
||||
bool flip_test_{false};
|
||||
bool shift_heatmap_{false};
|
||||
float simcc_split_ratio_{2.0};
|
||||
std::vector<int> input_size_{192, 256};
|
||||
};
|
||||
|
||||
REGISTER_CODEBASE_COMPONENT(MMPose, SimCCLabelDecode);
|
||||
|
||||
} // namespace mmdeploy::mmpose
|
|
@ -1710,6 +1710,27 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](../
|
|||
<td align="center">-</td>
|
||||
<td align="center">0.774</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center" rowspan="2"><a href="https://github.com/open-mmlab/mmpose/blob/1.x/configs/body_2d_keypoint/simcc/coco/simcc_mobilenetv2_wo-deconv-8xb64-210e_coco-256x192.py">SimCC</a></td>
|
||||
<td align="center" rowspan="2">Pose Detection</td>
|
||||
<td align="center" rowspan="2">COCO</td>
|
||||
<td align="center">AP</td>
|
||||
<td align="center">0.607</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">0.608</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">AR</td>
|
||||
<td align="center">0.668</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">0.672</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
|
|
@ -74,6 +74,7 @@ The table below lists the models that are guaranteed to be exportable to other b
|
|||
| MSPN | MMPose | N | Y | Y | Y | N | Y | N | N | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#mspn-arxiv-2019) |
|
||||
| LiteHRNet | MMPose | N | Y | Y | N | N | Y | N | N | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#litehrnet-cvpr-2021) |
|
||||
| Hourglass | MMPose | N | Y | Y | Y | N | Y | N | N | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#hourglass-eccv-2016) |
|
||||
| SimCC | MMPose | N | Y | Y | Y | N | N | N | N | [config](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/algorithms.html#simcc-eccv-2022) |
|
||||
| PointPillars | MMDetection3d | ? | Y | Y | N | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/pointpillars) |
|
||||
| CenterPoint (pillar) | MMDetection3d | ? | Y | Y | N | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/centerpoint) |
|
||||
| RotatedRetinaNet | RotatedDetection | N | Y | Y | N | N | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/rotated_retinanet/README.md) |
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
# MMPose Deployment
|
||||
|
||||
- [Installation](#installation)
|
||||
- [Install mmpose](#install-mmpose)
|
||||
- [Install mmdeploy](#install-mmdeploy)
|
||||
- [Convert model](#convert-model)
|
||||
- [Model specification](#model-specification)
|
||||
- [Model inference](#model-inference)
|
||||
- [Backend model inference](#backend-model-inference)
|
||||
- [SDK model inference](#sdk-model-inference)
|
||||
- [Supported models](#supported-models)
|
||||
- [MMPose Deployment](#mmpose-deployment)
|
||||
- [Installation](#installation)
|
||||
- [Install mmpose](#install-mmpose)
|
||||
- [Install mmdeploy](#install-mmdeploy)
|
||||
- [Convert model](#convert-model)
|
||||
- [Model specification](#model-specification)
|
||||
- [Model inference](#model-inference)
|
||||
- [Backend model inference](#backend-model-inference)
|
||||
- [SDK model inference](#sdk-model-inference)
|
||||
- [Supported models](#supported-models)
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
|
@ -151,8 +152,10 @@ TODO
|
|||
|
||||
## Supported models
|
||||
|
||||
| Model | Task | ONNX Runtime | TensorRT | ncnn | PPLNN | OpenVINO |
|
||||
| :---------------------------------------------------------------------------------------------------- | :------------ | :----------: | :------: | :--: | :---: | :------: |
|
||||
| [HRNet](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/backbones.html#hrnet-cvpr-2019) | PoseDetection | Y | Y | Y | N | Y |
|
||||
| [MSPN](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/backbones.html#mspn-arxiv-2019) | PoseDetection | Y | Y | Y | N | Y |
|
||||
| [LiteHRNet](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/backbones.html#litehrnet-cvpr-2021) | PoseDetection | Y | Y | Y | N | Y |
|
||||
| Model | Task | ONNX Runtime | TensorRT | ncnn | PPLNN | OpenVINO |
|
||||
| :----------------------------------------------------------------------------------------------------- | :------------ | :----------: | :------: | :--: | :---: | :------: |
|
||||
| [HRNet](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/backbones.html#hrnet-cvpr-2019) | PoseDetection | Y | Y | Y | N | Y |
|
||||
| [MSPN](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/backbones.html#mspn-arxiv-2019) | PoseDetection | Y | Y | Y | N | Y |
|
||||
| [LiteHRNet](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/backbones.html#litehrnet-cvpr-2021) | PoseDetection | Y | Y | Y | N | Y |
|
||||
| [Hourglass](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/algorithms.html#hourglass-eccv-2016) | PoseDetection | Y | Y | Y | N | Y |
|
||||
| [SimCC](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/algorithms.html#simcc-eccv-2022) | PoseDetection | Y | Y | Y | N | N |
|
||||
|
|
|
@ -1686,6 +1686,27 @@ GPU: ncnn, TensorRT, PPLNN
|
|||
<td align="center">-</td>
|
||||
<td align="center">0.774</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center" rowspan="2"><a href="https://github.com/open-mmlab/mmpose/blob/1.x/configs/body_2d_keypoint/simcc/coco/simcc_mobilenetv2_wo-deconv-8xb64-210e_coco-256x192.py">SimCC</a></td>
|
||||
<td align="center" rowspan="2">Pose Detection</td>
|
||||
<td align="center" rowspan="2">COCO</td>
|
||||
<td align="center">AP</td>
|
||||
<td align="center">0.607</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">0.608</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">AR</td>
|
||||
<td align="center">0.668</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">0.672</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
|
|
@ -74,6 +74,7 @@
|
|||
| MSPN | MMPose | N | Y | Y | Y | N | Y | N | N | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#mspn-arxiv-2019) |
|
||||
| LiteHRNet | MMPose | N | Y | Y | N | N | Y | N | N | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#litehrnet-cvpr-2021) |
|
||||
| Hourglass | MMPose | N | Y | Y | Y | N | Y | N | N | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#hourglass-eccv-2016) |
|
||||
| SimCC | MMPose | N | Y | Y | Y | N | N | N | N | [config](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/algorithms.html#simcc-eccv-2022) |
|
||||
| PointPillars | MMDetection3d | ? | Y | Y | N | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/pointpillars) |
|
||||
| CenterPoint (pillar) | MMDetection3d | ? | Y | Y | N | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/centerpoint) |
|
||||
| RotatedRetinaNet | RotatedDetection | N | Y | Y | N | N | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/rotated_retinanet/README.md) |
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
# MMPose 模型部署
|
||||
|
||||
- [安装](#安装)
|
||||
- [安装 mmcls](#安装-mmpose)
|
||||
- [安装 mmdeploy](#安装-mmdeploy)
|
||||
- [模型转换](#模型转换)
|
||||
- [模型规范](#模型规范)
|
||||
- [模型推理](#模型推理)
|
||||
- [后端模型推理](#后端模型推理)
|
||||
- [SDK 模型推理](#sdk-模型推理)
|
||||
- [模型支持列表](#模型支持列表)
|
||||
- [MMPose 模型部署](#mmpose-模型部署)
|
||||
- [安装](#安装)
|
||||
- [安装 mmpose](#安装-mmpose)
|
||||
- [安装 mmdeploy](#安装-mmdeploy)
|
||||
- [模型转换](#模型转换)
|
||||
- [模型规范](#模型规范)
|
||||
- [模型推理](#模型推理)
|
||||
- [后端模型推理](#后端模型推理)
|
||||
- [SDK 模型推理](#sdk-模型推理)
|
||||
- [模型支持列表](#模型支持列表)
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
|
@ -155,8 +156,10 @@ task_processor.visualize(
|
|||
|
||||
## 模型支持列表
|
||||
|
||||
| Model | Task | ONNX Runtime | TensorRT | ncnn | PPLNN | OpenVINO |
|
||||
| :---------------------------------------------------------------------------------------------------- | :------------ | :----------: | :------: | :--: | :---: | :------: |
|
||||
| [HRNet](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/backbones.html#hrnet-cvpr-2019) | PoseDetection | Y | Y | Y | N | Y |
|
||||
| [MSPN](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/backbones.html#mspn-arxiv-2019) | PoseDetection | Y | Y | Y | N | Y |
|
||||
| [LiteHRNet](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/backbones.html#litehrnet-cvpr-2021) | PoseDetection | Y | Y | Y | N | Y |
|
||||
| Model | Task | ONNX Runtime | TensorRT | ncnn | PPLNN | OpenVINO |
|
||||
| :----------------------------------------------------------------------------------------------------- | :------------ | :----------: | :------: | :--: | :---: | :------: |
|
||||
| [HRNet](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/backbones.html#hrnet-cvpr-2019) | PoseDetection | Y | Y | Y | N | Y |
|
||||
| [MSPN](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/backbones.html#mspn-arxiv-2019) | PoseDetection | Y | Y | Y | N | Y |
|
||||
| [LiteHRNet](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/backbones.html#litehrnet-cvpr-2021) | PoseDetection | Y | Y | Y | N | Y |
|
||||
| [Hourglass](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/algorithms.html#hourglass-eccv-2016) | PoseDetection | Y | Y | Y | N | Y |
|
||||
| [SimCC](https://mmpose.readthedocs.io/en/1.x/model_zoo_papers/algorithms.html#simcc-eccv-2022) | PoseDetection | Y | Y | Y | N | N |
|
||||
|
|
|
@ -333,18 +333,25 @@ class PoseDetection(BaseTask):
|
|||
|
||||
def get_postprocess(self, *args, **kwargs) -> Dict:
|
||||
"""Get the postprocess information for SDK."""
|
||||
codec = self.model_cfg.codec
|
||||
if isinstance(codec, (list, tuple)):
|
||||
codec = codec[-1]
|
||||
component = 'UNKNOWN'
|
||||
params = copy.deepcopy(self.model_cfg.model.test_cfg)
|
||||
params.update(codec)
|
||||
if self.model_cfg.model.type == 'TopdownPoseEstimator':
|
||||
head_type = self.model_cfg.model.head.type
|
||||
if head_type == 'HeatmapHead':
|
||||
component = 'TopdownHeatmapSimpleHeadDecode'
|
||||
if codec.type == 'MSRAHeatmap':
|
||||
params['post_process'] = 'default'
|
||||
component = 'TopdownHeatmapSimpleHeadDecode'
|
||||
elif head_type == 'MSPNHead':
|
||||
elif codec.type == 'UDPHeatmap':
|
||||
params['post_process'] = 'default'
|
||||
params['use_udp'] = True
|
||||
elif codec.type == 'MegviiHeatmap':
|
||||
params['post_process'] = 'megvii'
|
||||
params['modulate_kernel'] = self.model_cfg.kernel_sizes[-1]
|
||||
component = 'TopdownHeatmapMSMUHeadDecode'
|
||||
elif codec.type == 'SimCCLabel':
|
||||
component = 'SimCCLabelDecode'
|
||||
else:
|
||||
raise RuntimeError(f'Unsupported head type: {head_type}')
|
||||
raise RuntimeError(f'Unsupported codecs type: {codec.type}')
|
||||
postprocess = dict(params=params, type=component)
|
||||
return postprocess
|
||||
|
|
|
@ -101,24 +101,14 @@ class End2EndModel(BaseBackendModel):
|
|||
inputs = inputs.contiguous().to(self.device)
|
||||
batch_outputs = self.wrapper({self.input_name: inputs})
|
||||
batch_outputs = self.wrapper.output_to_list(batch_outputs)
|
||||
batch_heatmaps = batch_outputs[0]
|
||||
# flip test
|
||||
test_cfg = self.model_cfg.model.test_cfg
|
||||
if test_cfg.get('flip_test', False):
|
||||
from mmpose.models.utils.tta import flip_heatmaps
|
||||
batch_inputs_flip = inputs.flip(-1).contiguous()
|
||||
batch_outputs_flip = self.wrapper(
|
||||
{self.input_name: batch_inputs_flip})
|
||||
batch_heatmaps_flip = self.wrapper.output_to_list(
|
||||
batch_outputs_flip)[0]
|
||||
flip_indices = data_samples[0].metainfo['flip_indices']
|
||||
batch_heatmaps_flip = flip_heatmaps(
|
||||
batch_heatmaps_flip,
|
||||
flip_mode=test_cfg.get('flip_mode', 'heatmap'),
|
||||
flip_indices=flip_indices,
|
||||
shift_heatmap=test_cfg.get('shift_heatmap', False))
|
||||
batch_heatmaps = (batch_heatmaps + batch_heatmaps_flip) * 0.5
|
||||
preds = self.head.decode(batch_heatmaps)
|
||||
codec = self.model_cfg.codec
|
||||
if isinstance(codec, (list, tuple)):
|
||||
codec = codec[-1]
|
||||
if codec.type == 'SimCCLabel':
|
||||
batch_pred_x, batch_pred_y = batch_outputs
|
||||
preds = self.head.decode((batch_pred_x, batch_pred_y))
|
||||
else:
|
||||
preds = self.head.decode(batch_outputs[0])
|
||||
results = self.pack_result(preds, data_samples)
|
||||
return results
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from . import heatmap_head, mspn_head, regression_head
|
||||
from . import heatmap_head, mspn_head, regression_head, simcc_head
|
||||
|
||||
__all__ = ['heatmap_head', 'mspn_head', 'regression_head']
|
||||
__all__ = ['heatmap_head', 'mspn_head', 'regression_head', 'simcc_head']
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmpose.models.heads.heatmap_heads.SimCCHead.predict')
|
||||
def simcc_head__predict(ctx, self, feats, batch_data_samples, test_cfg=None):
|
||||
"""Rewrite `predict` of HeatmapHead for default backend.
|
||||
|
||||
1. skip decoding and return output tensor directly.
|
||||
|
||||
Args:
|
||||
feats (tuple[Tensor]): Input features.
|
||||
batch_data_samples (list[SampleList]): Data samples contain
|
||||
image meta information.
|
||||
test_cfg (ConfigType): test config.
|
||||
|
||||
Returns:
|
||||
output_heatmap (torch.Tensor): Output heatmaps.
|
||||
"""
|
||||
simcc_x, simcc_y = self.forward(feats)
|
||||
return simcc_x, simcc_y
|
|
@ -5,6 +5,10 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
from mmdeploy.utils import Backend
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.Tensor.flatten', backend=Backend.NCNN.value)
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.flatten', backend=Backend.NCNN.value)
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.Tensor.flatten', backend=Backend.COREML.value)
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
|
|
|
@ -124,3 +124,18 @@ models:
|
|||
# - *pipeline_trt_static_fp32_256x256
|
||||
# - *pipeline_ncnn_static_fp32_256x256
|
||||
# - *pipeline_openvino_static_fp32_256x256
|
||||
|
||||
# TODO: no mobilenetv2_coco.yml in latest mmpose, enable this later
|
||||
# - name: SimCC
|
||||
# metafile: configs/body_2d_keypoint/simcc/coco/mobilenetv2_coco.yml
|
||||
# model_configs:
|
||||
# - configs/body_2d_keypoint/simcc/coco/simcc_mobilenetv2_wo-deconv-8xb64-210e_coco-256x192.py
|
||||
# pipelines:
|
||||
# - convert_image: *convert_image
|
||||
# deploy_config: configs/mmpose/pose-detection_simcc_onnxruntime_dynamic.py
|
||||
# - convert_image: *convert_image
|
||||
# deploy_config: configs/mmpose/pose-detection_simcc_tensorrt_dynamic-256x192.py
|
||||
# backend_test: *default_backend_test
|
||||
# sdk_config: configs/mmpose/pose-detection_simcc_sdk_static-256x192.py
|
||||
# - convert_image: *convert_image
|
||||
# deploy_config: configs/mmpose/pose-detection_simcc_ncnn_static-256x192.py
|
||||
|
|
Loading…
Reference in New Issue