Fix hourglass from mmpose (#1277)

* update mmpose rewritings

* update yml

* update docstring for mmpose
pull/1296/head
RunningLeon 2022-11-03 15:15:33 +08:00 committed by GitHub
parent 034ba67556
commit 13290614f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 66 additions and 168 deletions

View File

@ -21,7 +21,7 @@ def process_model_config(
imgs: Union[Sequence[str], Sequence[np.ndarray]],
input_shape: Optional[Sequence[int]] = None,
):
"""Process the model config.
"""Process the model config for sdk model.
Args:
model_cfg (mmengine.Config): The model config.
@ -38,7 +38,7 @@ def process_model_config(
data_preprocessor = cfg.model.data_preprocessor
codec = cfg.codec
if isinstance(codec, list):
codec = codec[0]
codec = codec[-1]
input_size = codec['input_size'] if input_shape is None else input_shape
test_pipeline[0] = dict(type='LoadImageFromFile')
for i in reversed(range(len(test_pipeline))):
@ -119,10 +119,12 @@ class MMPose(MMCodebase):
@classmethod
def register_deploy_modules(cls):
"""register rewritings."""
import mmdeploy.codebase.mmpose.models # noqa: F401
@classmethod
def register_all_modules(cls):
"""register all modules from mmpose."""
from mmpose.utils.setup_env import register_all_modules
cls.register_deploy_modules()
@ -148,6 +150,14 @@ class PoseDetection(BaseTask):
model_checkpoint: Optional[str] = None,
cfg_options: Optional[Dict] = None,
**kwargs) -> torch.nn.Module:
"""build pytorch model from model config and checkpoint
Args:
model_checkpoint (str|None): Input model checkpoint file.
cfg_options (dict|None): Optional config arguments.
Returns:
nn.Module: An initialized pytorch model.
"""
from mmpose.apis import init_model
from mmpose.utils import register_all_modules
register_all_modules()
@ -163,11 +173,10 @@ class PoseDetection(BaseTask):
def build_backend_model(self,
model_files: Sequence[str] = None,
**kwargs) -> torch.nn.Module:
"""Initialize backend model.
"""build backend model.
Args:
model_files (Sequence[str]): Input model files. Default is None.
Returns:
nn.Module: An initialized backend model.
"""
@ -194,6 +203,8 @@ class PoseDetection(BaseTask):
``np.ndarray``.
input_shape (list[int]): A list of two integer in (width, height)
format specifying input shape. Defaults to ``None``.
data_preprocessor (BaseDataPreprocessor | None): Input data pre-
processor. Default is ``None``.
Returns:
tuple: (data, inputs), meta information for the input image

View File

@ -33,8 +33,10 @@ class End2EndModel(BaseBackendModel):
device (str): A string represents device type.
deploy_cfg (str | mmengine.Config): Deployment config file or loaded
Config object.
deploy_cfg (str | mmengine.Config): Model config file or loaded Config
model_cfg (str | mmengine.Config): Model config file or loaded Config
object.
data_preprocessor (dict | nn.Module | None): Input data pre-
processor. Default is ``None``.
"""
def __init__(self,
@ -59,7 +61,8 @@ class End2EndModel(BaseBackendModel):
# create head for decoding heatmap
self.head = builder.build_head(model_cfg.model.head)
def _init_wrapper(self, backend, backend_files, device, **kwargs):
def _init_wrapper(self, backend: Backend, backend_files: Sequence[str],
device: str, **kwargs):
"""Initialize backend wrapper.
Args:
@ -90,8 +93,6 @@ class End2EndModel(BaseBackendModel):
format.
data_samples (List[BaseDataElement]): A list of meta info for
image(s).
*args: Other arguments.
**kwargs: Other key-pair arguments.
Returns:
list: A list contains predictions.
@ -223,7 +224,8 @@ def build_pose_detection_model(
deploy_cfg (str | mmengine.Config): Input deployment config file or
Config object.
device (str): Device to input model.
data_preprocessor (Config | BaseDataPreprocessor | None): Input data
pre-processor. Default is ``None``.
Returns:
BaseBackendModel: Pose model for a configured backend.
"""

View File

@ -1,4 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import heatmap_head, mspn_head, regression_head, simcc_head
from . import mspn_head
__all__ = ['heatmap_head', 'mspn_head', 'regression_head', 'simcc_head']
__all__ = ['mspn_head']

View File

@ -1,23 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
'mmpose.models.heads.heatmap_heads.HeatmapHead.predict')
def heatmap_head__predict(ctx, self, feats, batch_data_samples, test_cfg=None):
"""Rewrite `predict` of HeatmapHead for default backend.
1. skip heatmaps decoding and return heatmaps directly.
Args:
feats (tuple[Tensor]): Input features.
batch_data_samples (list[SampleList]): Data samples contain
image meta information.
test_cfg (ConfigType): test config.
Returns:
output_heatmap (torch.Tensor): Output heatmaps.
"""
batch_heatmaps = self.forward(feats)
return batch_heatmaps

View File

@ -4,21 +4,20 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
'mmpose.models.heads.heatmap_heads.MSPNHead.predict')
def mspn_head__predict(ctx, self, feats, batch_data_samples, test_cfg=None):
"""Rewrite `predict` of HeatmapHead for default backend.
'mmpose.models.heads.heatmap_heads.CPMHead.forward')
@FUNCTION_REWRITER.register_rewriter(
'mmpose.models.heads.heatmap_heads.MSPNHead.forward')
def mspn_head__forward(ctx, self, feats):
"""Rewrite `forward` of MSPNHead and CPMHead for default backend.
1. skip heatmaps decoding and return heatmaps directly.
1. return last stage heatmaps directly.
Args:
feats (tuple[Tensor]): Input features.
batch_data_samples (list[SampleList]): Data samples contain
image meta information.
test_cfg (ConfigType): test config.
Returns:
output_heatmap (torch.Tensor): Output heatmaps.
"""
msmu_batch_heatmaps = self.forward(feats)
msmu_batch_heatmaps = ctx.origin_func(self, feats)
batch_heatmaps = msmu_batch_heatmaps[-1]
return batch_heatmaps

View File

@ -1,28 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
'mmpose.models.heads.regression_heads.regression_head'
'.RegressionHead.predict')
def regression_head__predict(ctx,
self,
feats,
batch_data_samples,
test_cfg=None):
"""Rewrite `predict` of RegressionHead for default backend.
1. skip heatmaps decoding and return heatmaps directly.
Args:
feats (tuple[Tensor]): Input features.
batch_data_samples (list[SampleList]): Data samples contain
image meta information.
test_cfg (ConfigType): test config.
Returns:
output_heatmap (torch.Tensor): Output heatmaps.
"""
batch_heatmaps = self.forward(feats)
return batch_heatmaps

View File

@ -1,22 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
'mmpose.models.heads.heatmap_heads.SimCCHead.predict')
def simcc_head__predict(ctx, self, feats, batch_data_samples, test_cfg=None):
"""Rewrite `predict` of HeatmapHead for default backend.
1. skip decoding and return output tensor directly.
Args:
feats (tuple[Tensor]): Input features.
batch_data_samples (list[SampleList]): Data samples contain
image meta information.
test_cfg (ConfigType): test config.
Returns:
output_heatmap (torch.Tensor): Output heatmaps.
"""
simcc_x, simcc_y = self.forward(feats)
return simcc_x, simcc_y

View File

@ -1,4 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import base, topdown
from . import base
__all__ = ['base', 'topdown']
__all__ = ['base']

View File

@ -4,33 +4,18 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
'mmpose.models.pose_estimators.base.BasePoseEstimator.forward')
def base_pose_estimator__forward(ctx,
self,
inputs,
data_samples=None,
mode='predict',
**kwargs):
"""Rewrite `forward_test` of TopDown for default backend.'.
def base_pose_estimator__forward(ctx, self, inputs, *args, **kwargs):
"""Rewrite `forward` of TopDown for default backend.'.
1. only support mode='predict'.
2. create data_samples if necessary
1.directly call _forward of subclass.
Args:
ctx (ContextCaller): The context with additional information.
self (BasePoseEstimator): The instance of the class Object
BasePoseEstimator.
inputs (torch.Tensor[NxCxHxW]): Input images.
data_samples (SampleList | None): Data samples contain
image meta information.
Returns:
torch.Tensor: The predicted heatmaps.
"""
if data_samples is None:
from mmpose.structures import PoseDataSample
_, c, h, w = [int(_) for _ in inputs.shape]
metainfo = dict(
img_shape=(h, w, c),
input_size=(w, h),
heatmap_size=self.cfg.codec.heatmap_size)
data_sample = PoseDataSample(metainfo=metainfo)
data_samples = [data_sample]
return self.predict(inputs, data_samples)
return self._forward(inputs)

