Fix hourglass from mmpose (#1277)
* update mmpose rewritings * update yml * update docstring for mmposepull/1296/head
parent
034ba67556
commit
13290614f6
|
@ -21,7 +21,7 @@ def process_model_config(
|
|||
imgs: Union[Sequence[str], Sequence[np.ndarray]],
|
||||
input_shape: Optional[Sequence[int]] = None,
|
||||
):
|
||||
"""Process the model config.
|
||||
"""Process the model config for sdk model.
|
||||
|
||||
Args:
|
||||
model_cfg (mmengine.Config): The model config.
|
||||
|
@ -38,7 +38,7 @@ def process_model_config(
|
|||
data_preprocessor = cfg.model.data_preprocessor
|
||||
codec = cfg.codec
|
||||
if isinstance(codec, list):
|
||||
codec = codec[0]
|
||||
codec = codec[-1]
|
||||
input_size = codec['input_size'] if input_shape is None else input_shape
|
||||
test_pipeline[0] = dict(type='LoadImageFromFile')
|
||||
for i in reversed(range(len(test_pipeline))):
|
||||
|
@ -119,10 +119,12 @@ class MMPose(MMCodebase):
|
|||
|
||||
@classmethod
|
||||
def register_deploy_modules(cls):
|
||||
"""register rewritings."""
|
||||
import mmdeploy.codebase.mmpose.models # noqa: F401
|
||||
|
||||
@classmethod
|
||||
def register_all_modules(cls):
|
||||
"""register all modules from mmpose."""
|
||||
from mmpose.utils.setup_env import register_all_modules
|
||||
|
||||
cls.register_deploy_modules()
|
||||
|
@ -148,6 +150,14 @@ class PoseDetection(BaseTask):
|
|||
model_checkpoint: Optional[str] = None,
|
||||
cfg_options: Optional[Dict] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
"""build pytorch model from model config and checkpoint
|
||||
Args:
|
||||
model_checkpoint (str|None): Input model checkpoint file.
|
||||
cfg_options (dict|None): Optional config arguments.
|
||||
|
||||
Returns:
|
||||
nn.Module: An initialized pytorch model.
|
||||
"""
|
||||
from mmpose.apis import init_model
|
||||
from mmpose.utils import register_all_modules
|
||||
register_all_modules()
|
||||
|
@ -163,11 +173,10 @@ class PoseDetection(BaseTask):
|
|||
def build_backend_model(self,
|
||||
model_files: Sequence[str] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
"""Initialize backend model.
|
||||
"""build backend model.
|
||||
|
||||
Args:
|
||||
model_files (Sequence[str]): Input model files. Default is None.
|
||||
|
||||
Returns:
|
||||
nn.Module: An initialized backend model.
|
||||
"""
|
||||
|
@ -194,6 +203,8 @@ class PoseDetection(BaseTask):
|
|||
``np.ndarray``.
|
||||
input_shape (list[int]): A list of two integer in (width, height)
|
||||
format specifying input shape. Defaults to ``None``.
|
||||
data_preprocessor (BaseDataPreprocessor | None): Input data pre-
|
||||
processor. Default is ``None``.
|
||||
|
||||
Returns:
|
||||
tuple: (data, inputs), meta information for the input image
|
||||
|
|
|
@ -33,8 +33,10 @@ class End2EndModel(BaseBackendModel):
|
|||
device (str): A string represents device type.
|
||||
deploy_cfg (str | mmengine.Config): Deployment config file or loaded
|
||||
Config object.
|
||||
deploy_cfg (str | mmengine.Config): Model config file or loaded Config
|
||||
model_cfg (str | mmengine.Config): Model config file or loaded Config
|
||||
object.
|
||||
data_preprocessor (dict | nn.Module | None): Input data pre-
|
||||
processor. Default is ``None``.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -59,7 +61,8 @@ class End2EndModel(BaseBackendModel):
|
|||
# create head for decoding heatmap
|
||||
self.head = builder.build_head(model_cfg.model.head)
|
||||
|
||||
def _init_wrapper(self, backend, backend_files, device, **kwargs):
|
||||
def _init_wrapper(self, backend: Backend, backend_files: Sequence[str],
|
||||
device: str, **kwargs):
|
||||
"""Initialize backend wrapper.
|
||||
|
||||
Args:
|
||||
|
@ -90,8 +93,6 @@ class End2EndModel(BaseBackendModel):
|
|||
format.
|
||||
data_samples (List[BaseDataElement]): A list of meta info for
|
||||
image(s).
|
||||
*args: Other arguments.
|
||||
**kwargs: Other key-pair arguments.
|
||||
|
||||
Returns:
|
||||
list: A list contains predictions.
|
||||
|
@ -223,7 +224,8 @@ def build_pose_detection_model(
|
|||
deploy_cfg (str | mmengine.Config): Input deployment config file or
|
||||
Config object.
|
||||
device (str): Device to input model.
|
||||
|
||||
data_preprocessor (Config | BaseDataPreprocessor | None): Input data
|
||||
pre-processor. Default is ``None``.
|
||||
Returns:
|
||||
BaseBackendModel: Pose model for a configured backend.
|
||||
"""
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from . import heatmap_head, mspn_head, regression_head, simcc_head
|
||||
from . import mspn_head
|
||||
|
||||
__all__ = ['heatmap_head', 'mspn_head', 'regression_head', 'simcc_head']
|
||||
__all__ = ['mspn_head']
|
||||
|
|
|
@ -1,23 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmpose.models.heads.heatmap_heads.HeatmapHead.predict')
|
||||
def heatmap_head__predict(ctx, self, feats, batch_data_samples, test_cfg=None):
|
||||
"""Rewrite `predict` of HeatmapHead for default backend.
|
||||
|
||||
1. skip heatmaps decoding and return heatmaps directly.
|
||||
|
||||
Args:
|
||||
feats (tuple[Tensor]): Input features.
|
||||
batch_data_samples (list[SampleList]): Data samples contain
|
||||
image meta information.
|
||||
test_cfg (ConfigType): test config.
|
||||
|
||||
Returns:
|
||||
output_heatmap (torch.Tensor): Output heatmaps.
|
||||
"""
|
||||
batch_heatmaps = self.forward(feats)
|
||||
return batch_heatmaps
|
|
@ -4,21 +4,20 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmpose.models.heads.heatmap_heads.MSPNHead.predict')
|
||||
def mspn_head__predict(ctx, self, feats, batch_data_samples, test_cfg=None):
|
||||
"""Rewrite `predict` of HeatmapHead for default backend.
|
||||
'mmpose.models.heads.heatmap_heads.CPMHead.forward')
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmpose.models.heads.heatmap_heads.MSPNHead.forward')
|
||||
def mspn_head__forward(ctx, self, feats):
|
||||
"""Rewrite `forward` of MSPNHead and CPMHead for default backend.
|
||||
|
||||
1. skip heatmaps decoding and return heatmaps directly.
|
||||
1. return last stage heatmaps directly.
|
||||
|
||||
Args:
|
||||
feats (tuple[Tensor]): Input features.
|
||||
batch_data_samples (list[SampleList]): Data samples contain
|
||||
image meta information.
|
||||
test_cfg (ConfigType): test config.
|
||||
|
||||
Returns:
|
||||
output_heatmap (torch.Tensor): Output heatmaps.
|
||||
"""
|
||||
msmu_batch_heatmaps = self.forward(feats)
|
||||
msmu_batch_heatmaps = ctx.origin_func(self, feats)
|
||||
batch_heatmaps = msmu_batch_heatmaps[-1]
|
||||
return batch_heatmaps
|
||||
|
|
|
@ -1,28 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmpose.models.heads.regression_heads.regression_head'
|
||||
'.RegressionHead.predict')
|
||||
def regression_head__predict(ctx,
|
||||
self,
|
||||
feats,
|
||||
batch_data_samples,
|
||||
test_cfg=None):
|
||||
"""Rewrite `predict` of RegressionHead for default backend.
|
||||
|
||||
1. skip heatmaps decoding and return heatmaps directly.
|
||||
|
||||
Args:
|
||||
feats (tuple[Tensor]): Input features.
|
||||
batch_data_samples (list[SampleList]): Data samples contain
|
||||
image meta information.
|
||||
test_cfg (ConfigType): test config.
|
||||
|
||||
Returns:
|
||||
output_heatmap (torch.Tensor): Output heatmaps.
|
||||
"""
|
||||
batch_heatmaps = self.forward(feats)
|
||||
return batch_heatmaps
|
|
@ -1,22 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmpose.models.heads.heatmap_heads.SimCCHead.predict')
|
||||
def simcc_head__predict(ctx, self, feats, batch_data_samples, test_cfg=None):
|
||||
"""Rewrite `predict` of HeatmapHead for default backend.
|
||||
|
||||
1. skip decoding and return output tensor directly.
|
||||
|
||||
Args:
|
||||
feats (tuple[Tensor]): Input features.
|
||||
batch_data_samples (list[SampleList]): Data samples contain
|
||||
image meta information.
|
||||
test_cfg (ConfigType): test config.
|
||||
|
||||
Returns:
|
||||
output_heatmap (torch.Tensor): Output heatmaps.
|
||||
"""
|
||||
simcc_x, simcc_y = self.forward(feats)
|
||||
return simcc_x, simcc_y
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from . import base, topdown
|
||||
from . import base
|
||||
|
||||
__all__ = ['base', 'topdown']
|
||||
__all__ = ['base']
|
||||
|
|
|
@ -4,33 +4,18 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmpose.models.pose_estimators.base.BasePoseEstimator.forward')
|
||||
def base_pose_estimator__forward(ctx,
|
||||
self,
|
||||
inputs,
|
||||
data_samples=None,
|
||||
mode='predict',
|
||||
**kwargs):
|
||||
"""Rewrite `forward_test` of TopDown for default backend.'.
|
||||
def base_pose_estimator__forward(ctx, self, inputs, *args, **kwargs):
|
||||
"""Rewrite `forward` of TopDown for default backend.'.
|
||||
|
||||
1. only support mode='predict'.
|
||||
2. create data_samples if necessary
|
||||
1.directly call _forward of subclass.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
self (BasePoseEstimator): The instance of the class Object
|
||||
BasePoseEstimator.
|
||||
inputs (torch.Tensor[NxCxHxW]): Input images.
|
||||
data_samples (SampleList | None): Data samples contain
|
||||
image meta information.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The predicted heatmaps.
|
||||
"""
|
||||
if data_samples is None:
|
||||
from mmpose.structures import PoseDataSample
|
||||
_, c, h, w = [int(_) for _ in inputs.shape]
|
||||
metainfo = dict(
|
||||
img_shape=(h, w, c),
|
||||
input_size=(w, h),
|
||||
heatmap_size=self.cfg.codec.heatmap_size)
|
||||
data_sample = PoseDataSample(metainfo=metainfo)
|
||||
data_samples = [data_sample]
|
||||
|
||||
return self.predict(inputs, data_samples)
|
||||
return self._forward(inputs)
|
||||
|
|
|
@ -1,24 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmpose.models.pose_estimators.topdown.TopdownPoseEstimator.predict')
|
||||
def topdown_pose_estimator__predict(ctx, self, inputs, data_samples, **kwargs):
|
||||
"""Rewrite `predict` of TopdownPoseEstimator for default backend.'.
|
||||
|
||||
1. skip flip_test
|
||||
2. avoid call `add_pred_to_datasample`
|
||||
|
||||
Args:
|
||||
inputs (torch.Tensor[NxCxHxW]): Input images.
|
||||
data_samples (SampleList | None): Data samples contain
|
||||
image meta information.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The predicted heatmaps.
|
||||
"""
|
||||
assert self.with_head, ('The model must have head to perform prediction.')
|
||||
feats = self.extract_feat(inputs)
|
||||
preds = self.head.predict(feats, data_samples, test_cfg=self.test_cfg)
|
||||
return preds
|
|
@ -114,28 +114,26 @@ models:
|
|||
- *pipeline_ts_fp32
|
||||
- *pipeline_pplnn_static_fp32
|
||||
|
||||
# TODO: no hourglass_coco.yml in latest mmpose, enable this later
|
||||
# - name: Hourglass
|
||||
# metafile: configs/body_2d_keypoint/topdown_heatmap/coco/hourglass_coco.yml
|
||||
# model_configs:
|
||||
# - configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_hourglass52_8xb32-210e_coco-256x256.py
|
||||
# pipelines:
|
||||
# - *pipeline_ort_static_fp32
|
||||
# - *pipeline_trt_static_fp32_256x256
|
||||
# - *pipeline_ncnn_static_fp32_256x256
|
||||
# - *pipeline_openvino_static_fp32_256x256
|
||||
- name: Hourglass
|
||||
metafile: configs/body_2d_keypoint/topdown_heatmap/coco/hourglass_coco.yml
|
||||
model_configs:
|
||||
- configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_hourglass52_8xb32-210e_coco-256x256.py
|
||||
pipelines:
|
||||
- *pipeline_ort_static_fp32
|
||||
- *pipeline_trt_static_fp32_256x256
|
||||
- *pipeline_ncnn_static_fp32_256x256
|
||||
- *pipeline_openvino_static_fp32_256x256
|
||||
|
||||
# TODO: no mobilenetv2_coco.yml in latest mmpose, enable this later
|
||||
# - name: SimCC
|
||||
# metafile: configs/body_2d_keypoint/simcc/coco/mobilenetv2_coco.yml
|
||||
# model_configs:
|
||||
# - configs/body_2d_keypoint/simcc/coco/simcc_mobilenetv2_wo-deconv-8xb64-210e_coco-256x192.py
|
||||
# pipelines:
|
||||
# - convert_image: *convert_image
|
||||
# deploy_config: configs/mmpose/pose-detection_simcc_onnxruntime_dynamic.py
|
||||
# - convert_image: *convert_image
|
||||
# deploy_config: configs/mmpose/pose-detection_simcc_tensorrt_dynamic-256x192.py
|
||||
# backend_test: *default_backend_test
|
||||
# sdk_config: configs/mmpose/pose-detection_simcc_sdk_static-256x192.py
|
||||
# - convert_image: *convert_image
|
||||
# deploy_config: configs/mmpose/pose-detection_simcc_ncnn_static-256x192.py
|
||||
- name: SimCC
|
||||
metafile: configs/body_2d_keypoint/simcc/coco/mobilenetv2_coco.yml
|
||||
model_configs:
|
||||
- configs/body_2d_keypoint/simcc/coco/simcc_mobilenetv2_wo-deconv-8xb64-210e_coco-256x192.py
|
||||
pipelines:
|
||||
- convert_image: *convert_image
|
||||
deploy_config: configs/mmpose/pose-detection_simcc_onnxruntime_dynamic.py
|
||||
- convert_image: *convert_image
|
||||
deploy_config: configs/mmpose/pose-detection_simcc_tensorrt_dynamic-256x192.py
|
||||
backend_test: *default_backend_test
|
||||
sdk_config: configs/mmpose/pose-detection_simcc_sdk_static-256x192.py
|
||||
- convert_image: *convert_image
|
||||
deploy_config: configs/mmpose/pose-detection_simcc_ncnn_static-256x192.py
|
||||
|
|
|
@ -30,13 +30,13 @@ def get_heatmap_head():
|
|||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
|
||||
def test_heatmaphead_predict(backend_type: Backend):
|
||||
def test_heatmaphead_forward(backend_type: Backend):
|
||||
check_backend(backend_type, True)
|
||||
model = get_heatmap_head()
|
||||
model.cpu().eval()
|
||||
deploy_cfg = generate_mmpose_deploy_config(backend_type.value)
|
||||
feats = [torch.rand(1, 2, 32, 48)]
|
||||
wrapped_model = WrapModel(model, 'predict', batch_data_samples=None)
|
||||
wrapped_model = WrapModel(model, 'forward')
|
||||
rewrite_inputs = {'feats': feats}
|
||||
rewrite_outputs, _ = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
|
@ -59,13 +59,13 @@ def get_msmu_head():
|
|||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
|
||||
def test_msmuhead_predict(backend_type: Backend):
|
||||
def test_msmuhead_forward(backend_type: Backend):
|
||||
check_backend(backend_type, True)
|
||||
model = get_msmu_head()
|
||||
model.cpu().eval()
|
||||
deploy_cfg = generate_mmpose_deploy_config(backend_type.value)
|
||||
feats = [[torch.rand(1, 16, 32, 48)]]
|
||||
wrapped_model = WrapModel(model, 'predict', batch_data_samples=None)
|
||||
wrapped_model = WrapModel(model, 'forward')
|
||||
rewrite_inputs = {'feats': feats}
|
||||
rewrite_outputs, _ = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
|
|
Loading…
Reference in New Issue