mirror of https://github.com/alibaba/EasyCV.git
feat: add hand keypoints predictor
Link: https://code.alibaba-inc.com/pai-vision/EasyCV/codereview/9935447 * feat: add hand keypoints predictorpull/191/head
parent
2bf3b55655
commit
a5988732cc
|
@ -187,3 +187,4 @@ eval_pipelines = [
|
|||
]
|
||||
export = dict(use_jit=False)
|
||||
checkpoint_sync_export = True
|
||||
predict = dict(type='HandKeypointsPredictor')
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
model = dict(
|
||||
type='SingleStageDetector',
|
||||
backbone=dict(
|
||||
type='MobileNetV2',
|
||||
out_indices=(4, 7),
|
||||
norm_cfg=dict(type='BN', eps=0.001, momentum=0.03),
|
||||
init_cfg=dict(type='TruncNormal', layer='Conv2d', std=0.03)),
|
||||
neck=dict(
|
||||
type='SSDNeck',
|
||||
in_channels=(96, 1280),
|
||||
out_channels=(96, 1280, 512, 256, 256, 128),
|
||||
level_strides=(2, 2, 2, 2),
|
||||
level_paddings=(1, 1, 1, 1),
|
||||
l2_norm_scale=None,
|
||||
use_depthwise=True,
|
||||
norm_cfg=dict(type='BN', eps=0.001, momentum=0.03),
|
||||
act_cfg=dict(type='ReLU6'),
|
||||
init_cfg=dict(type='TruncNormal', layer='Conv2d', std=0.03)),
|
||||
bbox_head=dict(
|
||||
type='SSDHead',
|
||||
in_channels=(96, 1280, 512, 256, 256, 128),
|
||||
num_classes=1,
|
||||
use_depthwise=True,
|
||||
norm_cfg=dict(type='BN', eps=0.001, momentum=0.03),
|
||||
act_cfg=dict(type='ReLU6'),
|
||||
init_cfg=dict(type='Normal', layer='Conv2d', std=0.001),
|
||||
|
||||
# set anchor size manually instead of using the predefined
|
||||
# SSD300 setting.
|
||||
anchor_generator=dict(
|
||||
type='SSDAnchorGenerator',
|
||||
scale_major=False,
|
||||
strides=[16, 32, 64, 107, 160, 320],
|
||||
ratios=[[2, 3], [2, 3], [2, 3], [2, 3], [2, 3], [2, 3]],
|
||||
min_sizes=[48, 100, 150, 202, 253, 304],
|
||||
max_sizes=[100, 150, 202, 253, 304, 320]),
|
||||
bbox_coder=dict(
|
||||
type='DeltaXYWHBBoxCoder',
|
||||
target_means=[.0, .0, .0, .0],
|
||||
target_stds=[0.1, 0.1, 0.2, 0.2])),
|
||||
# model training and testing settings
|
||||
train_cfg=dict(
|
||||
assigner=dict(
|
||||
type='MaxIoUAssigner',
|
||||
pos_iou_thr=0.5,
|
||||
neg_iou_thr=0.5,
|
||||
min_pos_iou=0.,
|
||||
ignore_iof_thr=-1,
|
||||
gt_max_assign_all=False),
|
||||
smoothl1_beta=1.,
|
||||
allowed_border=-1,
|
||||
pos_weight=-1,
|
||||
neg_pos_ratio=3,
|
||||
debug=False),
|
||||
test_cfg=dict(
|
||||
nms_pre=1000,
|
||||
nms=dict(type='nms', iou_threshold=0.45),
|
||||
min_bbox_size=0,
|
||||
score_thr=0.02,
|
||||
max_per_img=200))
|
||||
|
||||
classes = ('hand', )
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
test_pipeline = [
|
||||
dict(
|
||||
type='MMMultiScaleFlipAug',
|
||||
img_scale=(320, 320),
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='MMResize', keep_ratio=False),
|
||||
dict(type='MMNormalize', **img_norm_cfg),
|
||||
dict(type='MMPad', size_divisor=32),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img']),
|
||||
])
|
||||
]
|
||||
load_from = 'https://download.openmmlab.com/mmpose/mmdet_pretrained/' \
|
||||
'ssdlite_mobilenetv2_scratch_600e_onehand-4f9f8686_20220523.pth'
|
||||
mmlab_modules = [
|
||||
dict(type='mmdet', name='SingleStageDetector', module='model'),
|
||||
dict(type='mmdet', name='MobileNetV2', module='backbone'),
|
||||
dict(type='mmdet', name='SSDNeck', module='neck'),
|
||||
dict(type='mmdet', name='SSDHead', module='head'),
|
||||
]
|
||||
predictor = dict(type='DetectionPredictor', score_threshold=0.5)
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c05d58edee7398de37b8e479410676d6b97cfde69cc003e8356a348067e71988
|
||||
size 7750
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8570f45c7e642288b23a1c8722ba2b9b40939f1d55c962d13c789157b16edf01
|
||||
size 117072344
|
|
@ -1,11 +1,12 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from .classifier import TorchClassifier
|
||||
from .detector import (TorchFaceDetector, TorchYoloXClassifierPredictor,
|
||||
TorchYoloXPredictor)
|
||||
from .detector import (DetectionPredictor, TorchFaceDetector,
|
||||
TorchYoloXClassifierPredictor, TorchYoloXPredictor)
|
||||
from .face_keypoints_predictor import FaceKeypointsPredictor
|
||||
from .feature_extractor import (TorchFaceAttrExtractor,
|
||||
TorchFaceFeatureExtractor,
|
||||
TorchFeatureExtractor)
|
||||
from .hand_keypoints_predictor import HandKeypointsPredictor
|
||||
from .pose_predictor import (TorchPoseTopDownPredictor,
|
||||
TorchPoseTopDownPredictorWithDetector)
|
||||
from .segmentation import (Mask2formerPredictor, SegFormerPredictor,
|
||||
|
|
|
@ -14,6 +14,7 @@ from easycv.models.builder import build_model
|
|||
from easycv.utils.checkpoint import load_checkpoint
|
||||
from easycv.utils.config_tools import mmcv_config_fromfile
|
||||
from easycv.utils.constant import CACHE_DIR
|
||||
from easycv.utils.mmlab_utils import dynamic_adapt_for_mmlab
|
||||
from easycv.utils.registry import build_from_cfg
|
||||
|
||||
|
||||
|
@ -151,7 +152,8 @@ class PredictorV2(object):
|
|||
def _build_model(self):
|
||||
if self.cfg is None:
|
||||
raise ValueError('Please provide "config_file"!')
|
||||
|
||||
# Use mmdet model
|
||||
dynamic_adapt_for_mmlab(self.cfg)
|
||||
model = build_model(self.cfg.model)
|
||||
return model
|
||||
|
||||
|
|
|
@ -3,7 +3,6 @@ import json
|
|||
import os
|
||||
from glob import glob
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.ops import RoIPool
|
||||
|
@ -22,6 +21,7 @@ from easycv.utils.config_tools import mmcv_config_fromfile
|
|||
from easycv.utils.constant import CACHE_DIR
|
||||
from easycv.utils.mmlab_utils import dynamic_adapt_for_mmlab
|
||||
from easycv.utils.registry import build_from_cfg
|
||||
from .base import PredictorV2
|
||||
from .builder import PREDICTORS
|
||||
from .classifier import TorchClassifier
|
||||
|
||||
|
@ -36,6 +36,45 @@ except Exception:
|
|||
from easycv.thirdparty.mtcnn import FaceDetector
|
||||
|
||||
|
||||
@PREDICTORS.register_module()
|
||||
class DetectionPredictor(PredictorV2):
|
||||
"""Generic Detection Predictor, it will filter bbox results by ``score_threshold`` .
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_path=None,
|
||||
config_file=None,
|
||||
batch_size=1,
|
||||
device=None,
|
||||
save_results=False,
|
||||
save_path=None,
|
||||
mode='rgb',
|
||||
score_threshold=0.5):
|
||||
super(DetectionPredictor, self).__init__(
|
||||
model_path,
|
||||
config_file=config_file,
|
||||
batch_size=batch_size,
|
||||
device=device,
|
||||
save_results=save_results,
|
||||
save_path=save_path,
|
||||
mode=mode,
|
||||
)
|
||||
self.score_thresh = score_threshold
|
||||
|
||||
def postprocess(self, inputs, *args, **kwargs):
|
||||
for batch_index in range(self.batch_size):
|
||||
this_detection_scores = inputs['detection_scores'][batch_index]
|
||||
sel_ids = this_detection_scores > self.score_thresh
|
||||
inputs['detection_scores'][batch_index] = inputs[
|
||||
'detection_scores'][batch_index][sel_ids]
|
||||
inputs['detection_boxes'][batch_index] = inputs['detection_boxes'][
|
||||
batch_index][sel_ids]
|
||||
inputs['detection_classes'][batch_index] = inputs[
|
||||
'detection_classes'][batch_index][sel_ids]
|
||||
# TODO class label remapping
|
||||
return inputs
|
||||
|
||||
|
||||
@PREDICTORS.register_module()
|
||||
class TorchYoloXPredictor(PredictorInterface):
|
||||
|
||||
|
|
|
@ -0,0 +1,221 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
||||
from easycv.predictors.builder import PREDICTORS, build_predictor
|
||||
from ..datasets.pose.data_sources.hand.coco_hand import \
|
||||
COCO_WHOLEBODY_HAND_DATASET_INFO
|
||||
from ..datasets.pose.data_sources.top_down import DatasetInfo
|
||||
from .base import PredictorV2
|
||||
from .pose_predictor import _box2cs
|
||||
|
||||
HAND_SKELETON = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7],
|
||||
[7, 8], [9, 10], [10, 11], [11, 12], [13, 14], [14, 15],
|
||||
[15, 16], [0, 17], [17, 18], [18, 19], [19, 20], [5, 9],
|
||||
[9, 13], [13, 17]]
|
||||
|
||||
|
||||
@PREDICTORS.register_module()
|
||||
class HandKeypointsPredictor(PredictorV2):
|
||||
"""HandKeypointsPredictor
|
||||
|
||||
Attributes:
|
||||
model_path: path of keypoint model
|
||||
config_file: path or ``Config`` of config file
|
||||
detection_model_config: dict of hand detection model predictor config,
|
||||
example like ``dict(type="", model_path="", config_file="", ......)``
|
||||
batch_size: batch_size to infer
|
||||
save_results: bool
|
||||
save_path: path of result image
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_path,
|
||||
config_file=None,
|
||||
detection_predictor_config=None,
|
||||
batch_size=1,
|
||||
device=None,
|
||||
save_results=False,
|
||||
save_path=None,
|
||||
mode='rgb',
|
||||
*args,
|
||||
**kwargs):
|
||||
super(HandKeypointsPredictor, self).__init__(
|
||||
model_path,
|
||||
config_file=config_file,
|
||||
batch_size=batch_size,
|
||||
device=device,
|
||||
save_results=save_results,
|
||||
save_path=save_path,
|
||||
mode=mode,
|
||||
*args,
|
||||
**kwargs)
|
||||
self.dataset_info = DatasetInfo(COCO_WHOLEBODY_HAND_DATASET_INFO)
|
||||
assert detection_predictor_config is not None, f"{self.__class__.__name__} need 'detection_predictor_config' " \
|
||||
f'property to build hand detection model'
|
||||
self.detection_predictor = build_predictor(detection_predictor_config)
|
||||
|
||||
def _load_input(self, input):
|
||||
""" load img and convert detection result to topdown style
|
||||
|
||||
Args:
|
||||
input (dict):
|
||||
{
|
||||
"inputs": image path,
|
||||
"results": {
|
||||
"detection_boxes": B*ndarray(N*4)
|
||||
"detection_scores": B*ndarray(N,)
|
||||
"detection_classes": B*ndarray(N,)
|
||||
}
|
||||
}
|
||||
"""
|
||||
image_paths = input['inputs']
|
||||
batch_data = []
|
||||
box_id = 0
|
||||
for batch_index, image_path in enumerate(image_paths):
|
||||
det_bbox_result = input['results']['detection_boxes'][batch_index]
|
||||
det_bbox_scores = input['results']['detection_scores'][batch_index]
|
||||
img = mmcv.imread(image_path, 'color', self.mode)
|
||||
for bbox, score in zip(det_bbox_result, det_bbox_scores):
|
||||
center, scale = _box2cs(self.cfg.data_cfg['image_size'], bbox)
|
||||
# prepare data
|
||||
data = {
|
||||
'image_file':
|
||||
image_path,
|
||||
'img':
|
||||
img,
|
||||
'image_id':
|
||||
batch_index,
|
||||
'center':
|
||||
center,
|
||||
'scale':
|
||||
scale,
|
||||
'bbox_score':
|
||||
score,
|
||||
'bbox_id':
|
||||
box_id, # need to be assigned if batch_size > 1
|
||||
'dataset':
|
||||
'coco_wholebody_hand',
|
||||
'joints_3d':
|
||||
np.zeros((self.cfg.data_cfg.num_joints, 3),
|
||||
dtype=np.float32),
|
||||
'joints_3d_visible':
|
||||
np.zeros((self.cfg.data_cfg.num_joints, 3),
|
||||
dtype=np.float32),
|
||||
'rotation':
|
||||
0,
|
||||
'flip_pairs':
|
||||
self.dataset_info.flip_pairs,
|
||||
'ann_info': {
|
||||
'image_size':
|
||||
np.array(self.cfg.data_cfg['image_size']),
|
||||
'num_joints': self.cfg.data_cfg['num_joints']
|
||||
}
|
||||
}
|
||||
batch_data.append(data)
|
||||
box_id += 1
|
||||
return batch_data
|
||||
|
||||
def preprocess_single(self, input):
|
||||
results = []
|
||||
outputs = self._load_input(input)
|
||||
for output in outputs:
|
||||
results.append(self.processor(output))
|
||||
return results
|
||||
|
||||
def preprocess(self, inputs, *args, **kwargs):
|
||||
"""Process all inputs list. And collate to batch and put to target device.
|
||||
If you need custom ops to load or process a batch samples, you need to reimplement it.
|
||||
"""
|
||||
batch_outputs = []
|
||||
for i in inputs:
|
||||
for res in self.preprocess_single(i, *args, **kwargs):
|
||||
batch_outputs.append(res)
|
||||
batch_outputs = self._collate_fn(batch_outputs)
|
||||
batch_outputs = self._to_device(batch_outputs)
|
||||
return batch_outputs
|
||||
|
||||
def postprocess(self, inputs, *args, **kwargs):
|
||||
output = {}
|
||||
output['keypoints'] = inputs['preds']
|
||||
output['boxes'] = inputs['boxes']
|
||||
for i, bbox in enumerate(output['boxes']):
|
||||
center, scale = bbox[:2], bbox[2:4]
|
||||
output['boxes'][i][:4] = bbox_cs2xyxy(center, scale)
|
||||
output['boxes'] = output['boxes'][:, :4]
|
||||
return output
|
||||
|
||||
def __call__(self, inputs, keep_inputs=False):
|
||||
if isinstance(inputs, str):
|
||||
inputs = [inputs]
|
||||
|
||||
results_list = []
|
||||
for i in range(0, len(inputs), self.batch_size):
|
||||
batch = inputs[i:max(len(inputs) - 1, i + self.batch_size)]
|
||||
# hand det and return source image
|
||||
det_results = self.detection_predictor(batch, keep_inputs=True)
|
||||
# hand keypoints
|
||||
batch_outputs = self.preprocess(det_results)
|
||||
batch_outputs = self.forward(batch_outputs)
|
||||
results = self.postprocess(batch_outputs)
|
||||
if keep_inputs:
|
||||
results = {'inputs': batch, 'results': results}
|
||||
# if dump, the outputs will not added to the return value to prevent taking up too much memory
|
||||
if self.save_results:
|
||||
self.dump([results], self.save_path, mode='ab+')
|
||||
else:
|
||||
results_list.append(results)
|
||||
|
||||
return results_list
|
||||
|
||||
def show_result(self,
|
||||
image_path,
|
||||
keypoints,
|
||||
boxes=None,
|
||||
scale=4,
|
||||
save_path=None):
|
||||
"""Draw `result` over `img`.
|
||||
|
||||
Args:
|
||||
image_path (str): filepath of img
|
||||
keypoints (ndarray): N*21*3
|
||||
"""
|
||||
point_color = [120, 225, 240]
|
||||
sk_color = [0, 255, 0]
|
||||
img = mmcv.imread(image_path)
|
||||
img = img.copy()
|
||||
img_h, img_w = img.shape[:2]
|
||||
|
||||
for kpts in keypoints:
|
||||
# point
|
||||
for kid, (x, y, s) in enumerate(kpts):
|
||||
cv2.circle(img, (int(x), int(y)), scale, point_color, -1)
|
||||
# skeleton
|
||||
for sk_id, sk in enumerate(HAND_SKELETON):
|
||||
pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1]))
|
||||
pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1]))
|
||||
|
||||
if (pos1[0] <= 0 or pos1[0] >= img_w or pos1[1] <= 0
|
||||
or pos1[1] >= img_h or pos2[0] <= 0 or pos2[0] >= img_w
|
||||
or pos2[1] <= 0 or pos2[1] >= img_h):
|
||||
# skip the link that should not be drawn
|
||||
continue
|
||||
cv2.line(img, pos1, pos2, sk_color, thickness=1)
|
||||
|
||||
if boxes is not None:
|
||||
bboxes = np.vstack(boxes)
|
||||
mmcv.imshow_bboxes(
|
||||
img, bboxes, colors='green', top_k=-1, thickness=2, show=False)
|
||||
|
||||
if save_path is not None:
|
||||
mmcv.imwrite(img, save_path)
|
||||
return img
|
||||
|
||||
|
||||
def bbox_cs2xyxy(center, scale, padding=1., pixel_std=200.):
|
||||
wh = scale * 0.8 / padding * pixel_std
|
||||
xy = center - 0.5 * wh
|
||||
x1, y1 = xy
|
||||
w, h = wh
|
||||
return np.r_[x1, y1, x1 + w, y1 + h]
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import unittest
|
||||
|
||||
from easycv.predictors.hand_keypoints_predictor import HandKeypointsPredictor
|
||||
from easycv.utils.config_tools import mmcv_config_fromfile
|
||||
|
||||
MM_DEFAULT_HAND_DETECTION_SSDLITE_MODEL_PATH = 'https://download.openmmlab.com/mmpose/mmdet_pretrained/' \
|
||||
'ssdlite_mobilenetv2_scratch_600e_onehand-4f9f8686_20220523.pth'
|
||||
MM_DEFAULT_HAND_DETECTION_SSDLITE_CONFIG_FILE = 'data/test/pose/hand/configs/hand_keypoints_predictor.py'
|
||||
|
||||
|
||||
class HandKeypointsPredictorTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
self.image_path = 'data/test/pose/hand/data/hand.jpg'
|
||||
self.save_image_path = 'data/test/pose/hand/data/hand_result.jpg'
|
||||
self.model_path = 'data/test/pose/hand/hrnet_w18_256x256.pth'
|
||||
self.model_config_path = 'configs/pose/hand/hrnet_w18_coco_wholebody_hand_256x256_dark.py'
|
||||
|
||||
def test_single(self):
|
||||
config = mmcv_config_fromfile(self.model_config_path)
|
||||
predict_pipeline = HandKeypointsPredictor(
|
||||
model_path=self.model_path,
|
||||
config_file=config,
|
||||
detection_predictor_config=dict(
|
||||
type='DetectionPredictor',
|
||||
model_path=MM_DEFAULT_HAND_DETECTION_SSDLITE_MODEL_PATH,
|
||||
config_file=MM_DEFAULT_HAND_DETECTION_SSDLITE_CONFIG_FILE,
|
||||
score_threshold=0.5))
|
||||
|
||||
output = predict_pipeline(self.image_path)[0]
|
||||
keypoints = output['keypoints']
|
||||
boxes = output['boxes']
|
||||
image_show = predict_pipeline.show_result(
|
||||
self.image_path, keypoints, boxes, save_path=self.save_image_path)
|
||||
self.assertEqual(keypoints.shape[0], 1)
|
||||
self.assertEqual(keypoints.shape[1], 21)
|
||||
self.assertEqual(keypoints.shape[2], 3)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in New Issue