EasyCV/easycv/predictors/pose_predictor.py

528 lines
18 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import json
import mmcv
import numpy as np
import torch
from mmcv.image import imwrite
from mmcv.utils.path import is_filepath
from mmcv.visualization.image import imshow
from easycv.core.visualization import imshow_bboxes, imshow_keypoints
from easycv.datasets.pose.data_sources.top_down import DatasetInfo
from easycv.datasets.pose.pipelines.transforms import bbox_cs2xyxy
from easycv.predictors.builder import PREDICTORS, build_predictor
from easycv.utils.config_tools import mmcv_config_fromfile
from easycv.utils.misc import deprecated
from .base import InputProcessor, OutputProcessor, PredictorV2
def _box2cs(image_size, box):
"""This encodes bbox(x,y,w,h) into (center, scale)
Args:
x, y, w, h
Returns:
tuple: A tuple containing center and scale.
- np.ndarray[float32](2,): Center of the bbox (x, y).
- np.ndarray[float32](2,): Scale of the bbox w & h.
"""
x, y, w, h = box[:4]
aspect_ratio = image_size[0] / image_size[1]
center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32)
if w > aspect_ratio * h:
h = w * 1.0 / aspect_ratio
elif w < aspect_ratio * h:
w = h * aspect_ratio
# pixel std is 200.0
scale = np.array([w / 200.0, h / 200.0], dtype=np.float32)
scale = scale * 1.25
return center, scale
def vis_pose_result(
model,
img,
result,
radius=4,
thickness=1,
kpt_score_thr=0.3,
bbox_color='green',
dataset_info=None,
out_file=None,
pose_kpt_color=None,
pose_link_color=None,
text_color='white',
font_scale=0.5,
bbox_thickness=1,
win_name='',
show=False,
wait_time=0,
):
"""Visualize the detection results on the image.
Args:
model (nn.Module): The loaded detector.
img (str | np.ndarray): Image filename or loaded image.
result (list[dict]): The results to draw over `img`
(bbox_result, pose_result).
radius (int): Radius of circles.
thickness (int): Thickness of lines.
kpt_score_thr (float): The threshold to visualize the keypoints.
skeleton (list[tuple()]): Default None.
out_file (str|None): The filename of the output visualization image.
show (bool): Whether to show the image. Default: False.
wait_time (int): Value of waitKey param.
Default: 0.
out_file (str or None): The filename to write the image.
Default: None.
"""
# get dataset info
if (dataset_info is None and hasattr(model, 'cfg')
and 'dataset_info' in model.cfg):
dataset_info = DatasetInfo(model.cfg.dataset_info)
if not dataset_info:
raise ValueError('Please provide `dataset_info`!')
skeleton = dataset_info.skeleton
pose_kpt_color = dataset_info.pose_kpt_color
pose_link_color = dataset_info.pose_link_color
if hasattr(model, 'module'):
model = model.module
img = mmcv.imread(img)
img = img.copy()
bbox_result = result.get('bbox', [])
pose_result = result['keypoints']
if len(bbox_result) > 0:
bboxes = np.vstack(bbox_result)
labels = None
if 'label' in result:
labels = result['label']
# draw bounding boxes
imshow_bboxes(
img,
bboxes,
labels=labels,
colors=bbox_color,
text_color=text_color,
thickness=bbox_thickness,
font_scale=font_scale,
show=False)
imshow_keypoints(img, pose_result, skeleton, kpt_score_thr, pose_kpt_color,
pose_link_color, radius, thickness)
if show:
imshow(img, win_name, wait_time)
if out_file is not None:
imwrite(img, out_file)
return img
class PoseTopDownInputProcessor(InputProcessor):
def __init__(self,
cfg,
dataset_info,
detection_predictor_config,
bbox_thr=None,
pipelines=None,
batch_size=1,
cat_id=None,
mode='BGR'):
self.detection_predictor = build_predictor(detection_predictor_config)
self.dataset_info = dataset_info
self.bbox_thr = bbox_thr
self.cat_id = cat_id
super().__init__(
cfg,
pipelines=pipelines,
batch_size=batch_size,
threads=1,
mode=mode)
def get_detection_outputs(self, input, cat_id=None):
det_results = self.detection_predictor(input['img'], keep_inputs=False)
person_results = self._process_detection_results(
det_results, cat_id=cat_id)
return person_results
def _process_detection_results(self, det_results, cat_id=None):
"""Process det results, and return a list of bboxes.
Args:
det_results (list|tuple): det results.
cat_id (int | str): category id or name to reserve, if None, reserve all detection results.
Returns:
person_results (list): a list of detected bounding boxes
"""
# Only support one sample/image
if isinstance(det_results, tuple):
det_results = det_results[0]
elif isinstance(det_results, list):
det_results = det_results[0]
else:
det_results = det_results
bboxes = det_results['detection_boxes']
scores = det_results['detection_scores']
classes = det_results['detection_classes']
if cat_id is not None:
if isinstance(cat_id, str):
assert cat_id in self.detection_predictor.cfg.CLASSES, f'cat_id "{cat_id}" not in detection classes list: {self.detection_predictor.cfg.CLASSES}'
assert det_results.get('detection_class_names',
None) is not None
detection_class_names = det_results['detection_class_names']
keeped_ids = [
i for i in range(len(detection_class_names))
if str(detection_class_names[i]) == str(cat_id)
]
else:
keeped_ids = classes == cat_id
bboxes = bboxes[keeped_ids]
scores = scores[keeped_ids]
classes = classes[keeped_ids]
person_results = []
for idx, bbox in enumerate(bboxes):
person = {}
bbox = np.append(bbox, scores[idx])
person['bbox'] = bbox
person_results.append(person)
return person_results
def process_single(self, input):
output = super()._load_input(input)
person_results = self.get_detection_outputs(output, cat_id=self.cat_id)
box_id = 0
# Select bboxes by score threshold
bboxes = np.array([res['bbox'] for res in person_results])
if self.bbox_thr is not None:
assert bboxes.shape[1] == 5
valid_idx = np.where(bboxes[:, 4] > self.bbox_thr)[0]
bboxes = bboxes[valid_idx]
person_results = [person_results[i] for i in valid_idx]
output_person_info = []
for person_result in person_results:
box = person_result['bbox'] # x,y,x,y
box = [box[0], box[1], box[2] - box[0], box[3] - box[1]] # x,y,w,h
center, scale = _box2cs(self.cfg.data_cfg['image_size'], box)
data = {
'image_id':
0,
'center':
center,
'scale':
scale,
'bbox':
box,
'bbox_score':
box[4] if len(box) == 5 else 1,
'bbox_id':
box_id, # need to be assigned if batch_size > 1
'joints_3d':
np.zeros((self.cfg.data_cfg.num_joints, 3), dtype=np.float32),
'joints_3d_visible':
np.zeros((self.cfg.data_cfg.num_joints, 3), dtype=np.float32),
'rotation':
0,
'flip_pairs':
self.dataset_info.flip_pairs,
'ann_info': {
'image_size': np.array(self.cfg.data_cfg['image_size']),
'num_joints': self.cfg.data_cfg['num_joints'],
},
'image_file':
output['filename'],
'img':
output['img'],
'img_shape':
output['img_shape'],
'ori_shape':
output['ori_shape'],
'img_fields':
output['img_fields'],
}
box_id += 1
output_person_info.append(data)
results = []
for output in output_person_info:
results.append(self.processor(output))
return results
def __call__(self, inputs):
"""Process all inputs list. And collate to batch and put to target device.
If you need custom ops to load or process a batch samples, you need to reimplement it.
"""
batch_outputs = []
for inp in inputs:
for res in self.process_single(inp):
batch_outputs.append(res)
if len(batch_outputs) < 1:
return batch_outputs
batch_outputs = self._collate_fn(batch_outputs)
batch_outputs['img_metas']._data = [[
img_meta[i] for img_meta in batch_outputs['img_metas']._data
for i in range(len(img_meta))
]]
return batch_outputs
class PoseTopDownOutputProcessor(OutputProcessor):
def __call__(self, inputs):
output = {}
output['keypoints'] = inputs['preds']
output['bbox'] = inputs['boxes'] # c1, c2, s1, s2, area, core
for i, bbox in enumerate(output['bbox']):
center, scale = bbox[:2], bbox[2:4]
output['bbox'][i][:4] = bbox_cs2xyxy(center, scale)
output['bbox'] = output['bbox'][:, [0, 1, 2, 3, 5]]
return output
@PREDICTORS.register_module()
class PoseTopDownPredictor(PredictorV2):
"""Pose topdown predictor.
Args:
model_path (str): Path of model path.
config_file (Optinal[str]): Config file path for model and processor to init. Defaults to None.
detection_model_config: Dict of person detection model predictor config,
example like ``dict(type="", model_path="", config_file="", ......)``
batch_size (int): Batch size for forward.
bbox_thr (float): Bounding box threshold to filter output results of detection model
cat_id (int | str): Category id or name to filter target objects.
device (str | torch.device): Support str('cuda' or 'cpu') or torch.device, if is None, detect device automatically.
save_results (bool): Whether to save predict results.
save_path (str): File path for saving results, only valid when `save_results` is True.
pipelines (list[dict]): Data pipeline configs.
mode (str): The image mode into the model.
"""
def __init__(self,
model_path,
config_file=None,
detection_predictor_config=None,
batch_size=1,
bbox_thr=None,
cat_id=None,
device=None,
pipelines=None,
save_results=False,
save_path=None,
mode='BGR',
*args,
**kwargs):
assert batch_size == 1, 'Only support batch_size=1 now!'
self.cat_id = cat_id
self.bbox_thr = bbox_thr
self.detection_predictor_config = detection_predictor_config
super(PoseTopDownPredictor, self).__init__(
model_path,
config_file=config_file,
batch_size=batch_size,
device=device,
save_results=save_results,
save_path=save_path,
pipelines=pipelines,
input_processor_threads=1,
mode=mode,
*args,
**kwargs)
if hasattr(self.cfg, 'dataset_info'):
dataset_info = self.cfg.dataset_info
if is_filepath(dataset_info):
cfg = mmcv_config_fromfile(dataset_info)
dataset_info = cfg._cfg_dict['dataset_info']
else:
from easycv.datasets.pose.data_sources.coco import COCO_DATASET_INFO
dataset_info = COCO_DATASET_INFO
self.dataset_info = DatasetInfo(dataset_info)
def model_forward(self, inputs, return_heatmap=False):
with torch.no_grad():
result = self.model(
**inputs, mode='test', return_heatmap=return_heatmap)
return result
def get_input_processor(self):
return PoseTopDownInputProcessor(
cfg=self.cfg,
dataset_info=self.dataset_info,
detection_predictor_config=self.detection_predictor_config,
bbox_thr=self.bbox_thr,
pipelines=self.pipelines,
batch_size=self.batch_size,
cat_id=self.cat_id,
mode=self.mode)
def get_output_processor(self):
return PoseTopDownOutputProcessor()
def show_result(self,
image,
keypoints,
radius=4,
thickness=3,
kpt_score_thr=0.3,
bbox_color='green',
show=False,
save_path=None):
vis_result = vis_pose_result(
self.model,
image,
keypoints,
dataset_info=self.dataset_info,
radius=radius,
thickness=thickness,
kpt_score_thr=kpt_score_thr,
bbox_color=bbox_color,
show=show,
out_file=save_path)
return vis_result
class _TorchPoseTopDownOutputProcessor(PoseTopDownOutputProcessor):
def __call__(self, inputs):
output = super(_TorchPoseTopDownOutputProcessor, self).__call__(inputs)
bbox = output['bbox']
keypoints = output['keypoints']
results = []
for i in range(len(keypoints)):
results.append({'bbox': bbox[i], 'keypoints': keypoints[i]})
return {'pose_results': results}
@deprecated(reason='Please use PoseTopDownPredictor.')
@PREDICTORS.register_module()
class TorchPoseTopDownPredictorWithDetector(PoseTopDownPredictor):
def __init__(
self,
model_path,
model_config={
'pose': {
'bbox_thr': 0.3,
'format': 'xywh'
},
'detection': {
'model_type': None,
'reserved_classes': [],
'score_thresh': 0.0,
}
},
):
"""
init model
Args:
model_path: pose and detection model file path, split with `,`,
make sure the first is pose model, second is detection model
model_config: config string for model to init, in json format
"""
if isinstance(model_config, str):
model_config = json.loads(model_config)
reserved_classes = model_config['detection'].pop(
'reserved_classes', [])
if len(reserved_classes) == 0:
reserved_classes = None
else:
assert len(reserved_classes) == 1
reserved_classes = reserved_classes[0]
model_list = model_path.split(',')
assert len(model_list) == 2
# first is pose model, second is detection model
pose_model_path, detection_model_path = model_list
detection_model_type = model_config['detection'].pop('model_type')
if detection_model_type == 'TorchYoloXPredictor':
detection_predictor_config = dict(
type=detection_model_type,
model_path=detection_model_path,
model_config=model_config['detection'])
else:
detection_predictor_config = dict(
model_path=detection_model_path, **model_config['detection'])
pose_kwargs = model_config['pose']
pose_kwargs.pop('format', None)
super().__init__(
model_path=pose_model_path,
detection_predictor_config=detection_predictor_config,
cat_id=reserved_classes,
**pose_kwargs,
)
def get_output_processor(self):
return _TorchPoseTopDownOutputProcessor()
def show_result(self,
image_path,
keypoints,
radius=4,
thickness=1,
kpt_score_thr=0.3,
bbox_color='green',
show=False,
save_path=None):
dataset_info = self.dataset_info
# get dataset info
if (dataset_info is None and hasattr(self.model, 'cfg')
and 'dataset_info' in self.model.cfg):
dataset_info = DatasetInfo(self.model.cfg.dataset_info)
if not dataset_info:
raise ValueError('Please provide `dataset_info`!')
skeleton = dataset_info.skeleton
pose_kpt_color = dataset_info.pose_kpt_color
pose_link_color = dataset_info.pose_link_color
if hasattr(self.model, 'module'):
self.model = self.model.module
img = self.model.show_result(
image_path,
keypoints,
skeleton,
radius=radius,
thickness=thickness,
pose_kpt_color=pose_kpt_color,
pose_link_color=pose_link_color,
kpt_score_thr=kpt_score_thr,
bbox_color=bbox_color,
show=show,
out_file=save_path)
return img