Support deploy of YoloX-Pose (#2184)

* dev_mmpose

* tide

* fix lint

* del redundant task and model

* fix

* test ut

* test ut

* upload configs

* fix

* remove debug

* fix lint

* use mmcv.ops.nms

* fix lint

* remove loop

* debug

* test modified ut

* fix lint

* fix return type

* fix

* fix rescale

* fix

* fix pack_result

* update batch inference

* fix nms and pytorch show_box

* fix lint

* modify ut

* add docstring

* modify nms

* fix

* add openvino config

* update docs

* fix test_mmpose

---------

Co-authored-by: RunningLeon <mnsheng@yeah.net>
pull/2223/head
huayuan4396 2023-06-28 19:17:36 +08:00 committed by GitHub
parent a664f061ff
commit e19f6fa08d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 455 additions and 21 deletions

View File

@ -52,6 +52,10 @@ jobs:
run: |
git clone -b dev --depth 1 https://github.com/open-mmlab/mmyolo.git /home/runner/work/mmyolo
python -m pip install -v -e /home/runner/work/mmyolo
- name: Install mmpose
run: |
git clone --depth 1 https://github.com/open-mmlab/mmpose.git /home/runner/work/mmpose
python -m pip install -v -e /home/runner/work/mmpose
- name: Build and install
run: |
rm -rf .eggs && python -m pip install -e .

View File

@ -0,0 +1,25 @@
_base_ = ['./pose-detection_static.py', '../_base_/backends/onnxruntime.py']
onnx_config = dict(
output_names=['dets', 'keypoints'],
dynamic_axes={
'input': {
0: 'batch',
},
'dets': {
0: 'batch',
},
'keypoints': {
0: 'batch'
}
})
codebase_config = dict(
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=5000,
keep_top_k=100,
background_label_id=-1,
))

View File

@ -0,0 +1,27 @@
_base_ = ['./pose-detection_static.py', '../_base_/backends/openvino.py']
onnx_config = dict(
output_names=['dets', 'keypoints'],
dynamic_axes={
'input': {
0: 'batch',
},
'dets': {
0: 'batch',
},
'keypoints': {
0: 'batch'
}
})
backend_config = dict(
model_inputs=[dict(opt_shapes=dict(input=[1, 3, 640, 640]))])
codebase_config = dict(
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=5000,
keep_top_k=100,
background_label_id=-1,
))

View File

@ -0,0 +1,35 @@
_base_ = ['./pose-detection_static.py', '../_base_/backends/tensorrt.py']
onnx_config = dict(
output_names=['dets', 'keypoints'],
dynamic_axes={
'input': {
0: 'batch',
},
'dets': {
0: 'batch',
},
'keypoints': {
0: 'batch'
}
})
backend_config = dict(
common_config=dict(max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 640, 640],
opt_shape=[1, 3, 640, 640],
max_shape=[1, 3, 640, 640])))
])
codebase_config = dict(
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=5000,
keep_top_k=100,
background_label_id=-1,
))

View File

@ -50,7 +50,7 @@ __launch_bounds__(nthds_per_cta) __global__
bboxOffset) *
5;
if (nmsedIndex != nullptr) {
nmsedIndex[i] = bboxId / 5;
nmsedIndex[i] = bboxId / 5 - bboxOffset;
}
// clipped bbox xmin
nmsedDets[i * 6] =
@ -74,7 +74,7 @@ __launch_bounds__(nthds_per_cta) __global__
bboxOffset) *
4;
if (nmsedIndex != nullptr) {
nmsedIndex[i] = bboxId / 4;
nmsedIndex[i] = bboxId / 4 - bboxOffset;
}
// clipped bbox xmin
nmsedDets[i * 5] =

View File

