mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
新增coco-wholebody-hand数据集,新增pck auc epe nme评价指标
Link: https://code.alibaba-inc.com/pai-vision/EasyCV/codereview/9790242
This commit is contained in:
parent
0f74adb848
commit
2bf7c9f6ff
176
configs/pose/hand/litehrnet_30_coco_wholebody_hand_256x256.py
Normal file
176
configs/pose/hand/litehrnet_30_coco_wholebody_hand_256x256.py
Normal file
@ -0,0 +1,176 @@
|
||||
# oss_io_config = dict(
|
||||
# ak_id='your oss ak id',
|
||||
# ak_secret='your oss ak secret',
|
||||
# hosts='oss-cn-zhangjiakou.aliyuncs.com', # your oss hosts
|
||||
# buckets=['your_bucket']) # your oss buckets
|
||||
|
||||
oss_sync_config = dict(other_file_list=['**/events.out.tfevents*', '**/*log*'])
|
||||
|
||||
log_level = 'INFO'
|
||||
load_from = None
|
||||
resume_from = None
|
||||
dist_params = dict(backend='nccl')
|
||||
workflow = [('train', 1)]
|
||||
checkpoint_config = dict(interval=10)
|
||||
|
||||
optimizer = dict(type='Adam', lr=5e-4)
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
# learning policy
|
||||
lr_config = dict(
|
||||
policy='step',
|
||||
warmup='linear',
|
||||
warmup_iters=500,
|
||||
warmup_ratio=0.001,
|
||||
step=[170, 200])
|
||||
total_epochs = 210
|
||||
log_config = dict(
|
||||
interval=50,
|
||||
hooks=[dict(type='TextLoggerHook'),
|
||||
dict(type='TensorboardLoggerHook')])
|
||||
channel_cfg = dict(
|
||||
num_output_channels=21,
|
||||
dataset_joints=21,
|
||||
dataset_channel=[
|
||||
[
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
|
||||
19, 20
|
||||
],
|
||||
],
|
||||
inference_channel=[
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
|
||||
20
|
||||
])
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='TopDown',
|
||||
pretrained=False,
|
||||
backbone=dict(
|
||||
type='LiteHRNet',
|
||||
in_channels=3,
|
||||
extra=dict(
|
||||
stem=dict(stem_channels=32, out_channels=32, expand_ratio=1),
|
||||
num_stages=3,
|
||||
stages_spec=dict(
|
||||
num_modules=(3, 8, 3),
|
||||
num_branches=(2, 3, 4),
|
||||
num_blocks=(2, 2, 2),
|
||||
module_type=('LITE', 'LITE', 'LITE'),
|
||||
with_fuse=(True, True, True),
|
||||
reduce_ratios=(8, 8, 8),
|
||||
num_channels=(
|
||||
(40, 80),
|
||||
(40, 80, 160),
|
||||
(40, 80, 160, 320),
|
||||
)),
|
||||
with_head=True,
|
||||
)),
|
||||
keypoint_head=dict(
|
||||
type='TopdownHeatmapSimpleHead',
|
||||
in_channels=40,
|
||||
out_channels=channel_cfg['num_output_channels'],
|
||||
num_deconv_layers=0,
|
||||
extra=dict(final_conv_kernel=1, ),
|
||||
loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
|
||||
train_cfg=dict(),
|
||||
test_cfg=dict(
|
||||
flip_test=True,
|
||||
post_process='default',
|
||||
shift_heatmap=True,
|
||||
modulate_kernel=11))
|
||||
|
||||
data_root = 'data/coco'
|
||||
|
||||
data_cfg = dict(
|
||||
image_size=[256, 256],
|
||||
heatmap_size=[64, 64],
|
||||
num_output_channels=channel_cfg['num_output_channels'],
|
||||
num_joints=channel_cfg['dataset_joints'],
|
||||
dataset_channel=channel_cfg['dataset_channel'],
|
||||
inference_channel=channel_cfg['inference_channel'],
|
||||
)
|
||||
|
||||
train_pipeline = [
|
||||
# dict(type='TopDownGetBboxCenterScale', padding=1.25),
|
||||
dict(type='TopDownRandomFlip', flip_prob=0.5),
|
||||
dict(
|
||||
type='TopDownGetRandomScaleRotation', rot_factor=30,
|
||||
scale_factor=0.25),
|
||||
dict(type='TopDownAffine'),
|
||||
dict(type='MMToTensor'),
|
||||
dict(
|
||||
type='NormalizeTensor',
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225]),
|
||||
dict(type='TopDownGenerateTarget', sigma=3),
|
||||
dict(
|
||||
type='PoseCollect',
|
||||
keys=['img', 'target', 'target_weight'],
|
||||
meta_keys=[
|
||||
'image_file', 'image_id', 'joints_3d', 'joints_3d_visible',
|
||||
'center', 'scale', 'rotation', 'flip_pairs'
|
||||
])
|
||||
]
|
||||
|
||||
val_pipeline = [
|
||||
dict(type='TopDownAffine'),
|
||||
dict(type='MMToTensor'),
|
||||
dict(
|
||||
type='NormalizeTensor',
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225]),
|
||||
dict(
|
||||
type='PoseCollect',
|
||||
keys=['img'],
|
||||
meta_keys=[
|
||||
'image_file', 'image_id', 'center', 'scale', 'rotation',
|
||||
'flip_pairs'
|
||||
])
|
||||
]
|
||||
|
||||
test_pipeline = val_pipeline
|
||||
data_source_cfg = dict(type='HandCocoPoseTopDownSource', data_cfg=data_cfg)
|
||||
|
||||
data = dict(
|
||||
imgs_per_gpu=32, # for train
|
||||
workers_per_gpu=2, # for train
|
||||
# imgs_per_gpu=1, # for test
|
||||
# workers_per_gpu=1, # for test
|
||||
val_dataloader=dict(samples_per_gpu=32),
|
||||
test_dataloader=dict(samples_per_gpu=32),
|
||||
train=dict(
|
||||
type='HandCocoWholeBodyDataset',
|
||||
data_source=dict(
|
||||
ann_file=f'{data_root}/annotations/coco_wholebody_train_v1.0.json',
|
||||
img_prefix=f'{data_root}/train2017/',
|
||||
**data_source_cfg),
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type='HandCocoWholeBodyDataset',
|
||||
data_source=dict(
|
||||
ann_file=f'{data_root}/annotations/coco_wholebody_val_v1.0.json',
|
||||
img_prefix=f'{data_root}/val2017/',
|
||||
test_mode=True,
|
||||
**data_source_cfg),
|
||||
pipeline=val_pipeline),
|
||||
test=dict(
|
||||
type='HandCocoWholeBodyDataset',
|
||||
data_source=dict(
|
||||
ann_file=f'{data_root}/annotations/coco_wholebody_val_v1.0.json',
|
||||
img_prefix=f'{data_root}/val2017/',
|
||||
test_mode=True,
|
||||
**data_source_cfg),
|
||||
pipeline=val_pipeline),
|
||||
)
|
||||
|
||||
eval_config = dict(interval=10, metric='PCK', save_best='PCK')
|
||||
evaluator_args = dict(
|
||||
metric_names=['PCK', 'AUC', 'EPE', 'NME'], pck_thr=0.2, auc_nor=30)
|
||||
eval_pipelines = [
|
||||
dict(
|
||||
mode='test',
|
||||
data=dict(**data['val'], imgs_per_gpu=1),
|
||||
evaluators=[dict(type='KeyPointEvaluator', **evaluator_args)])
|
||||
]
|
||||
export = dict(use_jit=False)
|
||||
checkpoint_sync_export = True
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6c8207a06044306b0d271488a22e1a174af5a22e951a710e25a556cf5d212d5c
|
||||
size 160632
|
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:feadc69a8190787088fda0ac12971d91badc93dbe06057645050fdbec1ce6911
|
||||
size 204232
|
@ -4,8 +4,10 @@ from .base_evaluator import Evaluator
|
||||
from .classification_eval import ClsEvaluator
|
||||
from .coco_evaluation import CocoDetectionEvaluator, CoCoPoseTopDownEvaluator
|
||||
from .faceid_pair_eval import FaceIDPairEvaluator
|
||||
from .keypoint_eval import KeyPointEvaluator
|
||||
from .mse_eval import MSEEvaluator
|
||||
from .retrival_topk_eval import RetrivalTopKEvaluator
|
||||
from .segmentation_eval import SegmentationEvaluator
|
||||
from .top_down_eval import (keypoint_pck_accuracy, keypoints_from_heatmaps,
|
||||
from .top_down_eval import (keypoint_auc, keypoint_epe, keypoint_nme,
|
||||
keypoint_pck_accuracy, keypoints_from_heatmaps,
|
||||
pose_pck_accuracy)
|
||||
|
123
easycv/core/evaluation/keypoint_eval.py
Normal file
123
easycv/core/evaluation/keypoint_eval.py
Normal file
@ -0,0 +1,123 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Adapt from
|
||||
# https://github.com/open-mmlab/mmpose/blob/master/mmpose/datasets/datasets/base/kpt_2d_sview_rgb_img_top_down_dataset.py
|
||||
import numpy as np
|
||||
|
||||
from .base_evaluator import Evaluator
|
||||
from .builder import EVALUATORS
|
||||
from .metric_registry import METRICS
|
||||
from .top_down_eval import (keypoint_auc, keypoint_epe, keypoint_nme,
|
||||
keypoint_pck_accuracy)
|
||||
|
||||
|
||||
@EVALUATORS.register_module
|
||||
class KeyPointEvaluator(Evaluator):
|
||||
""" KeyPoint evaluator.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset_name=None,
|
||||
metric_names=['PCK', 'PCKh', 'AUC', 'EPE', 'NME'],
|
||||
pck_thr=0.2,
|
||||
pckh_thr=0.7,
|
||||
auc_nor=30):
|
||||
"""
|
||||
|
||||
Args:
|
||||
dataset_name: eval dataset name
|
||||
metric_names: eval metrics name
|
||||
pck_thr (float): PCK threshold, default as 0.2.
|
||||
pckh_thr (float): PCKh threshold, default as 0.7.
|
||||
auc_nor (float): AUC normalization factor, default as 30 pixel.
|
||||
"""
|
||||
super(KeyPointEvaluator, self).__init__(dataset_name, metric_names)
|
||||
self._pck_thr = pck_thr
|
||||
self._pckh_thr = pckh_thr
|
||||
self._auc_nor = auc_nor
|
||||
self.dataset_name = dataset_name
|
||||
allowed_metrics = ['PCK', 'PCKh', 'AUC', 'EPE', 'NME']
|
||||
for metric in metric_names:
|
||||
if metric not in allowed_metrics:
|
||||
raise KeyError(f'metric {metric} is not supported')
|
||||
|
||||
def _evaluate_impl(self, preds, coco_db, **kwargs):
|
||||
''' keypoint evaluation code which will be run after
|
||||
all test batched data are predicted
|
||||
|
||||
Args:
|
||||
preds: dict with key ``keypoints`` whose shape is Nx3
|
||||
coco_db: the db of wholebody coco datasource, sorted by 'bbox_id'
|
||||
|
||||
Return:
|
||||
a dict, each key is metric_name, value is metric value
|
||||
'''
|
||||
assert len(preds) == len(coco_db)
|
||||
eval_res = {}
|
||||
|
||||
outputs = []
|
||||
gts = []
|
||||
masks = []
|
||||
box_sizes = []
|
||||
threshold_bbox = []
|
||||
threshold_head_box = []
|
||||
|
||||
for pred, item in zip(preds, coco_db):
|
||||
outputs.append(np.array(pred['keypoints'])[:, :-1])
|
||||
gts.append(np.array(item['joints_3d'])[:, :-1])
|
||||
masks.append((np.array(item['joints_3d_visible'])[:, 0]) > 0)
|
||||
if 'PCK' in self.metric_names:
|
||||
bbox = np.array(item['bbox'])
|
||||
bbox_thr = np.max(bbox[2:])
|
||||
threshold_bbox.append(np.array([bbox_thr, bbox_thr]))
|
||||
if 'PCKh' in self.metric_names:
|
||||
head_box_thr = item['head_size']
|
||||
threshold_head_box.append(
|
||||
np.array([head_box_thr, head_box_thr]))
|
||||
box_sizes.append(item.get('box_size', 1))
|
||||
|
||||
outputs = np.array(outputs)
|
||||
gts = np.array(gts)
|
||||
masks = np.array(masks)
|
||||
threshold_bbox = np.array(threshold_bbox)
|
||||
threshold_head_box = np.array(threshold_head_box)
|
||||
box_sizes = np.array(box_sizes).reshape([-1, 1])
|
||||
|
||||
if 'PCK' in self.metric_names:
|
||||
_, pck, _ = keypoint_pck_accuracy(outputs, gts, masks,
|
||||
self._pck_thr, threshold_bbox)
|
||||
eval_res['PCK'] = pck
|
||||
|
||||
if 'PCKh' in self.metric_names:
|
||||
_, pckh, _ = keypoint_pck_accuracy(outputs, gts, masks,
|
||||
self._pckh_thr,
|
||||
threshold_head_box)
|
||||
eval_res['PCKh'] = pckh
|
||||
|
||||
if 'AUC' in self.metric_names:
|
||||
eval_res['AUC'] = keypoint_auc(outputs, gts, masks, self._auc_nor)
|
||||
|
||||
if 'EPE' in self.metric_names:
|
||||
eval_res['EPE'] = keypoint_epe(outputs, gts, masks)
|
||||
|
||||
if 'NME' in self.metric_names:
|
||||
normalize_factor = self._get_normalize_factor(
|
||||
gts=gts, box_sizes=box_sizes)
|
||||
eval_res['NME'] = keypoint_nme(outputs, gts, masks,
|
||||
normalize_factor)
|
||||
return eval_res
|
||||
|
||||
def _get_normalize_factor(self, gts, *args, **kwargs):
|
||||
"""Get the normalize factor. generally inter-ocular distance measured
|
||||
as the Euclidean distance between the outer corners of the eyes is
|
||||
used. This function should be overrode, to measure NME.
|
||||
|
||||
Args:
|
||||
gts (np.ndarray[N, K, 2]): Groundtruth keypoint location.
|
||||
|
||||
Returns:
|
||||
np.ndarray[N, 2]: normalized factor
|
||||
"""
|
||||
return np.ones([gts.shape[0], 2], dtype=np.float32)
|
||||
|
||||
|
||||
METRICS.register_default_best_metric(KeyPointEvaluator, 'PCK', 'max')
|
@ -178,6 +178,86 @@ def keypoint_pck_accuracy(pred, gt, mask, thr, normalize):
|
||||
return acc, avg_acc, cnt
|
||||
|
||||
|
||||
def keypoint_auc(pred, gt, mask, normalize, num_step=20):
|
||||
"""Calculate the pose accuracy of PCK for each individual keypoint and the
|
||||
averaged accuracy across all keypoints for coordinates.
|
||||
|
||||
Note:
|
||||
- batch_size: N
|
||||
- num_keypoints: K
|
||||
|
||||
Args:
|
||||
pred (np.ndarray[N, K, 2]): Predicted keypoint location.
|
||||
gt (np.ndarray[N, K, 2]): Groundtruth keypoint location.
|
||||
mask (np.ndarray[N, K]): Visibility of the target. False for invisible
|
||||
joints, and True for visible. Invisible joints will be ignored for
|
||||
accuracy calculation.
|
||||
normalize (float): Normalization factor.
|
||||
|
||||
Returns:
|
||||
float: Area under curve.
|
||||
"""
|
||||
nor = np.tile(np.array([[normalize, normalize]]), (pred.shape[0], 1))
|
||||
x = [1.0 * i / num_step for i in range(num_step)]
|
||||
y = []
|
||||
for thr in x:
|
||||
_, avg_acc, _ = keypoint_pck_accuracy(pred, gt, mask, thr, nor)
|
||||
y.append(avg_acc)
|
||||
|
||||
auc = 0
|
||||
for i in range(num_step):
|
||||
auc += 1.0 / num_step * y[i]
|
||||
return auc
|
||||
|
||||
|
||||
def keypoint_nme(pred, gt, mask, normalize_factor):
|
||||
"""Calculate the normalized mean error (NME).
|
||||
|
||||
Note:
|
||||
- batch_size: N
|
||||
- num_keypoints: K
|
||||
|
||||
Args:
|
||||
pred (np.ndarray[N, K, 2]): Predicted keypoint location.
|
||||
gt (np.ndarray[N, K, 2]): Groundtruth keypoint location.
|
||||
mask (np.ndarray[N, K]): Visibility of the target. False for invisible
|
||||
joints, and True for visible. Invisible joints will be ignored for
|
||||
accuracy calculation.
|
||||
normalize_factor (np.ndarray[N, 2]): Normalization factor.
|
||||
|
||||
Returns:
|
||||
float: normalized mean error
|
||||
"""
|
||||
distances = _calc_distances(pred, gt, mask, normalize_factor)
|
||||
distance_valid = distances[distances != -1]
|
||||
return distance_valid.sum() / max(1, len(distance_valid))
|
||||
|
||||
|
||||
def keypoint_epe(pred, gt, mask):
|
||||
"""Calculate the end-point error.
|
||||
|
||||
Note:
|
||||
- batch_size: N
|
||||
- num_keypoints: K
|
||||
|
||||
Args:
|
||||
pred (np.ndarray[N, K, 2]): Predicted keypoint location.
|
||||
gt (np.ndarray[N, K, 2]): Groundtruth keypoint location.
|
||||
mask (np.ndarray[N, K]): Visibility of the target. False for invisible
|
||||
joints, and True for visible. Invisible joints will be ignored for
|
||||
accuracy calculation.
|
||||
|
||||
Returns:
|
||||
float: Average end-point error.
|
||||
"""
|
||||
|
||||
distances = _calc_distances(
|
||||
pred, gt, mask,
|
||||
np.ones((pred.shape[0], pred.shape[2]), dtype=np.float32))
|
||||
distance_valid = distances[distances != -1]
|
||||
return distance_valid.sum() / max(1, len(distance_valid))
|
||||
|
||||
|
||||
def _taylor(heatmap, coord):
|
||||
"""Distribution aware coordinate decoding method.
|
||||
|
||||
|
@ -83,7 +83,7 @@ def fliplr_regression(regression,
|
||||
|
||||
allowed_center_mode = {'static', 'root'}
|
||||
assert center_mode in allowed_center_mode, 'Get invalid center_mode ' \
|
||||
f'{center_mode}, allowed choices are {allowed_center_mode}'
|
||||
f'{center_mode}, allowed choices are {allowed_center_mode}'
|
||||
|
||||
if center_mode == 'static':
|
||||
x_c = center_x
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from . import data_sources # pylint: disable=unused-import
|
||||
from . import pipelines # pylint: disable=unused-import
|
||||
from .hand_coco_wholebody_dataset import HandCocoWholeBodyDataset
|
||||
from .top_down import PoseTopDownDataset
|
||||
|
||||
__all__ = ['PoseTopDownDataset']
|
||||
__all__ = ['PoseTopDownDataset', 'HandCocoWholeBodyDataset']
|
||||
|
@ -1,5 +1,8 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from .coco import PoseTopDownSourceCoco
|
||||
from .hand import HandCocoPoseTopDownSource
|
||||
from .top_down import PoseTopDownSource
|
||||
|
||||
__all__ = ['PoseTopDownSourceCoco', 'PoseTopDownSource']
|
||||
__all__ = [
|
||||
'PoseTopDownSourceCoco', 'PoseTopDownSource', 'HandCocoPoseTopDownSource'
|
||||
]
|
||||
|
3
easycv/datasets/pose/data_sources/hand/__init__.py
Normal file
3
easycv/datasets/pose/data_sources/hand/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
# !/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
from .coco_hand import HandCocoPoseTopDownSource
|
276
easycv/datasets/pose/data_sources/hand/coco_hand.py
Normal file
276
easycv/datasets/pose/data_sources/hand/coco_hand.py
Normal file
@ -0,0 +1,276 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Adapt from
|
||||
# https://github.com/open-mmlab/mmpose/blob/master/mmpose/datasets/datasets/hand/hand_coco_wholebody_dataset.py
|
||||
import logging
|
||||
import os.path as osp
|
||||
|
||||
import numpy as np
|
||||
|
||||
from easycv.datasets.registry import DATASOURCES
|
||||
from ..top_down import PoseTopDownSource
|
||||
|
||||
COCO_WHOLEBODY_HAND_DATASET_INFO = dict(
|
||||
dataset_name='coco_wholebody_hand',
|
||||
paper_info=dict(
|
||||
author='Jin, Sheng and Xu, Lumin and Xu, Jin and '
|
||||
'Wang, Can and Liu, Wentao and '
|
||||
'Qian, Chen and Ouyang, Wanli and Luo, Ping',
|
||||
title='Whole-Body Human Pose Estimation in the Wild',
|
||||
container='Proceedings of the European '
|
||||
'Conference on Computer Vision (ECCV)',
|
||||
year='2020',
|
||||
homepage='https://github.com/jin-s13/COCO-WholeBody/',
|
||||
),
|
||||
keypoint_info={
|
||||
0:
|
||||
dict(name='wrist', id=0, color=[255, 255, 255], type='', swap=''),
|
||||
1:
|
||||
dict(name='thumb1', id=1, color=[255, 128, 0], type='', swap=''),
|
||||
2:
|
||||
dict(name='thumb2', id=2, color=[255, 128, 0], type='', swap=''),
|
||||
3:
|
||||
dict(name='thumb3', id=3, color=[255, 128, 0], type='', swap=''),
|
||||
4:
|
||||
dict(name='thumb4', id=4, color=[255, 128, 0], type='', swap=''),
|
||||
5:
|
||||
dict(
|
||||
name='forefinger1', id=5, color=[255, 153, 255], type='', swap=''),
|
||||
6:
|
||||
dict(
|
||||
name='forefinger2', id=6, color=[255, 153, 255], type='', swap=''),
|
||||
7:
|
||||
dict(
|
||||
name='forefinger3', id=7, color=[255, 153, 255], type='', swap=''),
|
||||
8:
|
||||
dict(
|
||||
name='forefinger4', id=8, color=[255, 153, 255], type='', swap=''),
|
||||
9:
|
||||
dict(
|
||||
name='middle_finger1',
|
||||
id=9,
|
||||
color=[102, 178, 255],
|
||||
type='',
|
||||
swap=''),
|
||||
10:
|
||||
dict(
|
||||
name='middle_finger2',
|
||||
id=10,
|
||||
color=[102, 178, 255],
|
||||
type='',
|
||||
swap=''),
|
||||
11:
|
||||
dict(
|
||||
name='middle_finger3',
|
||||
id=11,
|
||||
color=[102, 178, 255],
|
||||
type='',
|
||||
swap=''),
|
||||
12:
|
||||
dict(
|
||||
name='middle_finger4',
|
||||
id=12,
|
||||
color=[102, 178, 255],
|
||||
type='',
|
||||
swap=''),
|
||||
13:
|
||||
dict(
|
||||
name='ring_finger1', id=13, color=[255, 51, 51], type='', swap=''),
|
||||
14:
|
||||
dict(
|
||||
name='ring_finger2', id=14, color=[255, 51, 51], type='', swap=''),
|
||||
15:
|
||||
dict(
|
||||
name='ring_finger3', id=15, color=[255, 51, 51], type='', swap=''),
|
||||
16:
|
||||
dict(
|
||||
name='ring_finger4', id=16, color=[255, 51, 51], type='', swap=''),
|
||||
17:
|
||||
dict(name='pinky_finger1', id=17, color=[0, 255, 0], type='', swap=''),
|
||||
18:
|
||||
dict(name='pinky_finger2', id=18, color=[0, 255, 0], type='', swap=''),
|
||||
19:
|
||||
dict(name='pinky_finger3', id=19, color=[0, 255, 0], type='', swap=''),
|
||||
20:
|
||||
dict(name='pinky_finger4', id=20, color=[0, 255, 0], type='', swap='')
|
||||
},
|
||||
skeleton_info={
|
||||
0:
|
||||
dict(link=('wrist', 'thumb1'), id=0, color=[255, 128, 0]),
|
||||
1:
|
||||
dict(link=('thumb1', 'thumb2'), id=1, color=[255, 128, 0]),
|
||||
2:
|
||||
dict(link=('thumb2', 'thumb3'), id=2, color=[255, 128, 0]),
|
||||
3:
|
||||
dict(link=('thumb3', 'thumb4'), id=3, color=[255, 128, 0]),
|
||||
4:
|
||||
dict(link=('wrist', 'forefinger1'), id=4, color=[255, 153, 255]),
|
||||
5:
|
||||
dict(link=('forefinger1', 'forefinger2'), id=5, color=[255, 153, 255]),
|
||||
6:
|
||||
dict(link=('forefinger2', 'forefinger3'), id=6, color=[255, 153, 255]),
|
||||
7:
|
||||
dict(link=('forefinger3', 'forefinger4'), id=7, color=[255, 153, 255]),
|
||||
8:
|
||||
dict(link=('wrist', 'middle_finger1'), id=8, color=[102, 178, 255]),
|
||||
9:
|
||||
dict(
|
||||
link=('middle_finger1', 'middle_finger2'),
|
||||
id=9,
|
||||
color=[102, 178, 255]),
|
||||
10:
|
||||
dict(
|
||||
link=('middle_finger2', 'middle_finger3'),
|
||||
id=10,
|
||||
color=[102, 178, 255]),
|
||||
11:
|
||||
dict(
|
||||
link=('middle_finger3', 'middle_finger4'),
|
||||
id=11,
|
||||
color=[102, 178, 255]),
|
||||
12:
|
||||
dict(link=('wrist', 'ring_finger1'), id=12, color=[255, 51, 51]),
|
||||
13:
|
||||
dict(
|
||||
link=('ring_finger1', 'ring_finger2'), id=13, color=[255, 51, 51]),
|
||||
14:
|
||||
dict(
|
||||
link=('ring_finger2', 'ring_finger3'), id=14, color=[255, 51, 51]),
|
||||
15:
|
||||
dict(
|
||||
link=('ring_finger3', 'ring_finger4'), id=15, color=[255, 51, 51]),
|
||||
16:
|
||||
dict(link=('wrist', 'pinky_finger1'), id=16, color=[0, 255, 0]),
|
||||
17:
|
||||
dict(
|
||||
link=('pinky_finger1', 'pinky_finger2'), id=17, color=[0, 255, 0]),
|
||||
18:
|
||||
dict(
|
||||
link=('pinky_finger2', 'pinky_finger3'), id=18, color=[0, 255, 0]),
|
||||
19:
|
||||
dict(
|
||||
link=('pinky_finger3', 'pinky_finger4'), id=19, color=[0, 255, 0])
|
||||
},
|
||||
joint_weights=[1.] * 21,
|
||||
sigmas=[
|
||||
0.029, 0.022, 0.035, 0.037, 0.047, 0.026, 0.025, 0.024, 0.035, 0.018,
|
||||
0.024, 0.022, 0.026, 0.017, 0.021, 0.021, 0.032, 0.02, 0.019, 0.022,
|
||||
0.031
|
||||
])
|
||||
|
||||
|
||||
@DATASOURCES.register_module()
|
||||
class HandCocoPoseTopDownSource(PoseTopDownSource):
|
||||
"""Coco Whole-Body-Hand Source for top-down hand pose estimation.
|
||||
|
||||
"Whole-Body Human Pose Estimation in the Wild", ECCV'2020.
|
||||
More details can be found in the `paper
|
||||
<https://arxiv.org/abs/2007.11858>`__ .
|
||||
|
||||
The dataset loads raw features and apply specified transforms
|
||||
to return a dict containing the image tensors and other information.
|
||||
|
||||
COCO-WholeBody Hand keypoint indexes::
|
||||
|
||||
0: 'wrist',
|
||||
1: 'thumb1',
|
||||
2: 'thumb2',
|
||||
3: 'thumb3',
|
||||
4: 'thumb4',
|
||||
5: 'forefinger1',
|
||||
6: 'forefinger2',
|
||||
7: 'forefinger3',
|
||||
8: 'forefinger4',
|
||||
9: 'middle_finger1',
|
||||
10: 'middle_finger2',
|
||||
11: 'middle_finger3',
|
||||
12: 'middle_finger4',
|
||||
13: 'ring_finger1',
|
||||
14: 'ring_finger2',
|
||||
15: 'ring_finger3',
|
||||
16: 'ring_finger4',
|
||||
17: 'pinky_finger1',
|
||||
18: 'pinky_finger2',
|
||||
19: 'pinky_finger3',
|
||||
20: 'pinky_finger4'
|
||||
|
||||
Args:
|
||||
ann_file (str): Path to the annotation file.
|
||||
img_prefix (str): Path to a directory where images are held.
|
||||
Default: None.
|
||||
data_cfg (dict): config
|
||||
dataset_info (DatasetInfo): A class containing all dataset info.
|
||||
test_mode (bool): Store True when building test or
|
||||
validation dataset. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ann_file,
|
||||
img_prefix,
|
||||
data_cfg,
|
||||
dataset_info=None,
|
||||
test_mode=False):
|
||||
|
||||
if dataset_info is None:
|
||||
logging.info(
|
||||
'dataset_info is missing, use default coco wholebody hand dataset info'
|
||||
)
|
||||
dataset_info = COCO_WHOLEBODY_HAND_DATASET_INFO
|
||||
|
||||
super().__init__(
|
||||
ann_file,
|
||||
img_prefix,
|
||||
data_cfg,
|
||||
dataset_info=dataset_info,
|
||||
test_mode=test_mode)
|
||||
|
||||
self.ann_info['use_different_joint_weights'] = False
|
||||
self.db = self._get_db()
|
||||
|
||||
print(f'=> num_images: {self.num_images}')
|
||||
print(f'=> load {len(self.db)} samples')
|
||||
|
||||
def _get_db(self):
|
||||
"""Load dataset."""
|
||||
gt_db = []
|
||||
bbox_id = 0
|
||||
num_joints = self.ann_info['num_joints']
|
||||
for img_id in self.img_ids:
|
||||
|
||||
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=False)
|
||||
objs = self.coco.loadAnns(ann_ids)
|
||||
|
||||
for obj in objs:
|
||||
for type in ['left', 'right']:
|
||||
if obj[f'{type}hand_valid'] and max(
|
||||
obj[f'{type}hand_kpts']) > 0:
|
||||
joints_3d = np.zeros((num_joints, 3), dtype=np.float32)
|
||||
joints_3d_visible = np.zeros((num_joints, 3),
|
||||
dtype=np.float32)
|
||||
|
||||
keypoints = np.array(obj[f'{type}hand_kpts']).reshape(
|
||||
-1, 3)
|
||||
joints_3d[:, :2] = keypoints[:, :2]
|
||||
joints_3d_visible[:, :2] = np.minimum(
|
||||
1, keypoints[:, 2:3])
|
||||
|
||||
image_file = osp.join(self.img_prefix,
|
||||
self.id2name[img_id])
|
||||
center, scale = self._xywh2cs(
|
||||
*obj[f'{type}hand_box'][:4])
|
||||
gt_db.append({
|
||||
'image_file': image_file,
|
||||
'image_id': img_id,
|
||||
'rotation': 0,
|
||||
'center': center,
|
||||
'scale': scale,
|
||||
'joints_3d': joints_3d,
|
||||
'joints_3d_visible': joints_3d_visible,
|
||||
'dataset': self.dataset_name,
|
||||
'bbox': obj[f'{type}hand_box'],
|
||||
'bbox_score': 1,
|
||||
'bbox_id': bbox_id
|
||||
})
|
||||
bbox_id = bbox_id + 1
|
||||
gt_db = sorted(gt_db, key=lambda x: x['bbox_id'])
|
||||
|
||||
return gt_db
|
70
easycv/datasets/pose/hand_coco_wholebody_dataset.py
Normal file
70
easycv/datasets/pose/hand_coco_wholebody_dataset.py
Normal file
@ -0,0 +1,70 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Adapt from
|
||||
# https://github.com/open-mmlab/mmpose/blob/master/mmpose/datasets/datasets/hand/hand_coco_wholebody_dataset.py
|
||||
|
||||
from easycv.core.evaluation.keypoint_eval import KeyPointEvaluator
|
||||
from easycv.datasets.pose.data_sources.coco import PoseTopDownSource
|
||||
from easycv.datasets.registry import DATASETS
|
||||
from easycv.datasets.shared.base import BaseDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class HandCocoWholeBodyDataset(BaseDataset):
|
||||
"""CocoWholeBodyDataset for top-down hand pose estimation.
|
||||
|
||||
Args:
|
||||
data_source: Data_source config dict
|
||||
pipeline: Pipeline config list
|
||||
profiling: If set True, will print pipeline time
|
||||
"""
|
||||
|
||||
def __init__(self, data_source, pipeline, profiling=False):
|
||||
super(HandCocoWholeBodyDataset, self).__init__(data_source, pipeline,
|
||||
profiling)
|
||||
|
||||
if not isinstance(self.data_source, PoseTopDownSource):
|
||||
raise ValueError('Only support `PoseTopDownSource`, but get %s' %
|
||||
self.data_source)
|
||||
|
||||
def evaluate(self, outputs, evaluators, **kwargs):
|
||||
if len(evaluators) > 1 or not isinstance(evaluators[0],
|
||||
KeyPointEvaluator):
|
||||
raise ValueError(
|
||||
'HandCocoWholeBodyDataset only support one `KeyPointEvaluator` now, '
|
||||
'but get %s' % evaluators)
|
||||
evaluator = evaluators[0]
|
||||
|
||||
image_ids = outputs['image_ids']
|
||||
preds = outputs['preds']
|
||||
boxes = outputs['boxes']
|
||||
bbox_ids = outputs['bbox_ids']
|
||||
|
||||
kpts = []
|
||||
for i, image_id in enumerate(image_ids):
|
||||
kpts.append({
|
||||
'keypoints': preds[i],
|
||||
'center': boxes[i][0:2],
|
||||
'scale': boxes[i][2:4],
|
||||
'area': boxes[i][4],
|
||||
'score': boxes[i][5],
|
||||
'image_id': image_id,
|
||||
'bbox_id': bbox_ids[i]
|
||||
})
|
||||
kpts = self._sort_and_unique_bboxes(kpts)
|
||||
eval_res = evaluator.evaluate(kpts, self.data_source.db)
|
||||
return eval_res
|
||||
|
||||
def _sort_and_unique_bboxes(self, kpts, key='bbox_id'):
|
||||
"""sort kpts and remove the repeated ones."""
|
||||
kpts = sorted(kpts, key=lambda x: x[key])
|
||||
num = len(kpts)
|
||||
for i in range(num - 1, 0, -1):
|
||||
if kpts[i][key] == kpts[i - 1][key]:
|
||||
del kpts[i]
|
||||
|
||||
return kpts
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""Get the sample given index."""
|
||||
results = self.data_source.get_sample(idx)
|
||||
return self.pipeline(results)
|
@ -1,9 +1,9 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Adapt from https://github.com/open-mmlab/mmpose/blob/master/mmpose/datasets/pipelines/top_down_transform.py
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from mmcv.parallel import DataContainer as DC
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
from easycv.core.post_processing import (affine_transform, fliplr_joints,
|
||||
get_affine_transform, get_warp_matrix,
|
||||
|
51
tests/core/evaluation/test_keypoint_eval.py
Normal file
51
tests/core/evaluation/test_keypoint_eval.py
Normal file
@ -0,0 +1,51 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from easycv.core.evaluation import KeyPointEvaluator
|
||||
|
||||
|
||||
class KeyPointEvaluatorTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
|
||||
def test_keypoint_evaluator_pck(self):
|
||||
evaluator = KeyPointEvaluator(pck_thr=0.5, pckh_thr=0.5, auc_nor=30)
|
||||
output = np.zeros((5, 3))
|
||||
target = np.zeros((5, 3))
|
||||
mask = np.zeros((5, 3))
|
||||
mask[:, :2] = 1
|
||||
# first channel
|
||||
output[0] = [10, 0, 0]
|
||||
target[0] = [10, 0, 0]
|
||||
# second channel
|
||||
output[1] = [20, 20, 0]
|
||||
target[1] = [10, 10, 0]
|
||||
# third channel
|
||||
output[2] = [0, 0, 0]
|
||||
target[2] = [-1, 0, 0]
|
||||
# fourth channel
|
||||
output[3] = [30, 30, 0]
|
||||
target[3] = [30, 30, 0]
|
||||
# fifth channel
|
||||
output[4] = [0, 10, 0]
|
||||
target[4] = [0, 10, 0]
|
||||
preds = {'keypoints': output}
|
||||
db = {
|
||||
'joints_3d': target,
|
||||
'joints_3d_visible': mask,
|
||||
'bbox': [10, 10, 10, 10],
|
||||
'head_size': 10
|
||||
}
|
||||
eval_res = evaluator.evaluate([preds, preds], [db, db])
|
||||
self.assertAlmostEqual(eval_res['PCK'], 0.8)
|
||||
self.assertAlmostEqual(eval_res['PCKh'], 0.8)
|
||||
self.assertAlmostEqual(eval_res['EPE'], 3.0284271240234375)
|
||||
self.assertAlmostEqual(eval_res['AUC'], 0.86)
|
||||
self.assertAlmostEqual(eval_res['NME'], 3.0284271240234375)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
59
tests/datasets/pose/data_sources/test_coco_hand.py
Normal file
59
tests/datasets/pose/data_sources/test_coco_hand.py
Normal file
@ -0,0 +1,59 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from tests.ut_config import SMALL_COCO_WHOLE_BODY_HAND_ROOT
|
||||
|
||||
from easycv.datasets.pose.data_sources import HandCocoPoseTopDownSource
|
||||
|
||||
_DATA_CFG = dict(
|
||||
image_size=[256, 256],
|
||||
heatmap_size=[64, 64],
|
||||
num_output_channels=21,
|
||||
num_joints=21,
|
||||
dataset_channel=[list(range(21))],
|
||||
inference_channel=list(range(21)),
|
||||
)
|
||||
|
||||
|
||||
class HandCocoPoseSourceCocoTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
|
||||
def test_top_down_source_coco(self):
|
||||
data_source = HandCocoPoseTopDownSource(
|
||||
data_cfg=_DATA_CFG,
|
||||
ann_file=
|
||||
f'{SMALL_COCO_WHOLE_BODY_HAND_ROOT}/annotations/small_whole_body_hand_coco.json',
|
||||
img_prefix=f'{SMALL_COCO_WHOLE_BODY_HAND_ROOT}/train2017/')
|
||||
index_list = random.choices(list(range(4)), k=3)
|
||||
for idx in index_list:
|
||||
data = data_source.get_sample(idx)
|
||||
self.assertIn('image_file', data)
|
||||
self.assertIn('image_id', data)
|
||||
self.assertIn('bbox_score', data)
|
||||
self.assertIn('bbox_id', data)
|
||||
self.assertIn('image_id', data)
|
||||
self.assertEqual(data['center'].shape, (2, ))
|
||||
self.assertEqual(data['scale'].shape, (2, ))
|
||||
self.assertEqual(len(data['bbox']), 4)
|
||||
self.assertEqual(data['joints_3d'].shape, (21, 3))
|
||||
self.assertEqual(data['joints_3d_visible'].shape, (21, 3))
|
||||
self.assertEqual(data['img'].shape[-1], 3)
|
||||
ann_info = data['ann_info']
|
||||
self.assertEqual(ann_info['image_size'].all(),
|
||||
np.array([256, 256]).all())
|
||||
self.assertEqual(ann_info['heatmap_size'].all(),
|
||||
np.array([64, 64]).all())
|
||||
self.assertEqual(ann_info['num_joints'], 21)
|
||||
self.assertEqual(len(ann_info['inference_channel']), 21)
|
||||
self.assertEqual(ann_info['num_output_channels'], 21)
|
||||
break
|
||||
|
||||
self.assertEqual(len(data_source), 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
75
tests/datasets/pose/test_coco_whole_body_hand_dataset.py
Normal file
75
tests/datasets/pose/test_coco_whole_body_hand_dataset.py
Normal file
@ -0,0 +1,75 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from tests.ut_config import SMALL_COCO_WHOLE_BODY_HAND_ROOT
|
||||
|
||||
from easycv.datasets.pose import HandCocoWholeBodyDataset
|
||||
|
||||
_DATA_CFG = dict(
|
||||
image_size=[256, 256],
|
||||
heatmap_size=[64, 64],
|
||||
num_output_channels=21,
|
||||
num_joints=21,
|
||||
dataset_channel=[list(range(21))],
|
||||
inference_channel=list(range(21)))
|
||||
|
||||
_DATASET_ARGS = [{
|
||||
'data_source':
|
||||
dict(
|
||||
type='HandCocoPoseTopDownSource',
|
||||
data_cfg=_DATA_CFG,
|
||||
ann_file=
|
||||
f'{SMALL_COCO_WHOLE_BODY_HAND_ROOT}/annotations/small_whole_body_hand_coco.json',
|
||||
img_prefix=f'{SMALL_COCO_WHOLE_BODY_HAND_ROOT}/train2017/'),
|
||||
'pipeline': [
|
||||
dict(type='TopDownRandomFlip', flip_prob=0.5),
|
||||
dict(type='TopDownAffine'),
|
||||
dict(type='MMToTensor'),
|
||||
dict(type='TopDownGenerateTarget', sigma=3),
|
||||
dict(
|
||||
type='PoseCollect',
|
||||
keys=['img', 'target', 'target_weight'],
|
||||
meta_keys=[
|
||||
'image_file', 'joints_3d', 'flip_pairs', 'joints_3d_visible',
|
||||
'center', 'scale', 'rotation', 'bbox_score'
|
||||
])
|
||||
]
|
||||
}, {}]
|
||||
|
||||
|
||||
class PoseTopDownDatasetTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
|
||||
@staticmethod
|
||||
def build_dataset(index):
|
||||
dataset = HandCocoWholeBodyDataset(
|
||||
data_source=_DATASET_ARGS[index].get('data_source', None),
|
||||
pipeline=_DATASET_ARGS[index].get('pipeline', None))
|
||||
|
||||
return dataset
|
||||
|
||||
def test_0(self, index=0):
|
||||
dataset = self.build_dataset(index)
|
||||
ann_info = dataset.data_source.ann_info
|
||||
|
||||
self.assertEqual(len(dataset), 4)
|
||||
for i, batch in enumerate(dataset):
|
||||
self.assertEqual(
|
||||
batch['img'].shape,
|
||||
torch.Size([3] + list(ann_info['image_size'][::-1])))
|
||||
self.assertEqual(batch['target'].shape,
|
||||
(ann_info['num_joints'], ) +
|
||||
tuple(ann_info['heatmap_size'][::-1]))
|
||||
self.assertEqual(batch['img_metas'].data['joints_3d'].shape,
|
||||
(ann_info['num_joints'], 3))
|
||||
self.assertIn('center', batch['img_metas'].data)
|
||||
self.assertIn('scale', batch['img_metas'].data)
|
||||
|
||||
break
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -118,3 +118,4 @@ PRETRAINED_MODEL_SEGFORMER = os.path.join(
|
||||
)
|
||||
MODEL_CONFIG_SEGFORMER = (
|
||||
'./configs/segmentation/segformer/segformer_b0_coco.py')
|
||||
SMALL_COCO_WHOLE_BODY_HAND_ROOT = 'data/test/pose/hand/small_whole_body_hand_coco'
|
||||
|
Loading…
x
Reference in New Issue
Block a user