mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
support mmpose2.0 for sdk (#1080)
* add mmpose sdk * --amend * remove debug draw code * fix docstring * update
This commit is contained in:
parent
8bae5671be
commit
2e459dba6d
@ -3,14 +3,9 @@ _base_ = ['./pose-detection_static.py', '../_base_/backends/sdk.py']
|
||||
codebase_config = dict(model_type='sdk')
|
||||
|
||||
backend_config = dict(pipeline=[
|
||||
dict(type='LoadImageFromFile', channel_order='bgr'),
|
||||
dict(
|
||||
type='PackPoseInputs',
|
||||
keys=['img'],
|
||||
meta_keys=[
|
||||
'id', 'img_id', 'img_path', 'ori_shape', 'img_shape', 'input_size',
|
||||
'flip_indices', 'category'
|
||||
])
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='GetBBoxCenterScale'),
|
||||
dict(type='PackPoseInputs')
|
||||
])
|
||||
|
||||
ext_info = dict(image_size=[192, 256], padding=1.25)
|
||||
|
@ -36,41 +36,41 @@ def process_model_config(
|
||||
cfg = copy.deepcopy(model_cfg)
|
||||
test_pipeline = cfg.test_dataloader.dataset.pipeline
|
||||
data_preprocessor = cfg.model.data_preprocessor
|
||||
sdk_pipeline = []
|
||||
channel_order = 'rgb'
|
||||
if input_shape is None:
|
||||
codec = cfg.codec
|
||||
if isinstance(codec, (list, tuple)):
|
||||
codec = codec[0]
|
||||
input_shape = np.array(codec['input_size'])
|
||||
codec = cfg.codec
|
||||
if isinstance(codec, list):
|
||||
codec = codec[0]
|
||||
input_size = codec['input_size'] if input_shape is None else input_shape
|
||||
test_pipeline[0] = dict(type='LoadImageFromFile')
|
||||
for i in reversed(range(len(test_pipeline))):
|
||||
trans = test_pipeline[i]
|
||||
if trans['type'] == 'PackPoseInputs':
|
||||
test_pipeline.pop(i)
|
||||
elif trans['type'] == 'GetBBoxCenterScale':
|
||||
trans['type'] = 'TopDownGetBboxCenterScale'
|
||||
trans['padding'] = 1.25 # default argument
|
||||
trans['image_size'] = input_size
|
||||
elif trans['type'] == 'TopdownAffine':
|
||||
trans['type'] = 'TopDownAffine'
|
||||
trans['image_size'] = input_size
|
||||
trans.pop('input_size')
|
||||
|
||||
idx = 0
|
||||
while idx < len(test_pipeline):
|
||||
trans = test_pipeline[idx]
|
||||
if trans.type == 'ToTensor':
|
||||
assert idx + 1 < len(test_pipeline) and \
|
||||
test_pipeline[idx + 1].type == 'NormalizeTensor'
|
||||
trans = test_pipeline[idx + 1]
|
||||
trans.type = 'Normalize'
|
||||
trans['to_rgb'] = (channel_order == 'rgb')
|
||||
trans['mean'] = [x * 255 for x in trans['mean']]
|
||||
trans['std'] = [x * 255 for x in trans['std']]
|
||||
sdk_pipeline.append(trans)
|
||||
sdk_pipeline.append({'type': 'ImageToTensor', 'keys': ['img']})
|
||||
idx = idx + 2
|
||||
continue
|
||||
test_pipeline.append(
|
||||
dict(
|
||||
type='Normalize',
|
||||
mean=data_preprocessor.mean,
|
||||
std=data_preprocessor.std,
|
||||
to_rgb=data_preprocessor.bgr_to_rgb))
|
||||
test_pipeline.append(dict(type='ImageToTensor', keys=['img']))
|
||||
test_pipeline.append(
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img'],
|
||||
meta_keys=[
|
||||
'img_shape', 'pad_shape', 'ori_shape', 'img_norm_cfg',
|
||||
'scale_factor', 'bbox_score', 'center', 'scale'
|
||||
]))
|
||||
|
||||
if trans.type == 'LoadImage':
|
||||
if not data_preprocessor.bgr_to_rgb:
|
||||
channel_order = 'bgr'
|
||||
if trans.type == 'TopDownAffine':
|
||||
trans['image_size'] = input_shape
|
||||
if trans.type == 'TopDownGetBboxCenterScale':
|
||||
trans['image_size'] = input_shape
|
||||
|
||||
sdk_pipeline.append(trans)
|
||||
idx = idx + 1
|
||||
cfg.test_dataloader.dataset.pipeline = sdk_pipeline
|
||||
cfg.test_dataloader.dataset.pipeline = test_pipeline
|
||||
return cfg
|
||||
|
||||
|
||||
@ -307,18 +307,25 @@ class PoseDetection(BaseTask):
|
||||
Return:
|
||||
dict: Composed of the preprocess information.
|
||||
"""
|
||||
# TODO: make it work with sdk
|
||||
input_shape = get_input_shape(self.deploy_cfg)
|
||||
model_cfg = process_model_config(self.model_cfg, [''], input_shape)
|
||||
preprocess = model_cfg.test_dataloader.dataset.pipeline
|
||||
preprocess[0].type = 'LoadImageFromFile'
|
||||
return preprocess
|
||||
|
||||
def get_postprocess(self, *args, **kwargs) -> Dict:
|
||||
"""Get the postprocess information for SDK."""
|
||||
postprocess = {'type': 'UNKNOWN'}
|
||||
if self.model_cfg.model.type == 'TopDown':
|
||||
postprocess[
|
||||
'type'] = self.model_cfg.model.keypoint_head.type + 'Decode'
|
||||
postprocess.update(self.model_cfg.model.test_cfg)
|
||||
component = 'UNKNOWN'
|
||||
params = copy.deepcopy(self.model_cfg.model.test_cfg)
|
||||
if self.model_cfg.model.type == 'TopdownPoseEstimator':
|
||||
head_type = self.model_cfg.model.head.type
|
||||
if head_type == 'HeatmapHead':
|
||||
params['post_process'] = 'default'
|
||||
component = 'TopdownHeatmapSimpleHeadDecode'
|
||||
elif head_type == 'MSPNHead':
|
||||
params['post_process'] = 'megvii'
|
||||
params['modulate_kernel'] = self.model_cfg.kernel_sizes[-1]
|
||||
component = 'TopdownHeatmapMSMUHeadDecode'
|
||||
else:
|
||||
raise RuntimeError(f'Unsupported head type: {head_type}')
|
||||
postprocess = dict(params=params, type=component)
|
||||
return postprocess
|
||||
|
@ -3,13 +3,12 @@ from itertools import zip_longest
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
import mmengine
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine import Config
|
||||
from mmengine.model import BaseDataPreprocessor
|
||||
from mmengine.registry import Registry
|
||||
from mmengine.structures import BaseDataElement
|
||||
from mmengine.structures import BaseDataElement, InstanceData
|
||||
|
||||
from mmdeploy.codebase.base import BaseBackendModel
|
||||
from mmdeploy.utils import (Backend, get_backend, get_codebase_config,
|
||||
@ -81,7 +80,7 @@ class End2EndModel(BaseBackendModel):
|
||||
|
||||
def forward(self,
|
||||
inputs: torch.Tensor,
|
||||
data_samples: Optional[List[BaseDataElement]],
|
||||
data_samples: List[BaseDataElement],
|
||||
mode: str = 'predict',
|
||||
**kwargs):
|
||||
"""Run forward inference.
|
||||
@ -89,7 +88,7 @@ class End2EndModel(BaseBackendModel):
|
||||
Args:
|
||||
inputs (torch.Tensor): Input image(s) in [N x C x H x W]
|
||||
format.
|
||||
data_samples (Sequence[Sequence[dict]]): A list of meta info for
|
||||
data_samples (List[BaseDataElement]): A list of meta info for
|
||||
image(s).
|
||||
*args: Other arguments.
|
||||
**kwargs: Other key-pair arguments.
|
||||
@ -119,11 +118,25 @@ class End2EndModel(BaseBackendModel):
|
||||
flip_indices=flip_indices,
|
||||
shift_heatmap=test_cfg.get('shift_heatmap', False))
|
||||
batch_heatmaps = (batch_heatmaps + batch_heatmaps_flip) * 0.5
|
||||
results = self.pack_result(batch_heatmaps, data_samples)
|
||||
preds = self.head.decode(batch_heatmaps)
|
||||
results = self.pack_result(preds, data_samples)
|
||||
return results
|
||||
|
||||
def pack_result(self, heatmaps, data_samples):
|
||||
preds = self.head.decode(heatmaps)
|
||||
def pack_result(self,
|
||||
preds: Sequence[InstanceData],
|
||||
data_samples: List[BaseDataElement],
|
||||
convert_coordinate: bool = True):
|
||||
"""Pack pred results to mmpose format
|
||||
Args:
|
||||
preds (Sequence[InstanceData]): Prediction of keypoints.
|
||||
data_samples (List[BaseDataElement]): A list of meta info for
|
||||
image(s).
|
||||
convert_coordinate (bool): Whether to convert keypoints
|
||||
coordinates to original image space. Default is True.
|
||||
Returns:
|
||||
data_samples (List[BaseDataElement]):
|
||||
updated data_samples with predictions.
|
||||
"""
|
||||
if isinstance(preds, tuple):
|
||||
batch_pred_instances, batch_pred_fields = preds
|
||||
else:
|
||||
@ -137,16 +150,16 @@ class End2EndModel(BaseBackendModel):
|
||||
batch_pred_instances, batch_pred_fields, data_samples):
|
||||
|
||||
gt_instances = data_sample.gt_instances
|
||||
|
||||
# convert keypoint coordinates from input space to image space
|
||||
bbox_centers = gt_instances.bbox_centers
|
||||
bbox_scales = gt_instances.bbox_scales
|
||||
input_size = data_sample.metainfo['input_size']
|
||||
if convert_coordinate:
|
||||
input_size = data_sample.metainfo['input_size']
|
||||
bbox_centers = gt_instances.bbox_centers
|
||||
bbox_scales = gt_instances.bbox_scales
|
||||
keypoints = pred_instances.keypoints
|
||||
keypoints = keypoints / input_size * bbox_scales
|
||||
keypoints += bbox_centers - 0.5 * bbox_scales
|
||||
pred_instances.keypoints = keypoints
|
||||
|
||||
pred_instances.keypoints = pred_instances.keypoints / input_size \
|
||||
* bbox_scales + bbox_centers - 0.5 * bbox_scales
|
||||
|
||||
# add bbox information into pred_instances
|
||||
pred_instances.bboxes = gt_instances.bboxes
|
||||
pred_instances.bbox_scores = gt_instances.bbox_scores
|
||||
|
||||
@ -160,78 +173,47 @@ class End2EndModel(BaseBackendModel):
|
||||
|
||||
@__BACKEND_MODEL.register_module('sdk')
|
||||
class SDKEnd2EndModel(End2EndModel):
|
||||
"""SDK inference class, converts SDK output to mmcls format."""
|
||||
"""SDK inference class, converts SDK output to mmpose format."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs['data_preprocessor'] = None
|
||||
super(SDKEnd2EndModel, self).__init__(*args, **kwargs)
|
||||
self.ext_info = self.deploy_cfg.ext_info
|
||||
|
||||
def _xywh2cs(self, x, y, w, h, padding=1.25):
|
||||
"""This encodes bbox(x,y,w,h) into (center, scale)
|
||||
Args:
|
||||
x, y, w, h (float): left, top, width and height
|
||||
padding (float): bounding box padding factor
|
||||
Returns:
|
||||
center (np.ndarray[float32](2,)): center of the bbox (x, y).
|
||||
scale (np.ndarray[float32](2,)): scale of the bbox w & h.
|
||||
"""
|
||||
anno_size = self.ext_info.image_size
|
||||
aspect_ratio = anno_size[0] / anno_size[1]
|
||||
center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32)
|
||||
|
||||
if w > aspect_ratio * h:
|
||||
h = w * 1.0 / aspect_ratio
|
||||
elif w < aspect_ratio * h:
|
||||
w = h * aspect_ratio
|
||||
|
||||
# pixel std is 200.0
|
||||
scale = np.array([w / 200.0, h / 200.0], dtype=np.float32)
|
||||
# padding to include proper amount of context
|
||||
scale = scale * padding
|
||||
|
||||
return center, scale
|
||||
|
||||
def _xywh2xyxy(self, x, y, w, h):
|
||||
"""convert xywh to x1 y1 x2 y2."""
|
||||
return x, y, x + w - 1, y + h - 1
|
||||
|
||||
def forward(self, inputs: List[torch.Tensor], *args, **kwargs) -> list:
|
||||
def forward(self,
|
||||
inputs: List[torch.Tensor],
|
||||
data_samples: List[BaseDataElement],
|
||||
mode: str = 'predict',
|
||||
**kwargs) -> list:
|
||||
"""Run forward inference.
|
||||
|
||||
Args:
|
||||
inputs (List[torch.Tensor]): A list contains input image(s)
|
||||
in [N x C x H x W] format.
|
||||
*args: Other arguments.
|
||||
in [H x W x C] format.
|
||||
data_samples (List[BaseDataElement]):
|
||||
Data samples of image metas.
|
||||
mode (str): test mode, only support 'predict'
|
||||
**kwargs: Other key-pair arguments.
|
||||
|
||||
Returns:
|
||||
list: A list contains predictions.
|
||||
"""
|
||||
image_paths = []
|
||||
boxes = np.zeros(shape=(inputs.shape[0], 6))
|
||||
bbox_ids = []
|
||||
sdk_boxes = []
|
||||
for i, img_meta in enumerate(kwargs['img_metas']):
|
||||
center, scale = self._xywh2cs(*img_meta['bbox'])
|
||||
boxes[i, :2] = center
|
||||
boxes[i, 2:4] = scale
|
||||
boxes[i, 4] = np.prod(scale * 200.0)
|
||||
boxes[i, 5] = img_meta[
|
||||
'bbox_score'] if 'bbox_score' in img_meta else 1.0
|
||||
sdk_boxes.append(self._xywh2xyxy(*img_meta['bbox']))
|
||||
image_paths.append(img_meta['image_file'])
|
||||
bbox_ids.append(img_meta['bbox_id'])
|
||||
pred_results = []
|
||||
for input_img, sample in zip(inputs, data_samples):
|
||||
bboxes = sample.gt_instances.bboxes
|
||||
|
||||
pred = self.wrapper.handle(
|
||||
[inputs[0].contiguous().detach().cpu().numpy()], sdk_boxes)
|
||||
# inputs are c,h,w, sdk requested h,w,c
|
||||
input_img = input_img.permute(1, 2, 0)
|
||||
input_img = input_img.contiguous().detach().cpu().numpy()
|
||||
keypoints = self.wrapper.handle(input_img, bboxes.tolist())
|
||||
pred = InstanceData(
|
||||
keypoints=keypoints[..., :2],
|
||||
keypoint_scores=keypoints[..., 2])
|
||||
pred_results.append(pred)
|
||||
|
||||
result = dict(
|
||||
preds=pred,
|
||||
boxes=boxes,
|
||||
image_paths=image_paths,
|
||||
bbox_ids=bbox_ids)
|
||||
return result
|
||||
results = self.pack_result(
|
||||
pred_results, data_samples, convert_coordinate=False)
|
||||
return results
|
||||
|
||||
|
||||
def build_pose_detection_model(
|
||||
|
@ -65,16 +65,17 @@ class End2EndModel(BaseBackendModel):
|
||||
|
||||
def forward(self,
|
||||
inputs: torch.Tensor,
|
||||
data_samples: Optional[List[BaseDataElement]] = None,
|
||||
mode: str = 'predict'):
|
||||
data_samples: List[BaseDataElement],
|
||||
mode: str = 'predict',
|
||||
**kwargs):
|
||||
"""Run forward inference.
|
||||
|
||||
Args:
|
||||
img (Sequence[torch.Tensor]): A list contains input image(s)
|
||||
inputs (torch.Tensor): Input image tensor
|
||||
in [N x C x H x W] format.
|
||||
img_metas (Sequence[Sequence[dict]]): A list of meta info for
|
||||
data_samples (List[BaseDataElement]): A list of meta info for
|
||||
image(s).
|
||||
*args: Other arguments.
|
||||
mode (str): forward mode, only support 'predict'.
|
||||
**kwargs: Other key-pair arguments.
|
||||
|
||||
Returns:
|
||||
@ -90,7 +91,8 @@ class End2EndModel(BaseBackendModel):
|
||||
inputs})[self.output_names[0]]
|
||||
return self.pack_result(batch_outputs, data_samples)
|
||||
|
||||
def pack_result(self, batch_outputs, data_samples):
|
||||
def pack_result(self, batch_outputs: torch.Tensor,
|
||||
data_samples: List[BaseDataElement]):
|
||||
predictions = []
|
||||
for seg_pred, data_sample in zip(batch_outputs, data_samples):
|
||||
# resize seg_pred to original image shape
|
||||
|
@ -74,7 +74,7 @@ models:
|
||||
- name: HRNET
|
||||
metafile: configs/body_2d_keypoint/topdown_heatmap/coco/hrnet_coco.yml
|
||||
model_configs:
|
||||
- configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_hrnet-w32_8xb64-210e_coco-256x192.py
|
||||
- configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_hrnet-w48_8xb32-210e_coco-256x192.py
|
||||
pipelines:
|
||||
- *pipeline_ort_static_fp32
|
||||
- *pipeline_trt_static_fp16
|
||||
|
Loading…
x
Reference in New Issue
Block a user