@ -160,3 +160,4 @@ TODO
| [Hourglass](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/algorithms.html#hourglass-eccv-2016) | PoseDetection | Y | Y | Y | N | Y |
| [SimCC](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/algorithms.html#simcc-eccv-2022) | PoseDetection | Y | Y | Y | N | Y |
| [RTMPose](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmpose) | PoseDetection | Y | Y | Y | N | Y |
| [YoloX-Pose](https://github.com/open-mmlab/mmpose/tree/main/projects/yolox-pose) | PoseDetection | Y | Y | N | N | Y |

View File

@ -164,3 +164,4 @@ task_processor.visualize(
| [Hourglass](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/algorithms.html#hourglass-eccv-2016) | PoseDetection | Y | Y | Y | N | Y |
| [SimCC](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/algorithms.html#simcc-eccv-2022) | PoseDetection | Y | Y | Y | N | Y |
| [RTMPose](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmpose) | PoseDetection | Y | Y | Y | N | Y |
| [YoloX-Pose](https://github.com/open-mmlab/mmpose/tree/main/projects/yolox-pose) | PoseDetection | Y | Y | N | N | Y |

View File

@ -120,6 +120,9 @@ class MMPose(MMCodebase):
@classmethod
def register_deploy_modules(cls):
"""register rewritings."""
import mmdeploy.codebase.mmdet.models
import mmdeploy.codebase.mmdet.ops
import mmdeploy.codebase.mmdet.structures
import mmdeploy.codebase.mmpose.models # noqa: F401
@classmethod
@ -202,9 +205,11 @@ class PoseDetection(BaseTask):
raise AssertionError('imgs must be strings or numpy arrays')
elif isinstance(imgs, (np.ndarray, str)):
imgs = [imgs]
img_path = [imgs]
else:
raise AssertionError('imgs must be strings or numpy arrays')
if isinstance(imgs, (list, tuple)) and isinstance(imgs[0], str):
img_path = imgs
img_data = [mmcv.imread(img) for img in imgs]
imgs = img_data
person_results = []
@ -220,7 +225,7 @@ class PoseDetection(BaseTask):
TRANSFORMS.build(c) for c in cfg.test_dataloader.dataset.pipeline
]
test_pipeline = Compose(test_pipeline)
if input_shape is not None:
if input_shape is not None and hasattr(cfg, 'codec'):
if isinstance(cfg.codec, dict):
codec = cfg.codec
elif isinstance(cfg.codec, list):
@ -243,9 +248,15 @@ class PoseDetection(BaseTask):
bbox_score = np.array([bbox[4] if len(bbox) == 5 else 1
]) # shape (1,)
data = {
'img': imgs[i],
'bbox_score': bbox_score,
'bbox': bbox[None], # shape (1, 4)
'img':
imgs[i],
'bbox_score':
bbox_score,
'bbox': [] if hasattr(cfg.model, 'bbox_head')
and cfg.model.bbox_head.type == 'YOLOXPoseHead' else
bbox[None],
'img_path':
img_path[i]
}
data.update(meta_data)
data = test_pipeline(data)
@ -288,11 +299,17 @@ class PoseDetection(BaseTask):
if isinstance(image, str):
image = mmcv.imread(image, channel_order='rgb')
draw_bbox = result.pred_instances.bboxes is not None
if draw_bbox and isinstance(result.pred_instances.bboxes,
torch.Tensor):
result.pred_instances.bboxes = result.pred_instances.bboxes.cpu(
).numpy()
visualizer.add_datasample(
name,
image,
data_sample=result,
draw_gt=False,
draw_bbox=draw_bbox,
show=show_result,
out_file=output_file)

View File

@ -54,7 +54,8 @@ class End2EndModel(BaseBackendModel):
device=device,
**kwargs)
# create head for decoding heatmap
self.head = builder.build_head(model_cfg.model.head)
self.head = builder.build_head(model_cfg.model.head) if hasattr(
model_cfg.model, 'head') else None
def _init_wrapper(self, backend: Backend, backend_files: Sequence[str],
device: str, **kwargs):
@ -97,6 +98,9 @@ class End2EndModel(BaseBackendModel):
inputs = inputs.contiguous().to(self.device)
batch_outputs = self.wrapper({self.input_name: inputs})
batch_outputs = self.wrapper.output_to_list(batch_outputs)
if self.model_cfg.model.type == 'YOLODetector':
return self.pack_yolox_pose_result(batch_outputs, data_samples)
codec = self.model_cfg.codec
if isinstance(codec, (list, tuple)):
codec = codec[-1]
@ -158,6 +162,48 @@ class End2EndModel(BaseBackendModel):
return data_samples
def pack_yolox_pose_result(self, preds: List[torch.Tensor],
data_samples: List[BaseDataElement]):
"""Pack yolox-pose prediction results to mmpose format
Args:
preds (List[Tensor]): Prediction of bboxes and key-points.
data_samples (List[BaseDataElement]): A list of meta info for
image(s).
Returns:
data_samples (List[BaseDataElement])
updated data_samples with predictions.
"""
assert preds[0].shape[0] == len(data_samples)
batched_dets, batched_kpts = preds
for data_sample_idx, data_sample in enumerate(data_samples):
bboxes = batched_dets[data_sample_idx, :, :4]
bbox_scores = batched_dets[data_sample_idx, :, 4]
keypoints = batched_kpts[data_sample_idx, :, :, :2]
keypoint_scores = batched_kpts[data_sample_idx, :, :, 2]
# filter zero or negative scores
inds = bbox_scores > 0.0
bboxes = bboxes[inds, :]
bbox_scores = bbox_scores[inds]
keypoints = keypoints[inds, :]
keypoint_scores = keypoint_scores[inds]
pred_instances = InstanceData()
# rescale
scale_factor = data_sample.scale_factor
scale_factor = keypoints.new_tensor(scale_factor)
keypoints /= keypoints.new_tensor(scale_factor).reshape(1, 1, 2)
bboxes /= keypoints.new_tensor(scale_factor).repeat(1, 2)
pred_instances.bboxes = bboxes.cpu().numpy()
pred_instances.bbox_scores = bbox_scores
# the precision test requires keypoints to be np.ndarray
pred_instances.keypoints = keypoints.cpu().numpy()
pred_instances.keypoint_scores = keypoint_scores
pred_instances.lebels = torch.zeros(bboxes.shape[0])
data_sample.pred_instances = pred_instances
return data_samples
@__BACKEND_MODEL.register_module('sdk')
class SDKEnd2EndModel(End2EndModel):
@ -236,8 +282,13 @@ def build_pose_detection_model(
if isinstance(data_preprocessor, dict):
dp = data_preprocessor.copy()
dp_type = dp.pop('type')
assert dp_type == 'PoseDataPreprocessor'
data_preprocessor = PoseDataPreprocessor(**dp)
if dp_type == 'mmdet.DetDataPreprocessor':
from mmdet.models.data_preprocessors import DetDataPreprocessor
data_preprocessor = DetDataPreprocessor(**dp)
else:
assert dp_type == 'PoseDataPreprocessor'
data_preprocessor = PoseDataPreprocessor(**dp)
backend_pose_model = __BACKEND_MODEL.build(
dict(
type=model_type,

View File

@ -1,4 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import mspn_head
from . import mspn_head, yolox_pose_head # noqa: F401,F403
__all__ = ['mspn_head']
__all__ = ['mspn_head', 'yolox_pose_head']

View File

@ -0,0 +1,208 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple
import torch
from mmengine.config import ConfigDict
from torch import Tensor
from mmdeploy.codebase.mmdet import get_post_processing_params
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.mmcv.ops.nms import multiclass_nms
from mmdeploy.utils import Backend
from mmdeploy.utils.config_utils import get_backend_config
@FUNCTION_REWRITER.register_rewriter(func_name='models.yolox_pose_head.'
'YOLOXPoseHead.predict')
def predict(self,
x: Tuple[Tensor],
batch_data_samples=None,
rescale: bool = True):
"""Get predictions and transform to bbox and keypoints results.
Args:
x (Tuple[Tensor]): The input tensor from upstream network.
batch_data_samples: Batch image meta info. Defaults to None.
rescale: If True, return boxes in original image space.
Defaults to False.
Returns:
Tuple[Tensor]: Predict bbox and keypoint results.
- dets (Tensor): Predict bboxes and scores, which is a 3D Tensor,
has shape (batch_size, num_instances, 5), the last dimension 5
arrange as (x1, y1, x2, y2, score).
- pred_kpts (Tensor): Predict keypoints and scores, which is a 4D
Tensor, has shape (batch_size, num_instances, num_keypoints, 5),
the last dimension 3 arrange as (x, y, score).
"""
outs = self(x)
predictions = self.predict_by_feat(
*outs, batch_img_metas=batch_data_samples, rescale=rescale)
return predictions
@FUNCTION_REWRITER.register_rewriter(func_name='models.yolox_pose_head.'
'YOLOXPoseHead.predict_by_feat')
def yolox_pose_head__predict_by_feat(
self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
objectnesses: Optional[List[Tensor]] = None,
kpt_preds: Optional[List[Tensor]] = None,
vis_preds: Optional[List[Tensor]] = None,
batch_img_metas: Optional[List[dict]] = None,
cfg: Optional[ConfigDict] = None,
rescale: bool = True,
with_nms: bool = True) -> Tuple[Tensor]:
"""Transform a batch of output features extracted by the head into bbox and
keypoint results.
In addition to the base class method, keypoint predictions are also
calculated in this method.
Args:
cls_scores (List[Tensor]): Classification scores for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * num_classes, H, W).
bbox_preds (List[Tensor]): Box energies / deltas for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * 4, H, W).
objectnesses (Optional[List[Tensor]]): Score factor for
all scale level, each is a 4D-tensor, has shape
(batch_size, 1, H, W).
kpt_preds (Optional[List[Tensor]]): Keypoints for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_keypoints * 2, H, W)
vis_preds (Optional[List[Tensor]]): Keypoints scores for
all scale levels, each is a 4D-tensor, has shape
(batch_size, num_keypoints, H, W)
batch_img_metas (Optional[List[dict]]): Batch image meta
info. Defaults to None.
cfg (Optional[ConfigDict]): Test / postprocessing
configuration, if None, test_cfg would be used.
Defaults to None.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
with_nms (bool): If True, do nms before return boxes.
Defaults to True.
Returns:
Tuple[Tensor]: Predict bbox and keypoint results.
- dets (Tensor): Predict bboxes and scores, which is a 3D Tensor,
has shape (batch_size, num_instances, 5), the last dimension 5
arrange as (x1, y1, x2, y2, score).
- pred_kpts (Tensor): Predict keypoints and scores, which is a 4D
Tensor, has shape (batch_size, num_instances, num_keypoints, 5),
the last dimension 3 arrange as (x, y, score).
"""
ctx = FUNCTION_REWRITER.get_context()
deploy_cfg = ctx.cfg
dtype = cls_scores[0].dtype
device = cls_scores[0].device
bbox_decoder = self.bbox_coder.decode
assert len(cls_scores) == len(bbox_preds)
cfg = self.test_cfg if cfg is None else cfg
num_imgs = cls_scores[0].shape[0]
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
self.mlvl_priors = self.prior_generator.grid_priors(
featmap_sizes, dtype=dtype, device=device)
flatten_priors = torch.cat(self.mlvl_priors)
mlvl_strides = [
flatten_priors.new_full(
(featmap_size[0] * featmap_size[1] * self.num_base_priors, ),
stride)
for featmap_size, stride in zip(featmap_sizes, self.featmap_strides)
]
flatten_stride = torch.cat(mlvl_strides)
# flatten cls_scores, bbox_preds and objectness
flatten_cls_scores = [
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.num_classes)
for cls_score in cls_scores
]
cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
flatten_bbox_preds = [
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
for bbox_pred in bbox_preds
]
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
if objectnesses is not None:
flatten_objectness = [
objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1)
for objectness in objectnesses
]
flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid()
cls_scores = cls_scores * (flatten_objectness.unsqueeze(-1))
scores = cls_scores
bboxes = bbox_decoder(flatten_priors[None], flatten_bbox_preds,
flatten_stride)
# deal with key-poinsts
priors = torch.cat(self.mlvl_priors)
strides = [
priors.new_full((featmap_size.numel() * self.num_base_priors, ),
stride)
for featmap_size, stride in zip(featmap_sizes, self.featmap_strides)
]
strides = torch.cat(strides)
kpt_preds = torch.cat([
kpt_pred.permute(0, 2, 3, 1).reshape(
num_imgs, -1, self.num_keypoints * 2) for kpt_pred in kpt_preds
],
dim=1)
flatten_decoded_kpts = self.decode_pose(priors, kpt_preds, strides)
vis_preds = torch.cat([
vis_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.num_keypoints,
1) for vis_pred in vis_preds
],
dim=1).sigmoid()
pred_kpts = torch.cat([flatten_decoded_kpts, vis_preds], dim=3)
backend_config = get_backend_config(deploy_cfg)
if backend_config.type == Backend.TENSORRT.value:
# pad
bboxes = torch.cat(
[bboxes,
bboxes.new_zeros((bboxes.shape[0], 1, bboxes.shape[2]))],
dim=1)
scores = torch.cat(
[scores, scores.new_zeros((scores.shape[0], 1, 1))], dim=1)
pred_kpts = torch.cat([
pred_kpts,
pred_kpts.new_zeros((pred_kpts.shape[0], 1, pred_kpts.shape[2],
pred_kpts.shape[3]))
],
dim=1)
# nms
post_params = get_post_processing_params(deploy_cfg)
max_output_boxes_per_class = post_params.max_output_boxes_per_class
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
score_threshold = cfg.get('score_thr', post_params.score_threshold)
pre_top_k = post_params.get('pre_top_k', -1)
keep_top_k = post_params.get('keep_top_k', -1)
# do nms
_, _, nms_indices = multiclass_nms(
bboxes,
scores,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
pre_top_k=pre_top_k,
keep_top_k=keep_top_k,
output_index=True)
batch_inds = torch.arange(num_imgs, device=scores.device).view(-1, 1)
dets = torch.cat([bboxes, scores], dim=2)
dets = dets[batch_inds, nms_indices, ...]
pred_kpts = pred_kpts[batch_inds, nms_indices, ...]
return dets, pred_kpts

View File

@ -186,7 +186,9 @@ def _select_nms_index(scores: torch.Tensor,
boxes: torch.Tensor,
nms_index: torch.Tensor,
batch_size: int,
keep_top_k: int = -1):
keep_top_k: int = -1,
pre_inds: torch.Tensor = None,
output_index: bool = False):
"""Transform NMS output.
Args:
@ -197,6 +199,10 @@ def _select_nms_index(scores: torch.Tensor,
batch_size (int): Batch size of the input image.
keep_top_k (int): Number of top K boxes to keep after nms.
Defaults to -1.
pre_inds (Tensor): The pre-topk indices of boxes before nms.
Defaults to None.
return_index (bool): Whether to return indices of original bboxes.
Defaults to False.
Returns:
tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5]
@ -230,7 +236,13 @@ def _select_nms_index(scores: torch.Tensor,
1)
batched_labels = torch.cat((batched_labels, batched_labels.new_zeros(
(N, 1))), 1)
if output_index and pre_inds is not None:
# batch all
pre_inds = pre_inds[batch_inds, box_inds]
pre_inds = pre_inds.unsqueeze(0).repeat(batch_size, 1)
pre_inds = pre_inds.where((batch_inds == batch_template.unsqueeze(1)),
pre_inds.new_zeros(1))
pre_inds = torch.cat((pre_inds, pre_inds.new_zeros((N, 1))), 1)
# sort
is_use_topk = keep_top_k > 0 and \
(torch.onnx.is_in_onnx_export() or keep_top_k < batched_dets.shape[1])
@ -243,7 +255,11 @@ def _select_nms_index(scores: torch.Tensor,
device=topk_inds.device).view(-1, 1)
batched_dets = batched_dets[topk_batch_inds, topk_inds, ...]
batched_labels = batched_labels[topk_batch_inds, topk_inds, ...]
if output_index:
if pre_inds is not None:
topk_inds = pre_inds[topk_batch_inds, topk_inds, ...]
topk_inds = topk_inds[:, :-1]
return batched_dets, batched_labels, topk_inds
# slice and recover the tensor
return batched_dets, batched_labels
@ -263,7 +279,6 @@ def _multiclass_nms(boxes: Tensor,
shape (N, num_bboxes, num_classes) and the boxes is of shape (N, num_boxes,
4).
"""
assert not output_index, 'output_index is not supported on this backend.'
if version.parse(torch.__version__) < version.parse('1.13.0'):
max_output_boxes_per_class = torch.LongTensor(
[max_output_boxes_per_class])
@ -274,7 +289,8 @@ def _multiclass_nms(boxes: Tensor,
if pre_top_k > 0:
max_scores, _ = scores.max(-1)
_, topk_inds = max_scores.topk(pre_top_k)
batch_inds = torch.arange(batch_size).view(-1, 1).long()
batch_inds = torch.arange(
batch_size, device=scores.device).view(-1, 1).long()
boxes = boxes[batch_inds, topk_inds, :]
scores = scores[batch_inds, topk_inds, :]
@ -283,10 +299,14 @@ def _multiclass_nms(boxes: Tensor,
max_output_boxes_per_class,
iou_threshold, score_threshold)
dets, labels = _select_nms_index(
scores, boxes, selected_indices, batch_size, keep_top_k=keep_top_k)
return dets, labels
return _select_nms_index(
scores,
boxes,
selected_indices,
batch_size,
keep_top_k=keep_top_k,
pre_inds=topk_inds,
output_index=output_index)
def _multiclass_nms_single(boxes: Tensor,

View File

@ -43,6 +43,7 @@ class Codebase(AdvancedEnum):
MMROTATE = 'mmrotate'
MMACTION = 'mmaction'
MMRAZOR = 'mmrazor'
MMYOLO = 'mmyolo'
class IR(AdvancedEnum):

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmengine
import pytest
import torch
@ -188,3 +189,45 @@ def test_scale_forward(backend_type: Backend):
deploy_cfg=deploy_cfg,
run_with_backend=False)
torch_assert_close(rewrite_outputs, model_outputs)
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
def test_yolox_pose_head(backend_type: Backend):
try:
from models import yolox_pose_head # noqa: F401,F403
except ImportError:
pytest.skip(
'mmpose/projects/yolox-pose is not installed.',
allow_module_level=True)
deploy_cfg = mmengine.Config.fromfile(
'configs/mmpose/pose-detection_yolox-pose_onnxruntime_dynamic.py')
check_backend(backend_type, True)
model = yolox_pose_head.YOLOXPoseHead(
head_module=dict(
type='YOLOXPoseHeadModule',
num_classes=1,
in_channels=256,
feat_channels=256,
widen_factor=0.5,
stacked_convs=2,
num_keypoints=17,
featmap_strides=(8, 16, 32),
use_depthwise=False,
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
act_cfg=dict(type='SiLU', inplace=True),
))
model.cpu().eval()
model_inputs = [
torch.randn(2, 128, 80, 80),
torch.randn(2, 128, 40, 40),
torch.randn(2, 128, 20, 20)
]
pytorch_output = model(model_inputs)
wrapped_model = WrapModel(model, 'forward')
rewrite_inputs = {'inputs': model_inputs}
rewrite_outputs, _ = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
run_with_backend=False,
deploy_cfg=deploy_cfg)
torch_assert_close(rewrite_outputs, pytorch_output)

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmengine
import numpy
import torch
from mmengine.structures import InstanceData, PixelData
@ -16,7 +17,7 @@ def generate_datasample(img_size, heatmap_size=(64, 48)):
input_size=(h, w),
heatmap_size=heatmap_size)
pred_instances = InstanceData()
pred_instances.bboxes = torch.rand((1, 4)).numpy()
pred_instances.bboxes = numpy.array([[0.0, 0.0, 1.0, 1.0]])
pred_instances.bbox_scales = torch.ones(1, 2).numpy()
pred_instances.bbox_scores = torch.ones(1).numpy()
pred_instances.bbox_centers = torch.ones(1, 2).numpy()