View File

@ -1,24 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
'mmpose.models.pose_estimators.topdown.TopdownPoseEstimator.predict')
def topdown_pose_estimator__predict(ctx, self, inputs, data_samples, **kwargs):
"""Rewrite `predict` of TopdownPoseEstimator for default backend.'.
1. skip flip_test
2. avoid call `add_pred_to_datasample`
Args:
inputs (torch.Tensor[NxCxHxW]): Input images.
data_samples (SampleList | None): Data samples contain
image meta information.
Returns:
torch.Tensor: The predicted heatmaps.
"""
assert self.with_head, ('The model must have head to perform prediction.')
feats = self.extract_feat(inputs)
preds = self.head.predict(feats, data_samples, test_cfg=self.test_cfg)
return preds

View File

@ -114,28 +114,26 @@ models:
- *pipeline_ts_fp32
- *pipeline_pplnn_static_fp32
# TODO: no hourglass_coco.yml in latest mmpose, enable this later
# - name: Hourglass
# metafile: configs/body_2d_keypoint/topdown_heatmap/coco/hourglass_coco.yml
# model_configs:
# - configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_hourglass52_8xb32-210e_coco-256x256.py
# pipelines:
# - *pipeline_ort_static_fp32
# - *pipeline_trt_static_fp32_256x256
# - *pipeline_ncnn_static_fp32_256x256
# - *pipeline_openvino_static_fp32_256x256
- name: Hourglass
metafile: configs/body_2d_keypoint/topdown_heatmap/coco/hourglass_coco.yml
model_configs:
- configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_hourglass52_8xb32-210e_coco-256x256.py
pipelines:
- *pipeline_ort_static_fp32
- *pipeline_trt_static_fp32_256x256
- *pipeline_ncnn_static_fp32_256x256
- *pipeline_openvino_static_fp32_256x256
# TODO: no mobilenetv2_coco.yml in latest mmpose, enable this later
# - name: SimCC
# metafile: configs/body_2d_keypoint/simcc/coco/mobilenetv2_coco.yml
# model_configs:
# - configs/body_2d_keypoint/simcc/coco/simcc_mobilenetv2_wo-deconv-8xb64-210e_coco-256x192.py
# pipelines:
# - convert_image: *convert_image
# deploy_config: configs/mmpose/pose-detection_simcc_onnxruntime_dynamic.py
# - convert_image: *convert_image
# deploy_config: configs/mmpose/pose-detection_simcc_tensorrt_dynamic-256x192.py
# backend_test: *default_backend_test
# sdk_config: configs/mmpose/pose-detection_simcc_sdk_static-256x192.py
# - convert_image: *convert_image
# deploy_config: configs/mmpose/pose-detection_simcc_ncnn_static-256x192.py
- name: SimCC
metafile: configs/body_2d_keypoint/simcc/coco/mobilenetv2_coco.yml
model_configs:
- configs/body_2d_keypoint/simcc/coco/simcc_mobilenetv2_wo-deconv-8xb64-210e_coco-256x192.py
pipelines:
- convert_image: *convert_image
deploy_config: configs/mmpose/pose-detection_simcc_onnxruntime_dynamic.py
- convert_image: *convert_image
deploy_config: configs/mmpose/pose-detection_simcc_tensorrt_dynamic-256x192.py
backend_test: *default_backend_test
sdk_config: configs/mmpose/pose-detection_simcc_sdk_static-256x192.py
- convert_image: *convert_image
deploy_config: configs/mmpose/pose-detection_simcc_ncnn_static-256x192.py

