support mmpose2.0 for sdk (#1080)

* add mmpose sdk

* --amend

* remove debug draw code

* fix docstring

* update
This commit is contained in:
RunningLeon 2022-10-08 15:49:44 +08:00 committed by GitHub
parent 8bae5671be
commit 2e459dba6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 111 additions and 125 deletions

View File

@ -3,14 +3,9 @@ _base_ = ['./pose-detection_static.py', '../_base_/backends/sdk.py']
codebase_config = dict(model_type='sdk')
backend_config = dict(pipeline=[
dict(type='LoadImageFromFile', channel_order='bgr'),
dict(
type='PackPoseInputs',
keys=['img'],
meta_keys=[
'id', 'img_id', 'img_path', 'ori_shape', 'img_shape', 'input_size',
'flip_indices', 'category'
])
dict(type='LoadImageFromFile'),
dict(type='GetBBoxCenterScale'),
dict(type='PackPoseInputs')
])
ext_info = dict(image_size=[192, 256], padding=1.25)

View File

@ -36,41 +36,41 @@ def process_model_config(
cfg = copy.deepcopy(model_cfg)
test_pipeline = cfg.test_dataloader.dataset.pipeline
data_preprocessor = cfg.model.data_preprocessor
sdk_pipeline = []
channel_order = 'rgb'
if input_shape is None:
codec = cfg.codec
if isinstance(codec, (list, tuple)):
codec = codec[0]
input_shape = np.array(codec['input_size'])
codec = cfg.codec
if isinstance(codec, list):
codec = codec[0]
input_size = codec['input_size'] if input_shape is None else input_shape
test_pipeline[0] = dict(type='LoadImageFromFile')
for i in reversed(range(len(test_pipeline))):
trans = test_pipeline[i]
if trans['type'] == 'PackPoseInputs':
test_pipeline.pop(i)
elif trans['type'] == 'GetBBoxCenterScale':
trans['type'] = 'TopDownGetBboxCenterScale'
trans['padding'] = 1.25 # default argument
trans['image_size'] = input_size
elif trans['type'] == 'TopdownAffine':
trans['type'] = 'TopDownAffine'
trans['image_size'] = input_size
trans.pop('input_size')
idx = 0
while idx < len(test_pipeline):
trans = test_pipeline[idx]
if trans.type == 'ToTensor':
assert idx + 1 < len(test_pipeline) and \
test_pipeline[idx + 1].type == 'NormalizeTensor'
trans = test_pipeline[idx + 1]
trans.type = 'Normalize'
trans['to_rgb'] = (channel_order == 'rgb')
trans['mean'] = [x * 255 for x in trans['mean']]
trans['std'] = [x * 255 for x in trans['std']]
sdk_pipeline.append(trans)
sdk_pipeline.append({'type': 'ImageToTensor', 'keys': ['img']})
idx = idx + 2
continue
test_pipeline.append(
dict(
type='Normalize',
mean=data_preprocessor.mean,
std=data_preprocessor.std,
to_rgb=data_preprocessor.bgr_to_rgb))
test_pipeline.append(dict(type='ImageToTensor', keys=['img']))
test_pipeline.append(
dict(
type='Collect',
keys=['img'],
meta_keys=[
'img_shape', 'pad_shape', 'ori_shape', 'img_norm_cfg',
'scale_factor', 'bbox_score', 'center', 'scale'
]))
if trans.type == 'LoadImage':
if not data_preprocessor.bgr_to_rgb:
channel_order = 'bgr'
if trans.type == 'TopDownAffine':
trans['image_size'] = input_shape
if trans.type == 'TopDownGetBboxCenterScale':
trans['image_size'] = input_shape
sdk_pipeline.append(trans)
idx = idx + 1
cfg.test_dataloader.dataset.pipeline = sdk_pipeline
cfg.test_dataloader.dataset.pipeline = test_pipeline
return cfg
@ -307,18 +307,25 @@ class PoseDetection(BaseTask):
Return:
dict: Composed of the preprocess information.
"""
# TODO: make it work with sdk
input_shape = get_input_shape(self.deploy_cfg)
model_cfg = process_model_config(self.model_cfg, [''], input_shape)
preprocess = model_cfg.test_dataloader.dataset.pipeline
preprocess[0].type = 'LoadImageFromFile'
return preprocess
def get_postprocess(self, *args, **kwargs) -> Dict:
"""Get the postprocess information for SDK."""
postprocess = {'type': 'UNKNOWN'}
if self.model_cfg.model.type == 'TopDown':
postprocess[
'type'] = self.model_cfg.model.keypoint_head.type + 'Decode'
postprocess.update(self.model_cfg.model.test_cfg)
component = 'UNKNOWN'
params = copy.deepcopy(self.model_cfg.model.test_cfg)
if self.model_cfg.model.type == 'TopdownPoseEstimator':
head_type = self.model_cfg.model.head.type
if head_type == 'HeatmapHead':
params['post_process'] = 'default'
component = 'TopdownHeatmapSimpleHeadDecode'
elif head_type == 'MSPNHead':
params['post_process'] = 'megvii'
params['modulate_kernel'] = self.model_cfg.kernel_sizes[-1]
component = 'TopdownHeatmapMSMUHeadDecode'
else:
raise RuntimeError(f'Unsupported head type: {head_type}')
postprocess = dict(params=params, type=component)
return postprocess

View File

@ -3,13 +3,12 @@ from itertools import zip_longest
from typing import List, Optional, Sequence, Union
import mmengine
import numpy as np
import torch
import torch.nn as nn
from mmengine import Config
from mmengine.model import BaseDataPreprocessor
from mmengine.registry import Registry
from mmengine.structures import BaseDataElement
from mmengine.structures import BaseDataElement, InstanceData
from mmdeploy.codebase.base import BaseBackendModel
from mmdeploy.utils import (Backend, get_backend, get_codebase_config,
@ -81,7 +80,7 @@ class End2EndModel(BaseBackendModel):
def forward(self,
inputs: torch.Tensor,
data_samples: Optional[List[BaseDataElement]],
data_samples: List[BaseDataElement],
mode: str = 'predict',
**kwargs):
"""Run forward inference.
@ -89,7 +88,7 @@ class End2EndModel(BaseBackendModel):
Args:
inputs (torch.Tensor): Input image(s) in [N x C x H x W]
format.
data_samples (Sequence[Sequence[dict]]): A list of meta info for
data_samples (List[BaseDataElement]): A list of meta info for
image(s).
*args: Other arguments.
**kwargs: Other key-pair arguments.
@ -119,11 +118,25 @@ class End2EndModel(BaseBackendModel):
flip_indices=flip_indices,
shift_heatmap=test_cfg.get('shift_heatmap', False))
batch_heatmaps = (batch_heatmaps + batch_heatmaps_flip) * 0.5
results = self.pack_result(batch_heatmaps, data_samples)
preds = self.head.decode(batch_heatmaps)
results = self.pack_result(preds, data_samples)
return results
def pack_result(self, heatmaps, data_samples):
preds = self.head.decode(heatmaps)
def pack_result(self,
preds: Sequence[InstanceData],
data_samples: List[BaseDataElement],
convert_coordinate: bool = True):
"""Pack pred results to mmpose format
Args:
preds (Sequence[InstanceData]): Prediction of keypoints.
data_samples (List[BaseDataElement]): A list of meta info for
image(s).
convert_coordinate (bool): Whether to convert keypoints
coordinates to original image space. Default is True.
Returns:
data_samples (List[BaseDataElement])
updated data_samples with predictions.
"""
if isinstance(preds, tuple):
batch_pred_instances, batch_pred_fields = preds
else:
@ -137,16 +150,16 @@ class End2EndModel(BaseBackendModel):
batch_pred_instances, batch_pred_fields, data_samples):
gt_instances = data_sample.gt_instances
# convert keypoint coordinates from input space to image space
bbox_centers = gt_instances.bbox_centers
bbox_scales = gt_instances.bbox_scales
input_size = data_sample.metainfo['input_size']
if convert_coordinate:
input_size = data_sample.metainfo['input_size']
bbox_centers = gt_instances.bbox_centers
bbox_scales = gt_instances.bbox_scales
keypoints = pred_instances.keypoints
keypoints = keypoints / input_size * bbox_scales
keypoints += bbox_centers - 0.5 * bbox_scales
pred_instances.keypoints = keypoints
pred_instances.keypoints = pred_instances.keypoints / input_size \
* bbox_scales + bbox_centers - 0.5 * bbox_scales
# add bbox information into pred_instances
pred_instances.bboxes = gt_instances.bboxes
pred_instances.bbox_scores = gt_instances.bbox_scores
@ -160,78 +173,47 @@ class End2EndModel(BaseBackendModel):
@__BACKEND_MODEL.register_module('sdk')
class SDKEnd2EndModel(End2EndModel):
"""SDK inference class, converts SDK output to mmcls format."""
"""SDK inference class, converts SDK output to mmpose format."""
def __init__(self, *args, **kwargs):
kwargs['data_preprocessor'] = None
super(SDKEnd2EndModel, self).__init__(*args, **kwargs)
self.ext_info = self.deploy_cfg.ext_info
def _xywh2cs(self, x, y, w, h, padding=1.25):
"""This encodes bbox(x,y,w,h) into (center, scale)
Args:
x, y, w, h (float): left, top, width and height
padding (float): bounding box padding factor
Returns:
center (np.ndarray[float32](2,)): center of the bbox (x, y).
scale (np.ndarray[float32](2,)): scale of the bbox w & h.
"""
anno_size = self.ext_info.image_size
aspect_ratio = anno_size[0] / anno_size[1]
center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32)
if w > aspect_ratio * h:
h = w * 1.0 / aspect_ratio
elif w < aspect_ratio * h:
w = h * aspect_ratio
# pixel std is 200.0
scale = np.array([w / 200.0, h / 200.0], dtype=np.float32)
# padding to include proper amount of context
scale = scale * padding
return center, scale
def _xywh2xyxy(self, x, y, w, h):
"""convert xywh to x1 y1 x2 y2."""
return x, y, x + w - 1, y + h - 1
def forward(self, inputs: List[torch.Tensor], *args, **kwargs) -> list:
def forward(self,
inputs: List[torch.Tensor],
data_samples: List[BaseDataElement],
mode: str = 'predict',
**kwargs) -> list:
"""Run forward inference.
Args:
inputs (List[torch.Tensor]): A list contains input image(s)
in [N x C x H x W] format.
*args: Other arguments.
in [H x W x C] format.
data_samples (List[BaseDataElement]):
Data samples of image metas.
mode (str): test mode, only support 'predict'
**kwargs: Other key-pair arguments.
Returns:
list: A list contains predictions.
"""
image_paths = []
boxes = np.zeros(shape=(inputs.shape[0], 6))
bbox_ids = []
sdk_boxes = []
for i, img_meta in enumerate(kwargs['img_metas']):
center, scale = self._xywh2cs(*img_meta['bbox'])
boxes[i, :2] = center
boxes[i, 2:4] = scale
boxes[i, 4] = np.prod(scale * 200.0)
boxes[i, 5] = img_meta[
'bbox_score'] if 'bbox_score' in img_meta else 1.0
sdk_boxes.append(self._xywh2xyxy(*img_meta['bbox']))
image_paths.append(img_meta['image_file'])
bbox_ids.append(img_meta['bbox_id'])
pred_results = []
for input_img, sample in zip(inputs, data_samples):
bboxes = sample.gt_instances.bboxes
pred = self.wrapper.handle(
[inputs[0].contiguous().detach().cpu().numpy()], sdk_boxes)
# inputs are c,h,w, sdk requested h,w,c
input_img = input_img.permute(1, 2, 0)
input_img = input_img.contiguous().detach().cpu().numpy()
keypoints = self.wrapper.handle(input_img, bboxes.tolist())
pred = InstanceData(
keypoints=keypoints[..., :2],
keypoint_scores=keypoints[..., 2])
pred_results.append(pred)
result = dict(
preds=pred,
boxes=boxes,
image_paths=image_paths,
bbox_ids=bbox_ids)
return result
results = self.pack_result(
pred_results, data_samples, convert_coordinate=False)
return results
def build_pose_detection_model(

View File

@ -65,16 +65,17 @@ class End2EndModel(BaseBackendModel):
def forward(self,
inputs: torch.Tensor,
data_samples: Optional[List[BaseDataElement]] = None,
mode: str = 'predict'):
data_samples: List[BaseDataElement],
mode: str = 'predict',
**kwargs):
"""Run forward inference.
Args:
img (Sequence[torch.Tensor]): A list contains input image(s)
inputs (torch.Tensor): Input image tensor
in [N x C x H x W] format.
img_metas (Sequence[Sequence[dict]]): A list of meta info for
data_samples (List[BaseDataElement]): A list of meta info for
image(s).
*args: Other arguments.
mode (str): forward mode, only support 'predict'.
**kwargs: Other key-pair arguments.
Returns:
@ -90,7 +91,8 @@ class End2EndModel(BaseBackendModel):
inputs})[self.output_names[0]]
return self.pack_result(batch_outputs, data_samples)
def pack_result(self, batch_outputs, data_samples):
def pack_result(self, batch_outputs: torch.Tensor,
data_samples: List[BaseDataElement]):
predictions = []
for seg_pred, data_sample in zip(batch_outputs, data_samples):
# resize seg_pred to original image shape

View File

@ -74,7 +74,7 @@ models:
- name: HRNET
metafile: configs/body_2d_keypoint/topdown_heatmap/coco/hrnet_coco.yml
model_configs:
- configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_hrnet-w32_8xb64-210e_coco-256x192.py
- configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_hrnet-w48_8xb32-210e_coco-256x192.py
pipelines:
- *pipeline_ort_static_fp32
- *pipeline_trt_static_fp16