新增coco-wholebody-hand数据集,新增pck auc epe nme评价指标

Link: https://code.alibaba-inc.com/pai-vision/EasyCV/codereview/9790242
This commit is contained in:
liangting.zl 2022-08-24 19:19:33 +08:00 committed by jiangnana.jnn
parent 0f74adb848
commit 2bf7c9f6ff
18 changed files with 3312 additions and 5 deletions

View File

@ -0,0 +1,176 @@
# oss_io_config = dict(
# ak_id='your oss ak id',
# ak_secret='your oss ak secret',
# hosts='oss-cn-zhangjiakou.aliyuncs.com', # your oss hosts
# buckets=['your_bucket']) # your oss buckets
oss_sync_config = dict(other_file_list=['**/events.out.tfevents*', '**/*log*'])
log_level = 'INFO'
load_from = None
resume_from = None
dist_params = dict(backend='nccl')
workflow = [('train', 1)]
checkpoint_config = dict(interval=10)
optimizer = dict(type='Adam', lr=5e-4)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.001,
step=[170, 200])
total_epochs = 210
log_config = dict(
interval=50,
hooks=[dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook')])
channel_cfg = dict(
num_output_channels=21,
dataset_joints=21,
dataset_channel=[
[
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20
],
],
inference_channel=[
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
20
])
# model settings
model = dict(
type='TopDown',
pretrained=False,
backbone=dict(
type='LiteHRNet',
in_channels=3,
extra=dict(
stem=dict(stem_channels=32, out_channels=32, expand_ratio=1),
num_stages=3,
stages_spec=dict(
num_modules=(3, 8, 3),
num_branches=(2, 3, 4),
num_blocks=(2, 2, 2),
module_type=('LITE', 'LITE', 'LITE'),
with_fuse=(True, True, True),
reduce_ratios=(8, 8, 8),
num_channels=(
(40, 80),
(40, 80, 160),
(40, 80, 160, 320),
)),
with_head=True,
)),
keypoint_head=dict(
type='TopdownHeatmapSimpleHead',
in_channels=40,
out_channels=channel_cfg['num_output_channels'],
num_deconv_layers=0,
extra=dict(final_conv_kernel=1, ),
loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
train_cfg=dict(),
test_cfg=dict(
flip_test=True,
post_process='default',
shift_heatmap=True,
modulate_kernel=11))
data_root = 'data/coco'
data_cfg = dict(
image_size=[256, 256],
heatmap_size=[64, 64],
num_output_channels=channel_cfg['num_output_channels'],
num_joints=channel_cfg['dataset_joints'],
dataset_channel=channel_cfg['dataset_channel'],
inference_channel=channel_cfg['inference_channel'],
)
train_pipeline = [
# dict(type='TopDownGetBboxCenterScale', padding=1.25),
dict(type='TopDownRandomFlip', flip_prob=0.5),
dict(
type='TopDownGetRandomScaleRotation', rot_factor=30,
scale_factor=0.25),
dict(type='TopDownAffine'),
dict(type='MMToTensor'),
dict(
type='NormalizeTensor',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
dict(type='TopDownGenerateTarget', sigma=3),
dict(
type='PoseCollect',
keys=['img', 'target', 'target_weight'],
meta_keys=[
'image_file', 'image_id', 'joints_3d', 'joints_3d_visible',
'center', 'scale', 'rotation', 'flip_pairs'
])
]
val_pipeline = [
dict(type='TopDownAffine'),
dict(type='MMToTensor'),
dict(
type='NormalizeTensor',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
dict(
type='PoseCollect',
keys=['img'],
meta_keys=[
'image_file', 'image_id', 'center', 'scale', 'rotation',
'flip_pairs'
])
]
test_pipeline = val_pipeline
data_source_cfg = dict(type='HandCocoPoseTopDownSource', data_cfg=data_cfg)
data = dict(
imgs_per_gpu=32, # for train
workers_per_gpu=2, # for train
# imgs_per_gpu=1, # for test
# workers_per_gpu=1, # for test
val_dataloader=dict(samples_per_gpu=32),
test_dataloader=dict(samples_per_gpu=32),
train=dict(
type='HandCocoWholeBodyDataset',
data_source=dict(
ann_file=f'{data_root}/annotations/coco_wholebody_train_v1.0.json',
img_prefix=f'{data_root}/train2017/',
**data_source_cfg),
pipeline=train_pipeline),
val=dict(
type='HandCocoWholeBodyDataset',
data_source=dict(
ann_file=f'{data_root}/annotations/coco_wholebody_val_v1.0.json',
img_prefix=f'{data_root}/val2017/',
test_mode=True,
**data_source_cfg),
pipeline=val_pipeline),
test=dict(
type='HandCocoWholeBodyDataset',
data_source=dict(
ann_file=f'{data_root}/annotations/coco_wholebody_val_v1.0.json',
img_prefix=f'{data_root}/val2017/',
test_mode=True,
**data_source_cfg),
pipeline=val_pipeline),
)
eval_config = dict(interval=10, metric='PCK', save_best='PCK')
evaluator_args = dict(
metric_names=['PCK', 'AUC', 'EPE', 'NME'], pck_thr=0.2, auc_nor=30)
eval_pipelines = [
dict(
mode='test',
data=dict(**data['val'], imgs_per_gpu=1),
evaluators=[dict(type='KeyPointEvaluator', **evaluator_args)])
]
export = dict(use_jit=False)
checkpoint_sync_export = True

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6c8207a06044306b0d271488a22e1a174af5a22e951a710e25a556cf5d212d5c
size 160632

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:feadc69a8190787088fda0ac12971d91badc93dbe06057645050fdbec1ce6911
size 204232

View File

@ -4,8 +4,10 @@ from .base_evaluator import Evaluator
from .classification_eval import ClsEvaluator
from .coco_evaluation import CocoDetectionEvaluator, CoCoPoseTopDownEvaluator
from .faceid_pair_eval import FaceIDPairEvaluator
from .keypoint_eval import KeyPointEvaluator
from .mse_eval import MSEEvaluator
from .retrival_topk_eval import RetrivalTopKEvaluator
from .segmentation_eval import SegmentationEvaluator
from .top_down_eval import (keypoint_pck_accuracy, keypoints_from_heatmaps,
from .top_down_eval import (keypoint_auc, keypoint_epe, keypoint_nme,
keypoint_pck_accuracy, keypoints_from_heatmaps,
pose_pck_accuracy)

View File

@ -0,0 +1,123 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Adapt from
# https://github.com/open-mmlab/mmpose/blob/master/mmpose/datasets/datasets/base/kpt_2d_sview_rgb_img_top_down_dataset.py
import numpy as np
from .base_evaluator import Evaluator
from .builder import EVALUATORS
from .metric_registry import METRICS
from .top_down_eval import (keypoint_auc, keypoint_epe, keypoint_nme,
keypoint_pck_accuracy)
@EVALUATORS.register_module
class KeyPointEvaluator(Evaluator):
""" KeyPoint evaluator.
"""
def __init__(self,
dataset_name=None,
metric_names=['PCK', 'PCKh', 'AUC', 'EPE', 'NME'],
pck_thr=0.2,
pckh_thr=0.7,
auc_nor=30):
"""
Args:
dataset_name: eval dataset name
metric_names: eval metrics name
pck_thr (float): PCK threshold, default as 0.2.
pckh_thr (float): PCKh threshold, default as 0.7.
auc_nor (float): AUC normalization factor, default as 30 pixel.
"""
super(KeyPointEvaluator, self).__init__(dataset_name, metric_names)
self._pck_thr = pck_thr
self._pckh_thr = pckh_thr
self._auc_nor = auc_nor
self.dataset_name = dataset_name
allowed_metrics = ['PCK', 'PCKh', 'AUC', 'EPE', 'NME']
for metric in metric_names:
if metric not in allowed_metrics:
raise KeyError(f'metric {metric} is not supported')
def _evaluate_impl(self, preds, coco_db, **kwargs):
''' keypoint evaluation code which will be run after
all test batched data are predicted
Args:
preds: dict with key ``keypoints`` whose shape is Nx3
coco_db: the db of wholebody coco datasource, sorted by 'bbox_id'
Return:
a dict, each key is metric_name, value is metric value
'''
assert len(preds) == len(coco_db)
eval_res = {}
outputs = []
gts = []
masks = []
box_sizes = []
threshold_bbox = []
threshold_head_box = []
for pred, item in zip(preds, coco_db):
outputs.append(np.array(pred['keypoints'])[:, :-1])
gts.append(np.array(item['joints_3d'])[:, :-1])
masks.append((np.array(item['joints_3d_visible'])[:, 0]) > 0)
if 'PCK' in self.metric_names:
bbox = np.array(item['bbox'])
bbox_thr = np.max(bbox[2:])
threshold_bbox.append(np.array([bbox_thr, bbox_thr]))
if 'PCKh' in self.metric_names:
head_box_thr = item['head_size']
threshold_head_box.append(
np.array([head_box_thr, head_box_thr]))
box_sizes.append(item.get('box_size', 1))
outputs = np.array(outputs)
gts = np.array(gts)
masks = np.array(masks)
threshold_bbox = np.array(threshold_bbox)
threshold_head_box = np.array(threshold_head_box)
box_sizes = np.array(box_sizes).reshape([-1, 1])
if 'PCK' in self.metric_names:
_, pck, _ = keypoint_pck_accuracy(outputs, gts, masks,
self._pck_thr, threshold_bbox)
eval_res['PCK'] = pck
if 'PCKh' in self.metric_names:
_, pckh, _ = keypoint_pck_accuracy(outputs, gts, masks,
self._pckh_thr,
threshold_head_box)
eval_res['PCKh'] = pckh
if 'AUC' in self.metric_names:
eval_res['AUC'] = keypoint_auc(outputs, gts, masks, self._auc_nor)
if 'EPE' in self.metric_names:
eval_res['EPE'] = keypoint_epe(outputs, gts, masks)
if 'NME' in self.metric_names:
normalize_factor = self._get_normalize_factor(
gts=gts, box_sizes=box_sizes)
eval_res['NME'] = keypoint_nme(outputs, gts, masks,
normalize_factor)
return eval_res
def _get_normalize_factor(self, gts, *args, **kwargs):
"""Get the normalize factor. generally inter-ocular distance measured
as the Euclidean distance between the outer corners of the eyes is
used. This function should be overrode, to measure NME.
Args:
gts (np.ndarray[N, K, 2]): Groundtruth keypoint location.
Returns:
np.ndarray[N, 2]: normalized factor
"""
return np.ones([gts.shape[0], 2], dtype=np.float32)
METRICS.register_default_best_metric(KeyPointEvaluator, 'PCK', 'max')

View File

@ -178,6 +178,86 @@ def keypoint_pck_accuracy(pred, gt, mask, thr, normalize):
return acc, avg_acc, cnt
def keypoint_auc(pred, gt, mask, normalize, num_step=20):
"""Calculate the pose accuracy of PCK for each individual keypoint and the
averaged accuracy across all keypoints for coordinates.
Note:
- batch_size: N
- num_keypoints: K
Args:
pred (np.ndarray[N, K, 2]): Predicted keypoint location.
gt (np.ndarray[N, K, 2]): Groundtruth keypoint location.
mask (np.ndarray[N, K]): Visibility of the target. False for invisible
joints, and True for visible. Invisible joints will be ignored for
accuracy calculation.
normalize (float): Normalization factor.
Returns:
float: Area under curve.
"""
nor = np.tile(np.array([[normalize, normalize]]), (pred.shape[0], 1))
x = [1.0 * i / num_step for i in range(num_step)]
y = []
for thr in x:
_, avg_acc, _ = keypoint_pck_accuracy(pred, gt, mask, thr, nor)
y.append(avg_acc)
auc = 0
for i in range(num_step):
auc += 1.0 / num_step * y[i]
return auc
def keypoint_nme(pred, gt, mask, normalize_factor):
"""Calculate the normalized mean error (NME).
Note:
- batch_size: N
- num_keypoints: K
Args:
pred (np.ndarray[N, K, 2]): Predicted keypoint location.
gt (np.ndarray[N, K, 2]): Groundtruth keypoint location.
mask (np.ndarray[N, K]): Visibility of the target. False for invisible
joints, and True for visible. Invisible joints will be ignored for
accuracy calculation.
normalize_factor (np.ndarray[N, 2]): Normalization factor.
Returns:
float: normalized mean error
"""
distances = _calc_distances(pred, gt, mask, normalize_factor)
distance_valid = distances[distances != -1]
return distance_valid.sum() / max(1, len(distance_valid))
def keypoint_epe(pred, gt, mask):
"""Calculate the end-point error.
Note:
- batch_size: N
- num_keypoints: K
Args:
pred (np.ndarray[N, K, 2]): Predicted keypoint location.
gt (np.ndarray[N, K, 2]): Groundtruth keypoint location.
mask (np.ndarray[N, K]): Visibility of the target. False for invisible
joints, and True for visible. Invisible joints will be ignored for
accuracy calculation.
Returns:
float: Average end-point error.
"""
distances = _calc_distances(
pred, gt, mask,
np.ones((pred.shape[0], pred.shape[2]), dtype=np.float32))
distance_valid = distances[distances != -1]
return distance_valid.sum() / max(1, len(distance_valid))
def _taylor(heatmap, coord):
"""Distribution aware coordinate decoding method.

View File

@ -83,7 +83,7 @@ def fliplr_regression(regression,
allowed_center_mode = {'static', 'root'}
assert center_mode in allowed_center_mode, 'Get invalid center_mode ' \
f'{center_mode}, allowed choices are {allowed_center_mode}'
f'{center_mode}, allowed choices are {allowed_center_mode}'
if center_mode == 'static':
x_c = center_x

View File

@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from . import data_sources # pylint: disable=unused-import
from . import pipelines # pylint: disable=unused-import
from .hand_coco_wholebody_dataset import HandCocoWholeBodyDataset
from .top_down import PoseTopDownDataset
__all__ = ['PoseTopDownDataset']
__all__ = ['PoseTopDownDataset', 'HandCocoWholeBodyDataset']

View File

@ -1,5 +1,8 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .coco import PoseTopDownSourceCoco
from .hand import HandCocoPoseTopDownSource
from .top_down import PoseTopDownSource
__all__ = ['PoseTopDownSourceCoco', 'PoseTopDownSource']
__all__ = [
'PoseTopDownSourceCoco', 'PoseTopDownSource', 'HandCocoPoseTopDownSource'
]

View File

@ -0,0 +1,3 @@
# !/usr/bin/env python
# -*- encoding: utf-8 -*-
from .coco_hand import HandCocoPoseTopDownSource

View File

@ -0,0 +1,276 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Adapt from
# https://github.com/open-mmlab/mmpose/blob/master/mmpose/datasets/datasets/hand/hand_coco_wholebody_dataset.py
import logging
import os.path as osp
import numpy as np
from easycv.datasets.registry import DATASOURCES
from ..top_down import PoseTopDownSource
COCO_WHOLEBODY_HAND_DATASET_INFO = dict(
dataset_name='coco_wholebody_hand',
paper_info=dict(
author='Jin, Sheng and Xu, Lumin and Xu, Jin and '
'Wang, Can and Liu, Wentao and '
'Qian, Chen and Ouyang, Wanli and Luo, Ping',
title='Whole-Body Human Pose Estimation in the Wild',
container='Proceedings of the European '
'Conference on Computer Vision (ECCV)',
year='2020',
homepage='https://github.com/jin-s13/COCO-WholeBody/',
),
keypoint_info={
0:
dict(name='wrist', id=0, color=[255, 255, 255], type='', swap=''),
1:
dict(name='thumb1', id=1, color=[255, 128, 0], type='', swap=''),
2:
dict(name='thumb2', id=2, color=[255, 128, 0], type='', swap=''),
3:
dict(name='thumb3', id=3, color=[255, 128, 0], type='', swap=''),
4:
dict(name='thumb4', id=4, color=[255, 128, 0], type='', swap=''),
5:
dict(
name='forefinger1', id=5, color=[255, 153, 255], type='', swap=''),
6:
dict(
name='forefinger2', id=6, color=[255, 153, 255], type='', swap=''),
7:
dict(
name='forefinger3', id=7, color=[255, 153, 255], type='', swap=''),
8:
dict(
name='forefinger4', id=8, color=[255, 153, 255], type='', swap=''),
9:
dict(
name='middle_finger1',
id=9,
color=[102, 178, 255],
type='',
swap=''),
10:
dict(
name='middle_finger2',
id=10,
color=[102, 178, 255],
type='',
swap=''),
11:
dict(
name='middle_finger3',
id=11,
color=[102, 178, 255],
type='',
swap=''),
12:
dict(
name='middle_finger4',
id=12,
color=[102, 178, 255],
type='',
swap=''),
13:
dict(
name='ring_finger1', id=13, color=[255, 51, 51], type='', swap=''),
14:
dict(
name='ring_finger2', id=14, color=[255, 51, 51], type='', swap=''),
15:
dict(
name='ring_finger3', id=15, color=[255, 51, 51], type='', swap=''),
16:
dict(
name='ring_finger4', id=16, color=[255, 51, 51], type='', swap=''),
17:
dict(name='pinky_finger1', id=17, color=[0, 255, 0], type='', swap=''),
18:
dict(name='pinky_finger2', id=18, color=[0, 255, 0], type='', swap=''),
19:
dict(name='pinky_finger3', id=19, color=[0, 255, 0], type='', swap=''),
20:
dict(name='pinky_finger4', id=20, color=[0, 255, 0], type='', swap='')
},
skeleton_info={
0:
dict(link=('wrist', 'thumb1'), id=0, color=[255, 128, 0]),
1:
dict(link=('thumb1', 'thumb2'), id=1, color=[255, 128, 0]),
2:
dict(link=('thumb2', 'thumb3'), id=2, color=[255, 128, 0]),
3:
dict(link=('thumb3', 'thumb4'), id=3, color=[255, 128, 0]),
4:
dict(link=('wrist', 'forefinger1'), id=4, color=[255, 153, 255]),
5:
dict(link=('forefinger1', 'forefinger2'), id=5, color=[255, 153, 255]),
6:
dict(link=('forefinger2', 'forefinger3'), id=6, color=[255, 153, 255]),
7:
dict(link=('forefinger3', 'forefinger4'), id=7, color=[255, 153, 255]),
8:
dict(link=('wrist', 'middle_finger1'), id=8, color=[102, 178, 255]),
9:
dict(
link=('middle_finger1', 'middle_finger2'),
id=9,
color=[102, 178, 255]),
10:
dict(
link=('middle_finger2', 'middle_finger3'),
id=10,
color=[102, 178, 255]),
11:
dict(
link=('middle_finger3', 'middle_finger4'),
id=11,
color=[102, 178, 255]),
12:
dict(link=('wrist', 'ring_finger1'), id=12, color=[255, 51, 51]),
13:
dict(
link=('ring_finger1', 'ring_finger2'), id=13, color=[255, 51, 51]),
14:
dict(
link=('ring_finger2', 'ring_finger3'), id=14, color=[255, 51, 51]),
15:
dict(
link=('ring_finger3', 'ring_finger4'), id=15, color=[255, 51, 51]),
16:
dict(link=('wrist', 'pinky_finger1'), id=16, color=[0, 255, 0]),
17:
dict(
link=('pinky_finger1', 'pinky_finger2'), id=17, color=[0, 255, 0]),
18:
dict(
link=('pinky_finger2', 'pinky_finger3'), id=18, color=[0, 255, 0]),
19:
dict(
link=('pinky_finger3', 'pinky_finger4'), id=19, color=[0, 255, 0])
},
joint_weights=[1.] * 21,
sigmas=[
0.029, 0.022, 0.035, 0.037, 0.047, 0.026, 0.025, 0.024, 0.035, 0.018,
0.024, 0.022, 0.026, 0.017, 0.021, 0.021, 0.032, 0.02, 0.019, 0.022,
0.031
])
@DATASOURCES.register_module()
class HandCocoPoseTopDownSource(PoseTopDownSource):
"""Coco Whole-Body-Hand Source for top-down hand pose estimation.
"Whole-Body Human Pose Estimation in the Wild", ECCV'2020.
More details can be found in the `paper
<https://arxiv.org/abs/2007.11858>`__ .
The dataset loads raw features and apply specified transforms
to return a dict containing the image tensors and other information.
COCO-WholeBody Hand keypoint indexes::
0: 'wrist',
1: 'thumb1',
2: 'thumb2',
3: 'thumb3',
4: 'thumb4',
5: 'forefinger1',
6: 'forefinger2',
7: 'forefinger3',
8: 'forefinger4',
9: 'middle_finger1',
10: 'middle_finger2',
11: 'middle_finger3',
12: 'middle_finger4',
13: 'ring_finger1',
14: 'ring_finger2',
15: 'ring_finger3',
16: 'ring_finger4',
17: 'pinky_finger1',
18: 'pinky_finger2',
19: 'pinky_finger3',
20: 'pinky_finger4'
Args:
ann_file (str): Path to the annotation file.
img_prefix (str): Path to a directory where images are held.
Default: None.
data_cfg (dict): config
dataset_info (DatasetInfo): A class containing all dataset info.
test_mode (bool): Store True when building test or
validation dataset. Default: False.
"""
def __init__(self,
ann_file,
img_prefix,
data_cfg,
dataset_info=None,
test_mode=False):
if dataset_info is None:
logging.info(
'dataset_info is missing, use default coco wholebody hand dataset info'
)
dataset_info = COCO_WHOLEBODY_HAND_DATASET_INFO
super().__init__(
ann_file,
img_prefix,
data_cfg,
dataset_info=dataset_info,
test_mode=test_mode)
self.ann_info['use_different_joint_weights'] = False
self.db = self._get_db()
print(f'=> num_images: {self.num_images}')
print(f'=> load {len(self.db)} samples')
def _get_db(self):
"""Load dataset."""
gt_db = []
bbox_id = 0
num_joints = self.ann_info['num_joints']
for img_id in self.img_ids:
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=False)
objs = self.coco.loadAnns(ann_ids)
for obj in objs:
for type in ['left', 'right']:
if obj[f'{type}hand_valid'] and max(
obj[f'{type}hand_kpts']) > 0:
joints_3d = np.zeros((num_joints, 3), dtype=np.float32)
joints_3d_visible = np.zeros((num_joints, 3),
dtype=np.float32)
keypoints = np.array(obj[f'{type}hand_kpts']).reshape(
-1, 3)
joints_3d[:, :2] = keypoints[:, :2]
joints_3d_visible[:, :2] = np.minimum(
1, keypoints[:, 2:3])
image_file = osp.join(self.img_prefix,
self.id2name[img_id])
center, scale = self._xywh2cs(
*obj[f'{type}hand_box'][:4])
gt_db.append({
'image_file': image_file,
'image_id': img_id,
'rotation': 0,
'center': center,
'scale': scale,
'joints_3d': joints_3d,
'joints_3d_visible': joints_3d_visible,
'dataset': self.dataset_name,
'bbox': obj[f'{type}hand_box'],
'bbox_score': 1,
'bbox_id': bbox_id
})
bbox_id = bbox_id + 1
gt_db = sorted(gt_db, key=lambda x: x['bbox_id'])
return gt_db

View File

@ -0,0 +1,70 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Adapt from
# https://github.com/open-mmlab/mmpose/blob/master/mmpose/datasets/datasets/hand/hand_coco_wholebody_dataset.py
from easycv.core.evaluation.keypoint_eval import KeyPointEvaluator
from easycv.datasets.pose.data_sources.coco import PoseTopDownSource
from easycv.datasets.registry import DATASETS
from easycv.datasets.shared.base import BaseDataset
@DATASETS.register_module()
class HandCocoWholeBodyDataset(BaseDataset):
"""CocoWholeBodyDataset for top-down hand pose estimation.
Args:
data_source: Data_source config dict
pipeline: Pipeline config list
profiling: If set True, will print pipeline time
"""
def __init__(self, data_source, pipeline, profiling=False):
super(HandCocoWholeBodyDataset, self).__init__(data_source, pipeline,
profiling)
if not isinstance(self.data_source, PoseTopDownSource):
raise ValueError('Only support `PoseTopDownSource`, but get %s' %
self.data_source)
def evaluate(self, outputs, evaluators, **kwargs):
if len(evaluators) > 1 or not isinstance(evaluators[0],
KeyPointEvaluator):
raise ValueError(
'HandCocoWholeBodyDataset only support one `KeyPointEvaluator` now, '
'but get %s' % evaluators)
evaluator = evaluators[0]
image_ids = outputs['image_ids']
preds = outputs['preds']
boxes = outputs['boxes']
bbox_ids = outputs['bbox_ids']
kpts = []
for i, image_id in enumerate(image_ids):
kpts.append({
'keypoints': preds[i],
'center': boxes[i][0:2],
'scale': boxes[i][2:4],
'area': boxes[i][4],
'score': boxes[i][5],
'image_id': image_id,
'bbox_id': bbox_ids[i]
})
kpts = self._sort_and_unique_bboxes(kpts)
eval_res = evaluator.evaluate(kpts, self.data_source.db)
return eval_res
def _sort_and_unique_bboxes(self, kpts, key='bbox_id'):
"""sort kpts and remove the repeated ones."""
kpts = sorted(kpts, key=lambda x: x[key])
num = len(kpts)
for i in range(num - 1, 0, -1):
if kpts[i][key] == kpts[i - 1][key]:
del kpts[i]
return kpts
def __getitem__(self, idx):
"""Get the sample given index."""
results = self.data_source.get_sample(idx)
return self.pipeline(results)

View File

@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Adapt from https://github.com/open-mmlab/mmpose/blob/master/mmpose/datasets/pipelines/top_down_transform.py
import cv2
import numpy as np
from mmcv.parallel import DataContainer as DC
from torchvision.transforms import functional as F
from easycv.core.post_processing import (affine_transform, fliplr_joints,
get_affine_transform, get_warp_matrix,

View File

@ -0,0 +1,51 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
import numpy as np
from easycv.core.evaluation import KeyPointEvaluator
class KeyPointEvaluatorTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_keypoint_evaluator_pck(self):
evaluator = KeyPointEvaluator(pck_thr=0.5, pckh_thr=0.5, auc_nor=30)
output = np.zeros((5, 3))
target = np.zeros((5, 3))
mask = np.zeros((5, 3))
mask[:, :2] = 1
# first channel
output[0] = [10, 0, 0]
target[0] = [10, 0, 0]
# second channel
output[1] = [20, 20, 0]
target[1] = [10, 10, 0]
# third channel
output[2] = [0, 0, 0]
target[2] = [-1, 0, 0]
# fourth channel
output[3] = [30, 30, 0]
target[3] = [30, 30, 0]
# fifth channel
output[4] = [0, 10, 0]
target[4] = [0, 10, 0]
preds = {'keypoints': output}
db = {
'joints_3d': target,
'joints_3d_visible': mask,
'bbox': [10, 10, 10, 10],
'head_size': 10
}
eval_res = evaluator.evaluate([preds, preds], [db, db])
self.assertAlmostEqual(eval_res['PCK'], 0.8)
self.assertAlmostEqual(eval_res['PCKh'], 0.8)
self.assertAlmostEqual(eval_res['EPE'], 3.0284271240234375)
self.assertAlmostEqual(eval_res['AUC'], 0.86)
self.assertAlmostEqual(eval_res['NME'], 3.0284271240234375)
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,59 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import random
import unittest
import numpy as np
from tests.ut_config import SMALL_COCO_WHOLE_BODY_HAND_ROOT
from easycv.datasets.pose.data_sources import HandCocoPoseTopDownSource
_DATA_CFG = dict(
image_size=[256, 256],
heatmap_size=[64, 64],
num_output_channels=21,
num_joints=21,
dataset_channel=[list(range(21))],
inference_channel=list(range(21)),
)
class HandCocoPoseSourceCocoTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_top_down_source_coco(self):
data_source = HandCocoPoseTopDownSource(
data_cfg=_DATA_CFG,
ann_file=
f'{SMALL_COCO_WHOLE_BODY_HAND_ROOT}/annotations/small_whole_body_hand_coco.json',
img_prefix=f'{SMALL_COCO_WHOLE_BODY_HAND_ROOT}/train2017/')
index_list = random.choices(list(range(4)), k=3)
for idx in index_list:
data = data_source.get_sample(idx)
self.assertIn('image_file', data)
self.assertIn('image_id', data)
self.assertIn('bbox_score', data)
self.assertIn('bbox_id', data)
self.assertIn('image_id', data)
self.assertEqual(data['center'].shape, (2, ))
self.assertEqual(data['scale'].shape, (2, ))
self.assertEqual(len(data['bbox']), 4)
self.assertEqual(data['joints_3d'].shape, (21, 3))
self.assertEqual(data['joints_3d_visible'].shape, (21, 3))
self.assertEqual(data['img'].shape[-1], 3)
ann_info = data['ann_info']
self.assertEqual(ann_info['image_size'].all(),
np.array([256, 256]).all())
self.assertEqual(ann_info['heatmap_size'].all(),
np.array([64, 64]).all())
self.assertEqual(ann_info['num_joints'], 21)
self.assertEqual(len(ann_info['inference_channel']), 21)
self.assertEqual(ann_info['num_output_channels'], 21)
break
self.assertEqual(len(data_source), 4)
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,75 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
import torch
from tests.ut_config import SMALL_COCO_WHOLE_BODY_HAND_ROOT
from easycv.datasets.pose import HandCocoWholeBodyDataset
_DATA_CFG = dict(
image_size=[256, 256],
heatmap_size=[64, 64],
num_output_channels=21,
num_joints=21,
dataset_channel=[list(range(21))],
inference_channel=list(range(21)))
_DATASET_ARGS = [{
'data_source':
dict(
type='HandCocoPoseTopDownSource',
data_cfg=_DATA_CFG,
ann_file=
f'{SMALL_COCO_WHOLE_BODY_HAND_ROOT}/annotations/small_whole_body_hand_coco.json',
img_prefix=f'{SMALL_COCO_WHOLE_BODY_HAND_ROOT}/train2017/'),
'pipeline': [
dict(type='TopDownRandomFlip', flip_prob=0.5),
dict(type='TopDownAffine'),
dict(type='MMToTensor'),
dict(type='TopDownGenerateTarget', sigma=3),
dict(
type='PoseCollect',
keys=['img', 'target', 'target_weight'],
meta_keys=[
'image_file', 'joints_3d', 'flip_pairs', 'joints_3d_visible',
'center', 'scale', 'rotation', 'bbox_score'
])
]
}, {}]
class PoseTopDownDatasetTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
@staticmethod
def build_dataset(index):
dataset = HandCocoWholeBodyDataset(
data_source=_DATASET_ARGS[index].get('data_source', None),
pipeline=_DATASET_ARGS[index].get('pipeline', None))
return dataset
def test_0(self, index=0):
dataset = self.build_dataset(index)
ann_info = dataset.data_source.ann_info
self.assertEqual(len(dataset), 4)
for i, batch in enumerate(dataset):
self.assertEqual(
batch['img'].shape,
torch.Size([3] + list(ann_info['image_size'][::-1])))
self.assertEqual(batch['target'].shape,
(ann_info['num_joints'], ) +
tuple(ann_info['heatmap_size'][::-1]))
self.assertEqual(batch['img_metas'].data['joints_3d'].shape,
(ann_info['num_joints'], 3))
self.assertIn('center', batch['img_metas'].data)
self.assertIn('scale', batch['img_metas'].data)
break
if __name__ == '__main__':
unittest.main()

View File

@ -118,3 +118,4 @@ PRETRAINED_MODEL_SEGFORMER = os.path.join(
)
MODEL_CONFIG_SEGFORMER = (
'./configs/segmentation/segformer/segformer_b0_coco.py')
SMALL_COCO_WHOLE_BODY_HAND_ROOT = 'data/test/pose/hand/small_whole_body_hand_coco'