View File

@ -30,13 +30,13 @@ def get_heatmap_head():
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
def test_heatmaphead_predict(backend_type: Backend):
def test_heatmaphead_forward(backend_type: Backend):
check_backend(backend_type, True)
model = get_heatmap_head()
model.cpu().eval()
deploy_cfg = generate_mmpose_deploy_config(backend_type.value)
feats = [torch.rand(1, 2, 32, 48)]
wrapped_model = WrapModel(model, 'predict', batch_data_samples=None)
wrapped_model = WrapModel(model, 'forward')
rewrite_inputs = {'feats': feats}
rewrite_outputs, _ = get_rewrite_outputs(
wrapped_model=wrapped_model,
@ -59,13 +59,13 @@ def get_msmu_head():
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
def test_msmuhead_predict(backend_type: Backend):
def test_msmuhead_forward(backend_type: Backend):
check_backend(backend_type, True)
model = get_msmu_head()
model.cpu().eval()
deploy_cfg = generate_mmpose_deploy_config(backend_type.value)
feats = [[torch.rand(1, 16, 32, 48)]]
wrapped_model = WrapModel(model, 'predict', batch_data_samples=None)
wrapped_model = WrapModel(model, 'forward')
rewrite_inputs = {'feats': feats}
rewrite_outputs, _ = get_rewrite_outputs(
wrapped_model=wrapped_model,