65 lines
2.3 KiB
Python
65 lines
2.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import mmengine
|
|
import torch
|
|
from mmengine.structures import InstanceData, PixelData
|
|
|
|
from mmdeploy.apis import build_task_processor
|
|
from mmdeploy.utils import IR, Backend, Codebase, Task, load_config
|
|
|
|
|
|
def generate_datasample(img_size, heatmap_size=(64, 48)):
|
|
from mmpose.structures import PoseDataSample
|
|
h, w = img_size[:2]
|
|
metainfo = dict(
|
|
img_shape=(h, w, 3),
|
|
crop_size=(h, w),
|
|
input_size=(h, w),
|
|
heatmap_size=heatmap_size)
|
|
pred_instances = InstanceData()
|
|
pred_instances.bboxes = torch.rand((1, 4)).numpy()
|
|
pred_instances.bbox_scales = torch.ones(1, 2).numpy()
|
|
pred_instances.bbox_scores = torch.ones(1).numpy()
|
|
pred_instances.bbox_centers = torch.ones(1, 2).numpy()
|
|
pred_instances.keypoints = torch.rand((1, 17, 2))
|
|
pred_instances.keypoints_visible = torch.rand((1, 17, 1))
|
|
gt_fields = PixelData()
|
|
gt_fields.heatmaps = torch.rand((17, 64, 48))
|
|
data_sample = PoseDataSample(metainfo=metainfo)
|
|
data_sample.pred_instances = pred_instances
|
|
data_sample.gt_instances = pred_instances
|
|
data_sample.gt_fields = gt_fields
|
|
return data_sample
|
|
|
|
|
|
def generate_mmpose_deploy_config(backend=Backend.ONNXRUNTIME.value,
|
|
cfg_options=None):
|
|
deploy_cfg = mmengine.Config(
|
|
dict(
|
|
backend_config=dict(type=backend),
|
|
codebase_config=dict(
|
|
type=Codebase.MMPOSE.value, task=Task.POSE_DETECTION.value),
|
|
onnx_config=dict(
|
|
type=IR.ONNX.value,
|
|
export_params=True,
|
|
keep_initializers_as_inputs=False,
|
|
opset_version=11,
|
|
input_shape=None,
|
|
input_names=['input'],
|
|
output_names=['output'])))
|
|
|
|
if cfg_options is not None:
|
|
deploy_cfg.update(cfg_options)
|
|
|
|
return deploy_cfg
|
|
|
|
|
|
def generate_mmpose_task_processor(model_cfg=None, deploy_cfg=None):
|
|
|
|
if model_cfg is None:
|
|
model_cfg = 'tests/test_codebase/test_mmpose/data/model.py'
|
|
if deploy_cfg is None:
|
|
deploy_cfg = generate_mmpose_deploy_config()
|
|
model_cfg, deploy_cfg = load_config(model_cfg, deploy_cfg)
|
|
task_processor = build_task_processor(model_cfg, deploy_cfg, 'cpu')
|
|
return task_processor
|