Support deploy of YoloX-Pose (#2184)
* dev_mmpose * tide * fix lint * del redundant task and model * fix * test ut * test ut * upload configs * fix * remove debug * fix lint * use mmcv.ops.nms * fix lint * remove loop * debug * test modified ut * fix lint * fix return type * fix * fix rescale * fix * fix pack_result * update batch inference * fix nms and pytorch show_box * fix lint * modify ut * add docstring * modify nms * fix * add openvino config * update docs * fix test_mmpose --------- Co-authored-by: RunningLeon <mnsheng@yeah.net>pull/2223/head
parent
a664f061ff
commit
e19f6fa08d
|
@ -52,6 +52,10 @@ jobs:
|
|||
run: |
|
||||
git clone -b dev --depth 1 https://github.com/open-mmlab/mmyolo.git /home/runner/work/mmyolo
|
||||
python -m pip install -v -e /home/runner/work/mmyolo
|
||||
- name: Install mmpose
|
||||
run: |
|
||||
git clone --depth 1 https://github.com/open-mmlab/mmpose.git /home/runner/work/mmpose
|
||||
python -m pip install -v -e /home/runner/work/mmpose
|
||||
- name: Build and install
|
||||
run: |
|
||||
rm -rf .eggs && python -m pip install -e .
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
_base_ = ['./pose-detection_static.py', '../_base_/backends/onnxruntime.py']
|
||||
|
||||
onnx_config = dict(
|
||||
output_names=['dets', 'keypoints'],
|
||||
dynamic_axes={
|
||||
'input': {
|
||||
0: 'batch',
|
||||
},
|
||||
'dets': {
|
||||
0: 'batch',
|
||||
},
|
||||
'keypoints': {
|
||||
0: 'batch'
|
||||
}
|
||||
})
|
||||
|
||||
codebase_config = dict(
|
||||
post_processing=dict(
|
||||
score_threshold=0.05,
|
||||
iou_threshold=0.5,
|
||||
max_output_boxes_per_class=200,
|
||||
pre_top_k=5000,
|
||||
keep_top_k=100,
|
||||
background_label_id=-1,
|
||||
))
|
|
@ -0,0 +1,27 @@
|
|||
_base_ = ['./pose-detection_static.py', '../_base_/backends/openvino.py']
|
||||
|
||||
onnx_config = dict(
|
||||
output_names=['dets', 'keypoints'],
|
||||
dynamic_axes={
|
||||
'input': {
|
||||
0: 'batch',
|
||||
},
|
||||
'dets': {
|
||||
0: 'batch',
|
||||
},
|
||||
'keypoints': {
|
||||
0: 'batch'
|
||||
}
|
||||
})
|
||||
backend_config = dict(
|
||||
model_inputs=[dict(opt_shapes=dict(input=[1, 3, 640, 640]))])
|
||||
|
||||
codebase_config = dict(
|
||||
post_processing=dict(
|
||||
score_threshold=0.05,
|
||||
iou_threshold=0.5,
|
||||
max_output_boxes_per_class=200,
|
||||
pre_top_k=5000,
|
||||
keep_top_k=100,
|
||||
background_label_id=-1,
|
||||
))
|
|
@ -0,0 +1,35 @@
|
|||
_base_ = ['./pose-detection_static.py', '../_base_/backends/tensorrt.py']
|
||||
|
||||
onnx_config = dict(
|
||||
output_names=['dets', 'keypoints'],
|
||||
dynamic_axes={
|
||||
'input': {
|
||||
0: 'batch',
|
||||
},
|
||||
'dets': {
|
||||
0: 'batch',
|
||||
},
|
||||
'keypoints': {
|
||||
0: 'batch'
|
||||
}
|
||||
})
|
||||
backend_config = dict(
|
||||
common_config=dict(max_workspace_size=1 << 30),
|
||||
model_inputs=[
|
||||
dict(
|
||||
input_shapes=dict(
|
||||
input=dict(
|
||||
min_shape=[1, 3, 640, 640],
|
||||
opt_shape=[1, 3, 640, 640],
|
||||
max_shape=[1, 3, 640, 640])))
|
||||
])
|
||||
|
||||
codebase_config = dict(
|
||||
post_processing=dict(
|
||||
score_threshold=0.05,
|
||||
iou_threshold=0.5,
|
||||
max_output_boxes_per_class=200,
|
||||
pre_top_k=5000,
|
||||
keep_top_k=100,
|
||||
background_label_id=-1,
|
||||
))
|
|
@ -50,7 +50,7 @@ __launch_bounds__(nthds_per_cta) __global__
|
|||
bboxOffset) *
|
||||
5;
|
||||
if (nmsedIndex != nullptr) {
|
||||
nmsedIndex[i] = bboxId / 5;
|
||||
nmsedIndex[i] = bboxId / 5 - bboxOffset;
|
||||
}
|
||||
// clipped bbox xmin
|
||||
nmsedDets[i * 6] =
|
||||
|
@ -74,7 +74,7 @@ __launch_bounds__(nthds_per_cta) __global__
|
|||
bboxOffset) *
|
||||
4;
|
||||
if (nmsedIndex != nullptr) {
|
||||
nmsedIndex[i] = bboxId / 4;
|
||||
nmsedIndex[i] = bboxId / 4 - bboxOffset;
|
||||
}
|
||||
// clipped bbox xmin
|
||||
nmsedDets[i * 5] =
|
||||
|
|
|
@ -160,3 +160,4 @@ TODO
|
|||
| [Hourglass](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/algorithms.html#hourglass-eccv-2016) | PoseDetection | Y | Y | Y | N | Y |
|
||||
| [SimCC](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/algorithms.html#simcc-eccv-2022) | PoseDetection | Y | Y | Y | N | Y |
|
||||
| [RTMPose](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmpose) | PoseDetection | Y | Y | Y | N | Y |
|
||||
| [YoloX-Pose](https://github.com/open-mmlab/mmpose/tree/main/projects/yolox-pose) | PoseDetection | Y | Y | N | N | Y |
|
||||
|
|
|
@ -164,3 +164,4 @@ task_processor.visualize(
|
|||
| [Hourglass](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/algorithms.html#hourglass-eccv-2016) | PoseDetection | Y | Y | Y | N | Y |
|
||||
| [SimCC](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/algorithms.html#simcc-eccv-2022) | PoseDetection | Y | Y | Y | N | Y |
|
||||
| [RTMPose](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmpose) | PoseDetection | Y | Y | Y | N | Y |
|
||||
| [YoloX-Pose](https://github.com/open-mmlab/mmpose/tree/main/projects/yolox-pose) | PoseDetection | Y | Y | N | N | Y |
|
||||
|
|
|
@ -120,6 +120,9 @@ class MMPose(MMCodebase):
|
|||
@classmethod
|
||||
def register_deploy_modules(cls):
|
||||
"""register rewritings."""
|
||||
import mmdeploy.codebase.mmdet.models
|
||||
import mmdeploy.codebase.mmdet.ops
|
||||
import mmdeploy.codebase.mmdet.structures
|
||||
import mmdeploy.codebase.mmpose.models # noqa: F401
|
||||
|
||||
@classmethod
|
||||
|
@ -202,9 +205,11 @@ class PoseDetection(BaseTask):
|
|||
raise AssertionError('imgs must be strings or numpy arrays')
|
||||
elif isinstance(imgs, (np.ndarray, str)):
|
||||
imgs = [imgs]
|
||||
img_path = [imgs]
|
||||
else:
|
||||
raise AssertionError('imgs must be strings or numpy arrays')
|
||||
if isinstance(imgs, (list, tuple)) and isinstance(imgs[0], str):
|
||||
img_path = imgs
|
||||
img_data = [mmcv.imread(img) for img in imgs]
|
||||
imgs = img_data
|
||||
person_results = []
|
||||
|
@ -220,7 +225,7 @@ class PoseDetection(BaseTask):
|
|||
TRANSFORMS.build(c) for c in cfg.test_dataloader.dataset.pipeline
|
||||
]
|
||||
test_pipeline = Compose(test_pipeline)
|
||||
if input_shape is not None:
|
||||
if input_shape is not None and hasattr(cfg, 'codec'):
|
||||
if isinstance(cfg.codec, dict):
|
||||
codec = cfg.codec
|
||||
elif isinstance(cfg.codec, list):
|
||||
|
@ -243,9 +248,15 @@ class PoseDetection(BaseTask):
|
|||
bbox_score = np.array([bbox[4] if len(bbox) == 5 else 1
|
||||
]) # shape (1,)
|
||||
data = {
|
||||
'img': imgs[i],
|
||||
'bbox_score': bbox_score,
|
||||
'bbox': bbox[None], # shape (1, 4)
|
||||
'img':
|
||||
imgs[i],
|
||||
'bbox_score':
|
||||
bbox_score,
|
||||
'bbox': [] if hasattr(cfg.model, 'bbox_head')
|
||||
and cfg.model.bbox_head.type == 'YOLOXPoseHead' else
|
||||
bbox[None],
|
||||
'img_path':
|
||||
img_path[i]
|
||||
}
|
||||
data.update(meta_data)
|
||||
data = test_pipeline(data)
|
||||
|
@ -288,11 +299,17 @@ class PoseDetection(BaseTask):
|
|||
|
||||
if isinstance(image, str):
|
||||
image = mmcv.imread(image, channel_order='rgb')
|
||||
draw_bbox = result.pred_instances.bboxes is not None
|
||||
if draw_bbox and isinstance(result.pred_instances.bboxes,
|
||||
torch.Tensor):
|
||||
result.pred_instances.bboxes = result.pred_instances.bboxes.cpu(
|
||||
).numpy()
|
||||
visualizer.add_datasample(
|
||||
name,
|
||||
image,
|
||||
data_sample=result,
|
||||
draw_gt=False,
|
||||
draw_bbox=draw_bbox,
|
||||
show=show_result,
|
||||
out_file=output_file)
|
||||
|
||||
|
|
|
@ -54,7 +54,8 @@ class End2EndModel(BaseBackendModel):
|
|||
device=device,
|
||||
**kwargs)
|
||||
# create head for decoding heatmap
|
||||
self.head = builder.build_head(model_cfg.model.head)
|
||||
self.head = builder.build_head(model_cfg.model.head) if hasattr(
|
||||
model_cfg.model, 'head') else None
|
||||
|
||||
def _init_wrapper(self, backend: Backend, backend_files: Sequence[str],
|
||||
device: str, **kwargs):
|
||||
|
@ -97,6 +98,9 @@ class End2EndModel(BaseBackendModel):
|
|||
inputs = inputs.contiguous().to(self.device)
|
||||
batch_outputs = self.wrapper({self.input_name: inputs})
|
||||
batch_outputs = self.wrapper.output_to_list(batch_outputs)
|
||||
if self.model_cfg.model.type == 'YOLODetector':
|
||||
return self.pack_yolox_pose_result(batch_outputs, data_samples)
|
||||
|
||||
codec = self.model_cfg.codec
|
||||
if isinstance(codec, (list, tuple)):
|
||||
codec = codec[-1]
|
||||
|
@ -158,6 +162,48 @@ class End2EndModel(BaseBackendModel):
|
|||
|
||||
return data_samples
|
||||
|
||||
def pack_yolox_pose_result(self, preds: List[torch.Tensor],
|
||||
data_samples: List[BaseDataElement]):
|
||||
"""Pack yolox-pose prediction results to mmpose format
|
||||
Args:
|
||||
preds (List[Tensor]): Prediction of bboxes and key-points.
|
||||
data_samples (List[BaseDataElement]): A list of meta info for
|
||||
image(s).
|
||||
Returns:
|
||||
data_samples (List[BaseDataElement]):
|
||||
updated data_samples with predictions.
|
||||
"""
|
||||
assert preds[0].shape[0] == len(data_samples)
|
||||
batched_dets, batched_kpts = preds
|
||||
for data_sample_idx, data_sample in enumerate(data_samples):
|
||||
bboxes = batched_dets[data_sample_idx, :, :4]
|
||||
bbox_scores = batched_dets[data_sample_idx, :, 4]
|
||||
keypoints = batched_kpts[data_sample_idx, :, :, :2]
|
||||
keypoint_scores = batched_kpts[data_sample_idx, :, :, 2]
|
||||
|
||||
# filter zero or negative scores
|
||||
inds = bbox_scores > 0.0
|
||||
bboxes = bboxes[inds, :]
|
||||
bbox_scores = bbox_scores[inds]
|
||||
keypoints = keypoints[inds, :]
|
||||
keypoint_scores = keypoint_scores[inds]
|
||||
|
||||
pred_instances = InstanceData()
|
||||
# rescale
|
||||
scale_factor = data_sample.scale_factor
|
||||
scale_factor = keypoints.new_tensor(scale_factor)
|
||||
keypoints /= keypoints.new_tensor(scale_factor).reshape(1, 1, 2)
|
||||
bboxes /= keypoints.new_tensor(scale_factor).repeat(1, 2)
|
||||
pred_instances.bboxes = bboxes.cpu().numpy()
|
||||
pred_instances.bbox_scores = bbox_scores
|
||||
# the precision test requires keypoints to be np.ndarray
|
||||
pred_instances.keypoints = keypoints.cpu().numpy()
|
||||
pred_instances.keypoint_scores = keypoint_scores
|
||||
pred_instances.lebels = torch.zeros(bboxes.shape[0])
|
||||
|
||||
data_sample.pred_instances = pred_instances
|
||||
return data_samples
|
||||
|
||||
|
||||
@__BACKEND_MODEL.register_module('sdk')
|
||||
class SDKEnd2EndModel(End2EndModel):
|
||||
|
@ -236,8 +282,13 @@ def build_pose_detection_model(
|
|||
if isinstance(data_preprocessor, dict):
|
||||
dp = data_preprocessor.copy()
|
||||
dp_type = dp.pop('type')
|
||||
assert dp_type == 'PoseDataPreprocessor'
|
||||
data_preprocessor = PoseDataPreprocessor(**dp)
|
||||
if dp_type == 'mmdet.DetDataPreprocessor':
|
||||
from mmdet.models.data_preprocessors import DetDataPreprocessor
|
||||
data_preprocessor = DetDataPreprocessor(**dp)
|
||||
else:
|
||||
assert dp_type == 'PoseDataPreprocessor'
|
||||
data_preprocessor = PoseDataPreprocessor(**dp)
|
||||
|
||||
backend_pose_model = __BACKEND_MODEL.build(
|
||||
dict(
|
||||
type=model_type,
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from . import mspn_head
|
||||
from . import mspn_head, yolox_pose_head # noqa: F401,F403
|
||||
|
||||
__all__ = ['mspn_head']
|
||||
__all__ = ['mspn_head', 'yolox_pose_head']
|
||||
|
|
|
@ -0,0 +1,208 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from mmengine.config import ConfigDict
|
||||
from torch import Tensor
|
||||
|
||||
from mmdeploy.codebase.mmdet import get_post_processing_params
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.mmcv.ops.nms import multiclass_nms
|
||||
from mmdeploy.utils import Backend
|
||||
from mmdeploy.utils.config_utils import get_backend_config
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(func_name='models.yolox_pose_head.'
|
||||
'YOLOXPoseHead.predict')
|
||||
def predict(self,
|
||||
x: Tuple[Tensor],
|
||||
batch_data_samples=None,
|
||||
rescale: bool = True):
|
||||
"""Get predictions and transform to bbox and keypoints results.
|
||||
Args:
|
||||
x (Tuple[Tensor]): The input tensor from upstream network.
|
||||
batch_data_samples: Batch image meta info. Defaults to None.
|
||||
rescale: If True, return boxes in original image space.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor]: Predict bbox and keypoint results.
|
||||
- dets (Tensor): Predict bboxes and scores, which is a 3D Tensor,
|
||||
has shape (batch_size, num_instances, 5), the last dimension 5
|
||||
arrange as (x1, y1, x2, y2, score).
|
||||
- pred_kpts (Tensor): Predict keypoints and scores, which is a 4D
|
||||
Tensor, has shape (batch_size, num_instances, num_keypoints, 5),
|
||||
the last dimension 3 arrange as (x, y, score).
|
||||
"""
|
||||
outs = self(x)
|
||||
predictions = self.predict_by_feat(
|
||||
*outs, batch_img_metas=batch_data_samples, rescale=rescale)
|
||||
return predictions
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(func_name='models.yolox_pose_head.'
|
||||
'YOLOXPoseHead.predict_by_feat')
|
||||
def yolox_pose_head__predict_by_feat(
|
||||
self,
|
||||
cls_scores: List[Tensor],
|
||||
bbox_preds: List[Tensor],
|
||||
objectnesses: Optional[List[Tensor]] = None,
|
||||
kpt_preds: Optional[List[Tensor]] = None,
|
||||
vis_preds: Optional[List[Tensor]] = None,
|
||||
batch_img_metas: Optional[List[dict]] = None,
|
||||
cfg: Optional[ConfigDict] = None,
|
||||
rescale: bool = True,
|
||||
with_nms: bool = True) -> Tuple[Tensor]:
|
||||
"""Transform a batch of output features extracted by the head into bbox and
|
||||
keypoint results.
|
||||
|
||||
In addition to the base class method, keypoint predictions are also
|
||||
calculated in this method.
|
||||
|
||||
Args:
|
||||
cls_scores (List[Tensor]): Classification scores for all
|
||||
scale levels, each is a 4D-tensor, has shape
|
||||
(batch_size, num_priors * num_classes, H, W).
|
||||
bbox_preds (List[Tensor]): Box energies / deltas for all
|
||||
scale levels, each is a 4D-tensor, has shape
|
||||
(batch_size, num_priors * 4, H, W).
|
||||
objectnesses (Optional[List[Tensor]]): Score factor for
|
||||
all scale level, each is a 4D-tensor, has shape
|
||||
(batch_size, 1, H, W).
|
||||
kpt_preds (Optional[List[Tensor]]): Keypoints for all
|
||||
scale levels, each is a 4D-tensor, has shape
|
||||
(batch_size, num_keypoints * 2, H, W)
|
||||
vis_preds (Optional[List[Tensor]]): Keypoints scores for
|
||||
all scale levels, each is a 4D-tensor, has shape
|
||||
(batch_size, num_keypoints, H, W)
|
||||
batch_img_metas (Optional[List[dict]]): Batch image meta
|
||||
info. Defaults to None.
|
||||
cfg (Optional[ConfigDict]): Test / postprocessing
|
||||
configuration, if None, test_cfg would be used.
|
||||
Defaults to None.
|
||||
rescale (bool): If True, return boxes in original image space.
|
||||
Defaults to False.
|
||||
with_nms (bool): If True, do nms before return boxes.
|
||||
Defaults to True.
|
||||
Returns:
|
||||
Tuple[Tensor]: Predict bbox and keypoint results.
|
||||
- dets (Tensor): Predict bboxes and scores, which is a 3D Tensor,
|
||||
has shape (batch_size, num_instances, 5), the last dimension 5
|
||||
arrange as (x1, y1, x2, y2, score).
|
||||
- pred_kpts (Tensor): Predict keypoints and scores, which is a 4D
|
||||
Tensor, has shape (batch_size, num_instances, num_keypoints, 5),
|
||||
the last dimension 3 arrange as (x, y, score).
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
deploy_cfg = ctx.cfg
|
||||
dtype = cls_scores[0].dtype
|
||||
device = cls_scores[0].device
|
||||
bbox_decoder = self.bbox_coder.decode
|
||||
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
cfg = self.test_cfg if cfg is None else cfg
|
||||
|
||||
num_imgs = cls_scores[0].shape[0]
|
||||
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
|
||||
|
||||
self.mlvl_priors = self.prior_generator.grid_priors(
|
||||
featmap_sizes, dtype=dtype, device=device)
|
||||
|
||||
flatten_priors = torch.cat(self.mlvl_priors)
|
||||
|
||||
mlvl_strides = [
|
||||
flatten_priors.new_full(
|
||||
(featmap_size[0] * featmap_size[1] * self.num_base_priors, ),
|
||||
stride)
|
||||
for featmap_size, stride in zip(featmap_sizes, self.featmap_strides)
|
||||
]
|
||||
flatten_stride = torch.cat(mlvl_strides)
|
||||
|
||||
# flatten cls_scores, bbox_preds and objectness
|
||||
flatten_cls_scores = [
|
||||
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.num_classes)
|
||||
for cls_score in cls_scores
|
||||
]
|
||||
cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
|
||||
|
||||
flatten_bbox_preds = [
|
||||
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
|
||||
for bbox_pred in bbox_preds
|
||||
]
|
||||
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
|
||||
|
||||
if objectnesses is not None:
|
||||
flatten_objectness = [
|
||||
objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1)
|
||||
for objectness in objectnesses
|
||||
]
|
||||
flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid()
|
||||
cls_scores = cls_scores * (flatten_objectness.unsqueeze(-1))
|
||||
|
||||
scores = cls_scores
|
||||
bboxes = bbox_decoder(flatten_priors[None], flatten_bbox_preds,
|
||||
flatten_stride)
|
||||
|
||||
# deal with key-poinsts
|
||||
priors = torch.cat(self.mlvl_priors)
|
||||
strides = [
|
||||
priors.new_full((featmap_size.numel() * self.num_base_priors, ),
|
||||
stride)
|
||||
for featmap_size, stride in zip(featmap_sizes, self.featmap_strides)
|
||||
]
|
||||
strides = torch.cat(strides)
|
||||
kpt_preds = torch.cat([
|
||||
kpt_pred.permute(0, 2, 3, 1).reshape(
|
||||
num_imgs, -1, self.num_keypoints * 2) for kpt_pred in kpt_preds
|
||||
],
|
||||
dim=1)
|
||||
flatten_decoded_kpts = self.decode_pose(priors, kpt_preds, strides)
|
||||
|
||||
vis_preds = torch.cat([
|
||||
vis_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.num_keypoints,
|
||||
1) for vis_pred in vis_preds
|
||||
],
|
||||
dim=1).sigmoid()
|
||||
|
||||
pred_kpts = torch.cat([flatten_decoded_kpts, vis_preds], dim=3)
|
||||
|
||||
backend_config = get_backend_config(deploy_cfg)
|
||||
if backend_config.type == Backend.TENSORRT.value:
|
||||
# pad
|
||||
bboxes = torch.cat(
|
||||
[bboxes,
|
||||
bboxes.new_zeros((bboxes.shape[0], 1, bboxes.shape[2]))],
|
||||
dim=1)
|
||||
scores = torch.cat(
|
||||
[scores, scores.new_zeros((scores.shape[0], 1, 1))], dim=1)
|
||||
pred_kpts = torch.cat([
|
||||
pred_kpts,
|
||||
pred_kpts.new_zeros((pred_kpts.shape[0], 1, pred_kpts.shape[2],
|
||||
pred_kpts.shape[3]))
|
||||
],
|
||||
dim=1)
|
||||
|
||||
# nms
|
||||
post_params = get_post_processing_params(deploy_cfg)
|
||||
max_output_boxes_per_class = post_params.max_output_boxes_per_class
|
||||
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
|
||||
score_threshold = cfg.get('score_thr', post_params.score_threshold)
|
||||
pre_top_k = post_params.get('pre_top_k', -1)
|
||||
keep_top_k = post_params.get('keep_top_k', -1)
|
||||
# do nms
|
||||
_, _, nms_indices = multiclass_nms(
|
||||
bboxes,
|
||||
scores,
|
||||
max_output_boxes_per_class,
|
||||
iou_threshold,
|
||||
score_threshold,
|
||||
pre_top_k=pre_top_k,
|
||||
keep_top_k=keep_top_k,
|
||||
output_index=True)
|
||||
|
||||
batch_inds = torch.arange(num_imgs, device=scores.device).view(-1, 1)
|
||||
dets = torch.cat([bboxes, scores], dim=2)
|
||||
dets = dets[batch_inds, nms_indices, ...]
|
||||
pred_kpts = pred_kpts[batch_inds, nms_indices, ...]
|
||||
|
||||
return dets, pred_kpts
|
|
@ -186,7 +186,9 @@ def _select_nms_index(scores: torch.Tensor,
|
|||
boxes: torch.Tensor,
|
||||
nms_index: torch.Tensor,
|
||||
batch_size: int,
|
||||
keep_top_k: int = -1):
|
||||
keep_top_k: int = -1,
|
||||
pre_inds: torch.Tensor = None,
|
||||
output_index: bool = False):
|
||||
"""Transform NMS output.
|
||||
|
||||
Args:
|
||||
|
@ -197,6 +199,10 @@ def _select_nms_index(scores: torch.Tensor,
|
|||
batch_size (int): Batch size of the input image.
|
||||
keep_top_k (int): Number of top K boxes to keep after nms.
|
||||
Defaults to -1.
|
||||
pre_inds (Tensor): The pre-topk indices of boxes before nms.
|
||||
Defaults to None.
|
||||
return_index (bool): Whether to return indices of original bboxes.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5]
|
||||
|
@ -230,7 +236,13 @@ def _select_nms_index(scores: torch.Tensor,
|
|||
1)
|
||||
batched_labels = torch.cat((batched_labels, batched_labels.new_zeros(
|
||||
(N, 1))), 1)
|
||||
|
||||
if output_index and pre_inds is not None:
|
||||
# batch all
|
||||
pre_inds = pre_inds[batch_inds, box_inds]
|
||||
pre_inds = pre_inds.unsqueeze(0).repeat(batch_size, 1)
|
||||
pre_inds = pre_inds.where((batch_inds == batch_template.unsqueeze(1)),
|
||||
pre_inds.new_zeros(1))
|
||||
pre_inds = torch.cat((pre_inds, pre_inds.new_zeros((N, 1))), 1)
|
||||
# sort
|
||||
is_use_topk = keep_top_k > 0 and \
|
||||
(torch.onnx.is_in_onnx_export() or keep_top_k < batched_dets.shape[1])
|
||||
|
@ -243,7 +255,11 @@ def _select_nms_index(scores: torch.Tensor,
|
|||
device=topk_inds.device).view(-1, 1)
|
||||
batched_dets = batched_dets[topk_batch_inds, topk_inds, ...]
|
||||
batched_labels = batched_labels[topk_batch_inds, topk_inds, ...]
|
||||
|
||||
if output_index:
|
||||
if pre_inds is not None:
|
||||
topk_inds = pre_inds[topk_batch_inds, topk_inds, ...]
|
||||
topk_inds = topk_inds[:, :-1]
|
||||
return batched_dets, batched_labels, topk_inds
|
||||
# slice and recover the tensor
|
||||
return batched_dets, batched_labels
|
||||
|
||||
|
@ -263,7 +279,6 @@ def _multiclass_nms(boxes: Tensor,
|
|||
shape (N, num_bboxes, num_classes) and the boxes is of shape (N, num_boxes,
|
||||
4).
|
||||
"""
|
||||
assert not output_index, 'output_index is not supported on this backend.'
|
||||
if version.parse(torch.__version__) < version.parse('1.13.0'):
|
||||
max_output_boxes_per_class = torch.LongTensor(
|
||||
[max_output_boxes_per_class])
|
||||
|
@ -274,7 +289,8 @@ def _multiclass_nms(boxes: Tensor,
|
|||
if pre_top_k > 0:
|
||||
max_scores, _ = scores.max(-1)
|
||||
_, topk_inds = max_scores.topk(pre_top_k)
|
||||
batch_inds = torch.arange(batch_size).view(-1, 1).long()
|
||||
batch_inds = torch.arange(
|
||||
batch_size, device=scores.device).view(-1, 1).long()
|
||||
boxes = boxes[batch_inds, topk_inds, :]
|
||||
scores = scores[batch_inds, topk_inds, :]
|
||||
|
||||
|
@ -283,10 +299,14 @@ def _multiclass_nms(boxes: Tensor,
|
|||
max_output_boxes_per_class,
|
||||
iou_threshold, score_threshold)
|
||||
|
||||
dets, labels = _select_nms_index(
|
||||
scores, boxes, selected_indices, batch_size, keep_top_k=keep_top_k)
|
||||
|
||||
return dets, labels
|
||||
return _select_nms_index(
|
||||
scores,
|
||||
boxes,
|
||||
selected_indices,
|
||||
batch_size,
|
||||
keep_top_k=keep_top_k,
|
||||
pre_inds=topk_inds,
|
||||
output_index=output_index)
|
||||
|
||||
|
||||
def _multiclass_nms_single(boxes: Tensor,
|
||||
|
|
|
@ -43,6 +43,7 @@ class Codebase(AdvancedEnum):
|
|||
MMROTATE = 'mmrotate'
|
||||
MMACTION = 'mmaction'
|
||||
MMRAZOR = 'mmrazor'
|
||||
MMYOLO = 'mmyolo'
|
||||
|
||||
|
||||
class IR(AdvancedEnum):
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import mmengine
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
@ -188,3 +189,45 @@ def test_scale_forward(backend_type: Backend):
|
|||
deploy_cfg=deploy_cfg,
|
||||
run_with_backend=False)
|
||||
torch_assert_close(rewrite_outputs, model_outputs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
|
||||
def test_yolox_pose_head(backend_type: Backend):
|
||||
try:
|
||||
from models import yolox_pose_head # noqa: F401,F403
|
||||
except ImportError:
|
||||
pytest.skip(
|
||||
'mmpose/projects/yolox-pose is not installed.',
|
||||
allow_module_level=True)
|
||||
deploy_cfg = mmengine.Config.fromfile(
|
||||
'configs/mmpose/pose-detection_yolox-pose_onnxruntime_dynamic.py')
|
||||
check_backend(backend_type, True)
|
||||
model = yolox_pose_head.YOLOXPoseHead(
|
||||
head_module=dict(
|
||||
type='YOLOXPoseHeadModule',
|
||||
num_classes=1,
|
||||
in_channels=256,
|
||||
feat_channels=256,
|
||||
widen_factor=0.5,
|
||||
stacked_convs=2,
|
||||
num_keypoints=17,
|
||||
featmap_strides=(8, 16, 32),
|
||||
use_depthwise=False,
|
||||
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg=dict(type='SiLU', inplace=True),
|
||||
))
|
||||
model.cpu().eval()
|
||||
model_inputs = [
|
||||
torch.randn(2, 128, 80, 80),
|
||||
torch.randn(2, 128, 40, 40),
|
||||
torch.randn(2, 128, 20, 20)
|
||||
]
|
||||
pytorch_output = model(model_inputs)
|
||||
wrapped_model = WrapModel(model, 'forward')
|
||||
rewrite_inputs = {'inputs': model_inputs}
|
||||
rewrite_outputs, _ = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
run_with_backend=False,
|
||||
deploy_cfg=deploy_cfg)
|
||||
torch_assert_close(rewrite_outputs, pytorch_output)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import mmengine
|
||||
import numpy
|
||||
import torch
|
||||
from mmengine.structures import InstanceData, PixelData
|
||||
|
||||
|
@ -16,7 +17,7 @@ def generate_datasample(img_size, heatmap_size=(64, 48)):
|
|||
input_size=(h, w),
|
||||
heatmap_size=heatmap_size)
|
||||
pred_instances = InstanceData()
|
||||
pred_instances.bboxes = torch.rand((1, 4)).numpy()
|
||||
pred_instances.bboxes = numpy.array([[0.0, 0.0, 1.0, 1.0]])
|
||||
pred_instances.bbox_scales = torch.ones(1, 2).numpy()
|
||||
pred_instances.bbox_scores = torch.ones(1).numpy()
|
||||
pred_instances.bbox_centers = torch.ones(1, 2).numpy()
|
||||
|
|
Loading…
Reference in New Issue