mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
新增coco-wholebody-hand数据集,新增pck auc epe nme评价指标
Link: https://code.alibaba-inc.com/pai-vision/EasyCV/codereview/9790242
This commit is contained in:
parent
0f74adb848
commit
2bf7c9f6ff
176
configs/pose/hand/litehrnet_30_coco_wholebody_hand_256x256.py
Normal file
176
configs/pose/hand/litehrnet_30_coco_wholebody_hand_256x256.py
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
# oss_io_config = dict(
|
||||||
|
# ak_id='your oss ak id',
|
||||||
|
# ak_secret='your oss ak secret',
|
||||||
|
# hosts='oss-cn-zhangjiakou.aliyuncs.com', # your oss hosts
|
||||||
|
# buckets=['your_bucket']) # your oss buckets
|
||||||
|
|
||||||
|
oss_sync_config = dict(other_file_list=['**/events.out.tfevents*', '**/*log*'])
|
||||||
|
|
||||||
|
log_level = 'INFO'
|
||||||
|
load_from = None
|
||||||
|
resume_from = None
|
||||||
|
dist_params = dict(backend='nccl')
|
||||||
|
workflow = [('train', 1)]
|
||||||
|
checkpoint_config = dict(interval=10)
|
||||||
|
|
||||||
|
optimizer = dict(type='Adam', lr=5e-4)
|
||||||
|
optimizer_config = dict(grad_clip=None)
|
||||||
|
# learning policy
|
||||||
|
lr_config = dict(
|
||||||
|
policy='step',
|
||||||
|
warmup='linear',
|
||||||
|
warmup_iters=500,
|
||||||
|
warmup_ratio=0.001,
|
||||||
|
step=[170, 200])
|
||||||
|
total_epochs = 210
|
||||||
|
log_config = dict(
|
||||||
|
interval=50,
|
||||||
|
hooks=[dict(type='TextLoggerHook'),
|
||||||
|
dict(type='TensorboardLoggerHook')])
|
||||||
|
channel_cfg = dict(
|
||||||
|
num_output_channels=21,
|
||||||
|
dataset_joints=21,
|
||||||
|
dataset_channel=[
|
||||||
|
[
|
||||||
|
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
|
||||||
|
19, 20
|
||||||
|
],
|
||||||
|
],
|
||||||
|
inference_channel=[
|
||||||
|
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
|
||||||
|
20
|
||||||
|
])
|
||||||
|
|
||||||
|
# model settings
|
||||||
|
model = dict(
|
||||||
|
type='TopDown',
|
||||||
|
pretrained=False,
|
||||||
|
backbone=dict(
|
||||||
|
type='LiteHRNet',
|
||||||
|
in_channels=3,
|
||||||
|
extra=dict(
|
||||||
|
stem=dict(stem_channels=32, out_channels=32, expand_ratio=1),
|
||||||
|
num_stages=3,
|
||||||
|
stages_spec=dict(
|
||||||
|
num_modules=(3, 8, 3),
|
||||||
|
num_branches=(2, 3, 4),
|
||||||
|
num_blocks=(2, 2, 2),
|
||||||
|
module_type=('LITE', 'LITE', 'LITE'),
|
||||||
|
with_fuse=(True, True, True),
|
||||||
|
reduce_ratios=(8, 8, 8),
|
||||||
|
num_channels=(
|
||||||
|
(40, 80),
|
||||||
|
(40, 80, 160),
|
||||||
|
(40, 80, 160, 320),
|
||||||
|
)),
|
||||||
|
with_head=True,
|
||||||
|
)),
|
||||||
|
keypoint_head=dict(
|
||||||
|
type='TopdownHeatmapSimpleHead',
|
||||||
|
in_channels=40,
|
||||||
|
out_channels=channel_cfg['num_output_channels'],
|
||||||
|
num_deconv_layers=0,
|
||||||
|
extra=dict(final_conv_kernel=1, ),
|
||||||
|
loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
|
||||||
|
train_cfg=dict(),
|
||||||
|
test_cfg=dict(
|
||||||
|
flip_test=True,
|
||||||
|
post_process='default',
|
||||||
|
shift_heatmap=True,
|
||||||
|
modulate_kernel=11))
|
||||||
|
|
||||||
|
data_root = 'data/coco'
|
||||||
|
|
||||||
|
data_cfg = dict(
|
||||||
|
image_size=[256, 256],
|
||||||
|
heatmap_size=[64, 64],
|
||||||
|
num_output_channels=channel_cfg['num_output_channels'],
|
||||||
|
num_joints=channel_cfg['dataset_joints'],
|
||||||
|
dataset_channel=channel_cfg['dataset_channel'],
|
||||||
|
inference_channel=channel_cfg['inference_channel'],
|
||||||
|
)
|
||||||
|
|
||||||
|
train_pipeline = [
|
||||||
|
# dict(type='TopDownGetBboxCenterScale', padding=1.25),
|
||||||
|
dict(type='TopDownRandomFlip', flip_prob=0.5),
|
||||||
|
dict(
|
||||||
|
type='TopDownGetRandomScaleRotation', rot_factor=30,
|
||||||
|
scale_factor=0.25),
|
||||||
|
dict(type='TopDownAffine'),
|
||||||
|
dict(type='MMToTensor'),
|
||||||
|
dict(
|
||||||
|
type='NormalizeTensor',
|
||||||
|
mean=[0.485, 0.456, 0.406],
|
||||||
|
std=[0.229, 0.224, 0.225]),
|
||||||
|
dict(type='TopDownGenerateTarget', sigma=3),
|
||||||
|
dict(
|
||||||
|
type='PoseCollect',
|
||||||
|
keys=['img', 'target', 'target_weight'],
|
||||||
|
meta_keys=[
|
||||||
|
'image_file', 'image_id', 'joints_3d', 'joints_3d_visible',
|
||||||
|
'center', 'scale', 'rotation', 'flip_pairs'
|
||||||
|
])
|
||||||
|
]
|
||||||
|
|
||||||
|
val_pipeline = [
|
||||||
|
dict(type='TopDownAffine'),
|
||||||
|
dict(type='MMToTensor'),
|
||||||
|
dict(
|
||||||
|
type='NormalizeTensor',
|
||||||
|
mean=[0.485, 0.456, 0.406],
|
||||||
|
std=[0.229, 0.224, 0.225]),
|
||||||
|
dict(
|
||||||
|
type='PoseCollect',
|
||||||
|
keys=['img'],
|
||||||
|
meta_keys=[
|
||||||
|
'image_file', 'image_id', 'center', 'scale', 'rotation',
|
||||||
|
'flip_pairs'
|
||||||
|
])
|
||||||
|
]
|
||||||
|
|
||||||
|
test_pipeline = val_pipeline
|
||||||
|
data_source_cfg = dict(type='HandCocoPoseTopDownSource', data_cfg=data_cfg)
|
||||||
|
|
||||||
|
data = dict(
|
||||||
|
imgs_per_gpu=32, # for train
|
||||||
|
workers_per_gpu=2, # for train
|
||||||
|
# imgs_per_gpu=1, # for test
|
||||||
|
# workers_per_gpu=1, # for test
|
||||||
|
val_dataloader=dict(samples_per_gpu=32),
|
||||||
|
test_dataloader=dict(samples_per_gpu=32),
|
||||||
|
train=dict(
|
||||||
|
type='HandCocoWholeBodyDataset',
|
||||||
|
data_source=dict(
|
||||||
|
ann_file=f'{data_root}/annotations/coco_wholebody_train_v1.0.json',
|
||||||
|
img_prefix=f'{data_root}/train2017/',
|
||||||
|
**data_source_cfg),
|
||||||
|
pipeline=train_pipeline),
|
||||||
|
val=dict(
|
||||||
|
type='HandCocoWholeBodyDataset',
|
||||||
|
data_source=dict(
|
||||||
|
ann_file=f'{data_root}/annotations/coco_wholebody_val_v1.0.json',
|
||||||
|
img_prefix=f'{data_root}/val2017/',
|
||||||
|
test_mode=True,
|
||||||
|
**data_source_cfg),
|
||||||
|
pipeline=val_pipeline),
|
||||||
|
test=dict(
|
||||||
|
type='HandCocoWholeBodyDataset',
|
||||||
|
data_source=dict(
|
||||||
|
ann_file=f'{data_root}/annotations/coco_wholebody_val_v1.0.json',
|
||||||
|
img_prefix=f'{data_root}/val2017/',
|
||||||
|
test_mode=True,
|
||||||
|
**data_source_cfg),
|
||||||
|
pipeline=val_pipeline),
|
||||||
|
)
|
||||||
|
|
||||||
|
eval_config = dict(interval=10, metric='PCK', save_best='PCK')
|
||||||
|
evaluator_args = dict(
|
||||||
|
metric_names=['PCK', 'AUC', 'EPE', 'NME'], pck_thr=0.2, auc_nor=30)
|
||||||
|
eval_pipelines = [
|
||||||
|
dict(
|
||||||
|
mode='test',
|
||||||
|
data=dict(**data['val'], imgs_per_gpu=1),
|
||||||
|
evaluators=[dict(type='KeyPointEvaluator', **evaluator_args)])
|
||||||
|
]
|
||||||
|
export = dict(use_jit=False)
|
||||||
|
checkpoint_sync_export = True
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:6c8207a06044306b0d271488a22e1a174af5a22e951a710e25a556cf5d212d5c
|
||||||
|
size 160632
|
@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:feadc69a8190787088fda0ac12971d91badc93dbe06057645050fdbec1ce6911
|
||||||
|
size 204232
|
@ -4,8 +4,10 @@ from .base_evaluator import Evaluator
|
|||||||
from .classification_eval import ClsEvaluator
|
from .classification_eval import ClsEvaluator
|
||||||
from .coco_evaluation import CocoDetectionEvaluator, CoCoPoseTopDownEvaluator
|
from .coco_evaluation import CocoDetectionEvaluator, CoCoPoseTopDownEvaluator
|
||||||
from .faceid_pair_eval import FaceIDPairEvaluator
|
from .faceid_pair_eval import FaceIDPairEvaluator
|
||||||
|
from .keypoint_eval import KeyPointEvaluator
|
||||||
from .mse_eval import MSEEvaluator
|
from .mse_eval import MSEEvaluator
|
||||||
from .retrival_topk_eval import RetrivalTopKEvaluator
|
from .retrival_topk_eval import RetrivalTopKEvaluator
|
||||||
from .segmentation_eval import SegmentationEvaluator
|
from .segmentation_eval import SegmentationEvaluator
|
||||||
from .top_down_eval import (keypoint_pck_accuracy, keypoints_from_heatmaps,
|
from .top_down_eval import (keypoint_auc, keypoint_epe, keypoint_nme,
|
||||||
|
keypoint_pck_accuracy, keypoints_from_heatmaps,
|
||||||
pose_pck_accuracy)
|
pose_pck_accuracy)
|
||||||
|
123
easycv/core/evaluation/keypoint_eval.py
Normal file
123
easycv/core/evaluation/keypoint_eval.py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
# Adapt from
|
||||||
|
# https://github.com/open-mmlab/mmpose/blob/master/mmpose/datasets/datasets/base/kpt_2d_sview_rgb_img_top_down_dataset.py
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .base_evaluator import Evaluator
|
||||||
|
from .builder import EVALUATORS
|
||||||
|
from .metric_registry import METRICS
|
||||||
|
from .top_down_eval import (keypoint_auc, keypoint_epe, keypoint_nme,
|
||||||
|
keypoint_pck_accuracy)
|
||||||
|
|
||||||
|
|
||||||
|
@EVALUATORS.register_module
|
||||||
|
class KeyPointEvaluator(Evaluator):
|
||||||
|
""" KeyPoint evaluator.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dataset_name=None,
|
||||||
|
metric_names=['PCK', 'PCKh', 'AUC', 'EPE', 'NME'],
|
||||||
|
pck_thr=0.2,
|
||||||
|
pckh_thr=0.7,
|
||||||
|
auc_nor=30):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_name: eval dataset name
|
||||||
|
metric_names: eval metrics name
|
||||||
|
pck_thr (float): PCK threshold, default as 0.2.
|
||||||
|
pckh_thr (float): PCKh threshold, default as 0.7.
|
||||||
|
auc_nor (float): AUC normalization factor, default as 30 pixel.
|
||||||
|
"""
|
||||||
|
super(KeyPointEvaluator, self).__init__(dataset_name, metric_names)
|
||||||
|
self._pck_thr = pck_thr
|
||||||
|
self._pckh_thr = pckh_thr
|
||||||
|
self._auc_nor = auc_nor
|
||||||
|
self.dataset_name = dataset_name
|
||||||
|
allowed_metrics = ['PCK', 'PCKh', 'AUC', 'EPE', 'NME']
|
||||||
|
for metric in metric_names:
|
||||||
|
if metric not in allowed_metrics:
|
||||||
|
raise KeyError(f'metric {metric} is not supported')
|
||||||
|
|
||||||
|
def _evaluate_impl(self, preds, coco_db, **kwargs):
|
||||||
|
''' keypoint evaluation code which will be run after
|
||||||
|
all test batched data are predicted
|
||||||
|
|
||||||
|
Args:
|
||||||
|
preds: dict with key ``keypoints`` whose shape is Nx3
|
||||||
|
coco_db: the db of wholebody coco datasource, sorted by 'bbox_id'
|
||||||
|
|
||||||
|
Return:
|
||||||
|
a dict, each key is metric_name, value is metric value
|
||||||
|
'''
|
||||||
|
assert len(preds) == len(coco_db)
|
||||||
|
eval_res = {}
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
gts = []
|
||||||
|
masks = []
|
||||||
|
box_sizes = []
|
||||||
|
threshold_bbox = []
|
||||||
|
threshold_head_box = []
|
||||||
|
|
||||||
|
for pred, item in zip(preds, coco_db):
|
||||||
|
outputs.append(np.array(pred['keypoints'])[:, :-1])
|
||||||
|
gts.append(np.array(item['joints_3d'])[:, :-1])
|
||||||
|
masks.append((np.array(item['joints_3d_visible'])[:, 0]) > 0)
|
||||||
|
if 'PCK' in self.metric_names:
|
||||||
|
bbox = np.array(item['bbox'])
|
||||||
|
bbox_thr = np.max(bbox[2:])
|
||||||
|
threshold_bbox.append(np.array([bbox_thr, bbox_thr]))
|
||||||
|
if 'PCKh' in self.metric_names:
|
||||||
|
head_box_thr = item['head_size']
|
||||||
|
threshold_head_box.append(
|
||||||
|
np.array([head_box_thr, head_box_thr]))
|
||||||
|
box_sizes.append(item.get('box_size', 1))
|
||||||
|
|
||||||
|
outputs = np.array(outputs)
|
||||||
|
gts = np.array(gts)
|
||||||
|
masks = np.array(masks)
|
||||||
|
threshold_bbox = np.array(threshold_bbox)
|
||||||
|
threshold_head_box = np.array(threshold_head_box)
|
||||||
|
box_sizes = np.array(box_sizes).reshape([-1, 1])
|
||||||
|
|
||||||
|
if 'PCK' in self.metric_names:
|
||||||
|
_, pck, _ = keypoint_pck_accuracy(outputs, gts, masks,
|
||||||
|
self._pck_thr, threshold_bbox)
|
||||||
|
eval_res['PCK'] = pck
|
||||||
|
|
||||||
|
if 'PCKh' in self.metric_names:
|
||||||
|
_, pckh, _ = keypoint_pck_accuracy(outputs, gts, masks,
|
||||||
|
self._pckh_thr,
|
||||||
|
threshold_head_box)
|
||||||
|
eval_res['PCKh'] = pckh
|
||||||
|
|
||||||
|
if 'AUC' in self.metric_names:
|
||||||
|
eval_res['AUC'] = keypoint_auc(outputs, gts, masks, self._auc_nor)
|
||||||
|
|
||||||
|
if 'EPE' in self.metric_names:
|
||||||
|
eval_res['EPE'] = keypoint_epe(outputs, gts, masks)
|
||||||
|
|
||||||
|
if 'NME' in self.metric_names:
|
||||||
|
normalize_factor = self._get_normalize_factor(
|
||||||
|
gts=gts, box_sizes=box_sizes)
|
||||||
|
eval_res['NME'] = keypoint_nme(outputs, gts, masks,
|
||||||
|
normalize_factor)
|
||||||
|
return eval_res
|
||||||
|
|
||||||
|
def _get_normalize_factor(self, gts, *args, **kwargs):
|
||||||
|
"""Get the normalize factor. generally inter-ocular distance measured
|
||||||
|
as the Euclidean distance between the outer corners of the eyes is
|
||||||
|
used. This function should be overrode, to measure NME.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gts (np.ndarray[N, K, 2]): Groundtruth keypoint location.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray[N, 2]: normalized factor
|
||||||
|
"""
|
||||||
|
return np.ones([gts.shape[0], 2], dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
METRICS.register_default_best_metric(KeyPointEvaluator, 'PCK', 'max')
|
@ -178,6 +178,86 @@ def keypoint_pck_accuracy(pred, gt, mask, thr, normalize):
|
|||||||
return acc, avg_acc, cnt
|
return acc, avg_acc, cnt
|
||||||
|
|
||||||
|
|
||||||
|
def keypoint_auc(pred, gt, mask, normalize, num_step=20):
|
||||||
|
"""Calculate the pose accuracy of PCK for each individual keypoint and the
|
||||||
|
averaged accuracy across all keypoints for coordinates.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- batch_size: N
|
||||||
|
- num_keypoints: K
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred (np.ndarray[N, K, 2]): Predicted keypoint location.
|
||||||
|
gt (np.ndarray[N, K, 2]): Groundtruth keypoint location.
|
||||||
|
mask (np.ndarray[N, K]): Visibility of the target. False for invisible
|
||||||
|
joints, and True for visible. Invisible joints will be ignored for
|
||||||
|
accuracy calculation.
|
||||||
|
normalize (float): Normalization factor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: Area under curve.
|
||||||
|
"""
|
||||||
|
nor = np.tile(np.array([[normalize, normalize]]), (pred.shape[0], 1))
|
||||||
|
x = [1.0 * i / num_step for i in range(num_step)]
|
||||||
|
y = []
|
||||||
|
for thr in x:
|
||||||
|
_, avg_acc, _ = keypoint_pck_accuracy(pred, gt, mask, thr, nor)
|
||||||
|
y.append(avg_acc)
|
||||||
|
|
||||||
|
auc = 0
|
||||||
|
for i in range(num_step):
|
||||||
|
auc += 1.0 / num_step * y[i]
|
||||||
|
return auc
|
||||||
|
|
||||||
|
|
||||||
|
def keypoint_nme(pred, gt, mask, normalize_factor):
|
||||||
|
"""Calculate the normalized mean error (NME).
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- batch_size: N
|
||||||
|
- num_keypoints: K
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred (np.ndarray[N, K, 2]): Predicted keypoint location.
|
||||||
|
gt (np.ndarray[N, K, 2]): Groundtruth keypoint location.
|
||||||
|
mask (np.ndarray[N, K]): Visibility of the target. False for invisible
|
||||||
|
joints, and True for visible. Invisible joints will be ignored for
|
||||||
|
accuracy calculation.
|
||||||
|
normalize_factor (np.ndarray[N, 2]): Normalization factor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: normalized mean error
|
||||||
|
"""
|
||||||
|
distances = _calc_distances(pred, gt, mask, normalize_factor)
|
||||||
|
distance_valid = distances[distances != -1]
|
||||||
|
return distance_valid.sum() / max(1, len(distance_valid))
|
||||||
|
|
||||||
|
|
||||||
|
def keypoint_epe(pred, gt, mask):
|
||||||
|
"""Calculate the end-point error.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- batch_size: N
|
||||||
|
- num_keypoints: K
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred (np.ndarray[N, K, 2]): Predicted keypoint location.
|
||||||
|
gt (np.ndarray[N, K, 2]): Groundtruth keypoint location.
|
||||||
|
mask (np.ndarray[N, K]): Visibility of the target. False for invisible
|
||||||
|
joints, and True for visible. Invisible joints will be ignored for
|
||||||
|
accuracy calculation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: Average end-point error.
|
||||||
|
"""
|
||||||
|
|
||||||
|
distances = _calc_distances(
|
||||||
|
pred, gt, mask,
|
||||||
|
np.ones((pred.shape[0], pred.shape[2]), dtype=np.float32))
|
||||||
|
distance_valid = distances[distances != -1]
|
||||||
|
return distance_valid.sum() / max(1, len(distance_valid))
|
||||||
|
|
||||||
|
|
||||||
def _taylor(heatmap, coord):
|
def _taylor(heatmap, coord):
|
||||||
"""Distribution aware coordinate decoding method.
|
"""Distribution aware coordinate decoding method.
|
||||||
|
|
||||||
|
@ -83,7 +83,7 @@ def fliplr_regression(regression,
|
|||||||
|
|
||||||
allowed_center_mode = {'static', 'root'}
|
allowed_center_mode = {'static', 'root'}
|
||||||
assert center_mode in allowed_center_mode, 'Get invalid center_mode ' \
|
assert center_mode in allowed_center_mode, 'Get invalid center_mode ' \
|
||||||
f'{center_mode}, allowed choices are {allowed_center_mode}'
|
f'{center_mode}, allowed choices are {allowed_center_mode}'
|
||||||
|
|
||||||
if center_mode == 'static':
|
if center_mode == 'static':
|
||||||
x_c = center_x
|
x_c = center_x
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
from . import data_sources # pylint: disable=unused-import
|
from . import data_sources # pylint: disable=unused-import
|
||||||
from . import pipelines # pylint: disable=unused-import
|
from . import pipelines # pylint: disable=unused-import
|
||||||
|
from .hand_coco_wholebody_dataset import HandCocoWholeBodyDataset
|
||||||
from .top_down import PoseTopDownDataset
|
from .top_down import PoseTopDownDataset
|
||||||
|
|
||||||
__all__ = ['PoseTopDownDataset']
|
__all__ = ['PoseTopDownDataset', 'HandCocoWholeBodyDataset']
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
from .coco import PoseTopDownSourceCoco
|
from .coco import PoseTopDownSourceCoco
|
||||||
|
from .hand import HandCocoPoseTopDownSource
|
||||||
from .top_down import PoseTopDownSource
|
from .top_down import PoseTopDownSource
|
||||||
|
|
||||||
__all__ = ['PoseTopDownSourceCoco', 'PoseTopDownSource']
|
__all__ = [
|
||||||
|
'PoseTopDownSourceCoco', 'PoseTopDownSource', 'HandCocoPoseTopDownSource'
|
||||||
|
]
|
||||||
|
3
easycv/datasets/pose/data_sources/hand/__init__.py
Normal file
3
easycv/datasets/pose/data_sources/hand/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
# !/usr/bin/env python
|
||||||
|
# -*- encoding: utf-8 -*-
|
||||||
|
from .coco_hand import HandCocoPoseTopDownSource
|
276
easycv/datasets/pose/data_sources/hand/coco_hand.py
Normal file
276
easycv/datasets/pose/data_sources/hand/coco_hand.py
Normal file
@ -0,0 +1,276 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
# Adapt from
|
||||||
|
# https://github.com/open-mmlab/mmpose/blob/master/mmpose/datasets/datasets/hand/hand_coco_wholebody_dataset.py
|
||||||
|
import logging
|
||||||
|
import os.path as osp
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from easycv.datasets.registry import DATASOURCES
|
||||||
|
from ..top_down import PoseTopDownSource
|
||||||
|
|
||||||
|
COCO_WHOLEBODY_HAND_DATASET_INFO = dict(
|
||||||
|
dataset_name='coco_wholebody_hand',
|
||||||
|
paper_info=dict(
|
||||||
|
author='Jin, Sheng and Xu, Lumin and Xu, Jin and '
|
||||||
|
'Wang, Can and Liu, Wentao and '
|
||||||
|
'Qian, Chen and Ouyang, Wanli and Luo, Ping',
|
||||||
|
title='Whole-Body Human Pose Estimation in the Wild',
|
||||||
|
container='Proceedings of the European '
|
||||||
|
'Conference on Computer Vision (ECCV)',
|
||||||
|
year='2020',
|
||||||
|
homepage='https://github.com/jin-s13/COCO-WholeBody/',
|
||||||
|
),
|
||||||
|
keypoint_info={
|
||||||
|
0:
|
||||||
|
dict(name='wrist', id=0, color=[255, 255, 255], type='', swap=''),
|
||||||
|
1:
|
||||||
|
dict(name='thumb1', id=1, color=[255, 128, 0], type='', swap=''),
|
||||||
|
2:
|
||||||
|
dict(name='thumb2', id=2, color=[255, 128, 0], type='', swap=''),
|
||||||
|
3:
|
||||||
|
dict(name='thumb3', id=3, color=[255, 128, 0], type='', swap=''),
|
||||||
|
4:
|
||||||
|
dict(name='thumb4', id=4, color=[255, 128, 0], type='', swap=''),
|
||||||
|
5:
|
||||||
|
dict(
|
||||||
|
name='forefinger1', id=5, color=[255, 153, 255], type='', swap=''),
|
||||||
|
6:
|
||||||
|
dict(
|
||||||
|
name='forefinger2', id=6, color=[255, 153, 255], type='', swap=''),
|
||||||
|
7:
|
||||||
|
dict(
|
||||||
|
name='forefinger3', id=7, color=[255, 153, 255], type='', swap=''),
|
||||||
|
8:
|
||||||
|
dict(
|
||||||
|
name='forefinger4', id=8, color=[255, 153, 255], type='', swap=''),
|
||||||
|
9:
|
||||||
|
dict(
|
||||||
|
name='middle_finger1',
|
||||||
|
id=9,
|
||||||
|
color=[102, 178, 255],
|
||||||
|
type='',
|
||||||
|
swap=''),
|
||||||
|
10:
|
||||||
|
dict(
|
||||||
|
name='middle_finger2',
|
||||||
|
id=10,
|
||||||
|
color=[102, 178, 255],
|
||||||
|
type='',
|
||||||
|
swap=''),
|
||||||
|
11:
|
||||||
|
dict(
|
||||||
|
name='middle_finger3',
|
||||||
|
id=11,
|
||||||
|
color=[102, 178, 255],
|
||||||
|
type='',
|
||||||
|
swap=''),
|
||||||
|
12:
|
||||||
|
dict(
|
||||||
|
name='middle_finger4',
|
||||||
|
id=12,
|
||||||
|
color=[102, 178, 255],
|
||||||
|
type='',
|
||||||
|
swap=''),
|
||||||
|
13:
|
||||||
|
dict(
|
||||||
|
name='ring_finger1', id=13, color=[255, 51, 51], type='', swap=''),
|
||||||
|
14:
|
||||||
|
dict(
|
||||||
|
name='ring_finger2', id=14, color=[255, 51, 51], type='', swap=''),
|
||||||
|
15:
|
||||||
|
dict(
|
||||||
|
name='ring_finger3', id=15, color=[255, 51, 51], type='', swap=''),
|
||||||
|
16:
|
||||||
|
dict(
|
||||||
|
name='ring_finger4', id=16, color=[255, 51, 51], type='', swap=''),
|
||||||
|
17:
|
||||||
|
dict(name='pinky_finger1', id=17, color=[0, 255, 0], type='', swap=''),
|
||||||
|
18:
|
||||||
|
dict(name='pinky_finger2', id=18, color=[0, 255, 0], type='', swap=''),
|
||||||
|
19:
|
||||||
|
dict(name='pinky_finger3', id=19, color=[0, 255, 0], type='', swap=''),
|
||||||
|
20:
|
||||||
|
dict(name='pinky_finger4', id=20, color=[0, 255, 0], type='', swap='')
|
||||||
|
},
|
||||||
|
skeleton_info={
|
||||||
|
0:
|
||||||
|
dict(link=('wrist', 'thumb1'), id=0, color=[255, 128, 0]),
|
||||||
|
1:
|
||||||
|
dict(link=('thumb1', 'thumb2'), id=1, color=[255, 128, 0]),
|
||||||
|
2:
|
||||||
|
dict(link=('thumb2', 'thumb3'), id=2, color=[255, 128, 0]),
|
||||||
|
3:
|
||||||
|
dict(link=('thumb3', 'thumb4'), id=3, color=[255, 128, 0]),
|
||||||
|
4:
|
||||||
|
dict(link=('wrist', 'forefinger1'), id=4, color=[255, 153, 255]),
|
||||||
|
5:
|
||||||
|
dict(link=('forefinger1', 'forefinger2'), id=5, color=[255, 153, 255]),
|
||||||
|
6:
|
||||||
|
dict(link=('forefinger2', 'forefinger3'), id=6, color=[255, 153, 255]),
|
||||||
|
7:
|
||||||
|
dict(link=('forefinger3', 'forefinger4'), id=7, color=[255, 153, 255]),
|
||||||
|
8:
|
||||||
|
dict(link=('wrist', 'middle_finger1'), id=8, color=[102, 178, 255]),
|
||||||
|
9:
|
||||||
|
dict(
|
||||||
|
link=('middle_finger1', 'middle_finger2'),
|
||||||
|
id=9,
|
||||||
|
color=[102, 178, 255]),
|
||||||
|
10:
|
||||||
|
dict(
|
||||||
|
link=('middle_finger2', 'middle_finger3'),
|
||||||
|
id=10,
|
||||||
|
color=[102, 178, 255]),
|
||||||
|
11:
|
||||||
|
dict(
|
||||||
|
link=('middle_finger3', 'middle_finger4'),
|
||||||
|
id=11,
|
||||||
|
color=[102, 178, 255]),
|
||||||
|
12:
|
||||||
|
dict(link=('wrist', 'ring_finger1'), id=12, color=[255, 51, 51]),
|
||||||
|
13:
|
||||||
|
dict(
|
||||||
|
link=('ring_finger1', 'ring_finger2'), id=13, color=[255, 51, 51]),
|
||||||
|
14:
|
||||||
|
dict(
|
||||||
|
link=('ring_finger2', 'ring_finger3'), id=14, color=[255, 51, 51]),
|
||||||
|
15:
|
||||||
|
dict(
|
||||||
|
link=('ring_finger3', 'ring_finger4'), id=15, color=[255, 51, 51]),
|
||||||
|
16:
|
||||||
|
dict(link=('wrist', 'pinky_finger1'), id=16, color=[0, 255, 0]),
|
||||||
|
17:
|
||||||
|
dict(
|
||||||
|
link=('pinky_finger1', 'pinky_finger2'), id=17, color=[0, 255, 0]),
|
||||||
|
18:
|
||||||
|
dict(
|
||||||
|
link=('pinky_finger2', 'pinky_finger3'), id=18, color=[0, 255, 0]),
|
||||||
|
19:
|
||||||
|
dict(
|
||||||
|
link=('pinky_finger3', 'pinky_finger4'), id=19, color=[0, 255, 0])
|
||||||
|
},
|
||||||
|
joint_weights=[1.] * 21,
|
||||||
|
sigmas=[
|
||||||
|
0.029, 0.022, 0.035, 0.037, 0.047, 0.026, 0.025, 0.024, 0.035, 0.018,
|
||||||
|
0.024, 0.022, 0.026, 0.017, 0.021, 0.021, 0.032, 0.02, 0.019, 0.022,
|
||||||
|
0.031
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
@DATASOURCES.register_module()
|
||||||
|
class HandCocoPoseTopDownSource(PoseTopDownSource):
|
||||||
|
"""Coco Whole-Body-Hand Source for top-down hand pose estimation.
|
||||||
|
|
||||||
|
"Whole-Body Human Pose Estimation in the Wild", ECCV'2020.
|
||||||
|
More details can be found in the `paper
|
||||||
|
<https://arxiv.org/abs/2007.11858>`__ .
|
||||||
|
|
||||||
|
The dataset loads raw features and apply specified transforms
|
||||||
|
to return a dict containing the image tensors and other information.
|
||||||
|
|
||||||
|
COCO-WholeBody Hand keypoint indexes::
|
||||||
|
|
||||||
|
0: 'wrist',
|
||||||
|
1: 'thumb1',
|
||||||
|
2: 'thumb2',
|
||||||
|
3: 'thumb3',
|
||||||
|
4: 'thumb4',
|
||||||
|
5: 'forefinger1',
|
||||||
|
6: 'forefinger2',
|
||||||
|
7: 'forefinger3',
|
||||||
|
8: 'forefinger4',
|
||||||
|
9: 'middle_finger1',
|
||||||
|
10: 'middle_finger2',
|
||||||
|
11: 'middle_finger3',
|
||||||
|
12: 'middle_finger4',
|
||||||
|
13: 'ring_finger1',
|
||||||
|
14: 'ring_finger2',
|
||||||
|
15: 'ring_finger3',
|
||||||
|
16: 'ring_finger4',
|
||||||
|
17: 'pinky_finger1',
|
||||||
|
18: 'pinky_finger2',
|
||||||
|
19: 'pinky_finger3',
|
||||||
|
20: 'pinky_finger4'
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ann_file (str): Path to the annotation file.
|
||||||
|
img_prefix (str): Path to a directory where images are held.
|
||||||
|
Default: None.
|
||||||
|
data_cfg (dict): config
|
||||||
|
dataset_info (DatasetInfo): A class containing all dataset info.
|
||||||
|
test_mode (bool): Store True when building test or
|
||||||
|
validation dataset. Default: False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
ann_file,
|
||||||
|
img_prefix,
|
||||||
|
data_cfg,
|
||||||
|
dataset_info=None,
|
||||||
|
test_mode=False):
|
||||||
|
|
||||||
|
if dataset_info is None:
|
||||||
|
logging.info(
|
||||||
|
'dataset_info is missing, use default coco wholebody hand dataset info'
|
||||||
|
)
|
||||||
|
dataset_info = COCO_WHOLEBODY_HAND_DATASET_INFO
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
ann_file,
|
||||||
|
img_prefix,
|
||||||
|
data_cfg,
|
||||||
|
dataset_info=dataset_info,
|
||||||
|
test_mode=test_mode)
|
||||||
|
|
||||||
|
self.ann_info['use_different_joint_weights'] = False
|
||||||
|
self.db = self._get_db()
|
||||||
|
|
||||||
|
print(f'=> num_images: {self.num_images}')
|
||||||
|
print(f'=> load {len(self.db)} samples')
|
||||||
|
|
||||||
|
def _get_db(self):
|
||||||
|
"""Load dataset."""
|
||||||
|
gt_db = []
|
||||||
|
bbox_id = 0
|
||||||
|
num_joints = self.ann_info['num_joints']
|
||||||
|
for img_id in self.img_ids:
|
||||||
|
|
||||||
|
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=False)
|
||||||
|
objs = self.coco.loadAnns(ann_ids)
|
||||||
|
|
||||||
|
for obj in objs:
|
||||||
|
for type in ['left', 'right']:
|
||||||
|
if obj[f'{type}hand_valid'] and max(
|
||||||
|
obj[f'{type}hand_kpts']) > 0:
|
||||||
|
joints_3d = np.zeros((num_joints, 3), dtype=np.float32)
|
||||||
|
joints_3d_visible = np.zeros((num_joints, 3),
|
||||||
|
dtype=np.float32)
|
||||||
|
|
||||||
|
keypoints = np.array(obj[f'{type}hand_kpts']).reshape(
|
||||||
|
-1, 3)
|
||||||
|
joints_3d[:, :2] = keypoints[:, :2]
|
||||||
|
joints_3d_visible[:, :2] = np.minimum(
|
||||||
|
1, keypoints[:, 2:3])
|
||||||
|
|
||||||
|
image_file = osp.join(self.img_prefix,
|
||||||
|
self.id2name[img_id])
|
||||||
|
center, scale = self._xywh2cs(
|
||||||
|
*obj[f'{type}hand_box'][:4])
|
||||||
|
gt_db.append({
|
||||||
|
'image_file': image_file,
|
||||||
|
'image_id': img_id,
|
||||||
|
'rotation': 0,
|
||||||
|
'center': center,
|
||||||
|
'scale': scale,
|
||||||
|
'joints_3d': joints_3d,
|
||||||
|
'joints_3d_visible': joints_3d_visible,
|
||||||
|
'dataset': self.dataset_name,
|
||||||
|
'bbox': obj[f'{type}hand_box'],
|
||||||
|
'bbox_score': 1,
|
||||||
|
'bbox_id': bbox_id
|
||||||
|
})
|
||||||
|
bbox_id = bbox_id + 1
|
||||||
|
gt_db = sorted(gt_db, key=lambda x: x['bbox_id'])
|
||||||
|
|
||||||
|
return gt_db
|
70
easycv/datasets/pose/hand_coco_wholebody_dataset.py
Normal file
70
easycv/datasets/pose/hand_coco_wholebody_dataset.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
# Adapt from
|
||||||
|
# https://github.com/open-mmlab/mmpose/blob/master/mmpose/datasets/datasets/hand/hand_coco_wholebody_dataset.py
|
||||||
|
|
||||||
|
from easycv.core.evaluation.keypoint_eval import KeyPointEvaluator
|
||||||
|
from easycv.datasets.pose.data_sources.coco import PoseTopDownSource
|
||||||
|
from easycv.datasets.registry import DATASETS
|
||||||
|
from easycv.datasets.shared.base import BaseDataset
|
||||||
|
|
||||||
|
|
||||||
|
@DATASETS.register_module()
|
||||||
|
class HandCocoWholeBodyDataset(BaseDataset):
|
||||||
|
"""CocoWholeBodyDataset for top-down hand pose estimation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_source: Data_source config dict
|
||||||
|
pipeline: Pipeline config list
|
||||||
|
profiling: If set True, will print pipeline time
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, data_source, pipeline, profiling=False):
|
||||||
|
super(HandCocoWholeBodyDataset, self).__init__(data_source, pipeline,
|
||||||
|
profiling)
|
||||||
|
|
||||||
|
if not isinstance(self.data_source, PoseTopDownSource):
|
||||||
|
raise ValueError('Only support `PoseTopDownSource`, but get %s' %
|
||||||
|
self.data_source)
|
||||||
|
|
||||||
|
def evaluate(self, outputs, evaluators, **kwargs):
|
||||||
|
if len(evaluators) > 1 or not isinstance(evaluators[0],
|
||||||
|
KeyPointEvaluator):
|
||||||
|
raise ValueError(
|
||||||
|
'HandCocoWholeBodyDataset only support one `KeyPointEvaluator` now, '
|
||||||
|
'but get %s' % evaluators)
|
||||||
|
evaluator = evaluators[0]
|
||||||
|
|
||||||
|
image_ids = outputs['image_ids']
|
||||||
|
preds = outputs['preds']
|
||||||
|
boxes = outputs['boxes']
|
||||||
|
bbox_ids = outputs['bbox_ids']
|
||||||
|
|
||||||
|
kpts = []
|
||||||
|
for i, image_id in enumerate(image_ids):
|
||||||
|
kpts.append({
|
||||||
|
'keypoints': preds[i],
|
||||||
|
'center': boxes[i][0:2],
|
||||||
|
'scale': boxes[i][2:4],
|
||||||
|
'area': boxes[i][4],
|
||||||
|
'score': boxes[i][5],
|
||||||
|
'image_id': image_id,
|
||||||
|
'bbox_id': bbox_ids[i]
|
||||||
|
})
|
||||||
|
kpts = self._sort_and_unique_bboxes(kpts)
|
||||||
|
eval_res = evaluator.evaluate(kpts, self.data_source.db)
|
||||||
|
return eval_res
|
||||||
|
|
||||||
|
def _sort_and_unique_bboxes(self, kpts, key='bbox_id'):
|
||||||
|
"""sort kpts and remove the repeated ones."""
|
||||||
|
kpts = sorted(kpts, key=lambda x: x[key])
|
||||||
|
num = len(kpts)
|
||||||
|
for i in range(num - 1, 0, -1):
|
||||||
|
if kpts[i][key] == kpts[i - 1][key]:
|
||||||
|
del kpts[i]
|
||||||
|
|
||||||
|
return kpts
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
"""Get the sample given index."""
|
||||||
|
results = self.data_source.get_sample(idx)
|
||||||
|
return self.pipeline(results)
|
@ -1,9 +1,9 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
# Adapt from https://github.com/open-mmlab/mmpose/blob/master/mmpose/datasets/pipelines/top_down_transform.py
|
# Adapt from https://github.com/open-mmlab/mmpose/blob/master/mmpose/datasets/pipelines/top_down_transform.py
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mmcv.parallel import DataContainer as DC
|
from mmcv.parallel import DataContainer as DC
|
||||||
from torchvision.transforms import functional as F
|
|
||||||
|
|
||||||
from easycv.core.post_processing import (affine_transform, fliplr_joints,
|
from easycv.core.post_processing import (affine_transform, fliplr_joints,
|
||||||
get_affine_transform, get_warp_matrix,
|
get_affine_transform, get_warp_matrix,
|
||||||
|
51
tests/core/evaluation/test_keypoint_eval.py
Normal file
51
tests/core/evaluation/test_keypoint_eval.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from easycv.core.evaluation import KeyPointEvaluator
|
||||||
|
|
||||||
|
|
||||||
|
class KeyPointEvaluatorTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||||
|
|
||||||
|
def test_keypoint_evaluator_pck(self):
|
||||||
|
evaluator = KeyPointEvaluator(pck_thr=0.5, pckh_thr=0.5, auc_nor=30)
|
||||||
|
output = np.zeros((5, 3))
|
||||||
|
target = np.zeros((5, 3))
|
||||||
|
mask = np.zeros((5, 3))
|
||||||
|
mask[:, :2] = 1
|
||||||
|
# first channel
|
||||||
|
output[0] = [10, 0, 0]
|
||||||
|
target[0] = [10, 0, 0]
|
||||||
|
# second channel
|
||||||
|
output[1] = [20, 20, 0]
|
||||||
|
target[1] = [10, 10, 0]
|
||||||
|
# third channel
|
||||||
|
output[2] = [0, 0, 0]
|
||||||
|
target[2] = [-1, 0, 0]
|
||||||
|
# fourth channel
|
||||||
|
output[3] = [30, 30, 0]
|
||||||
|
target[3] = [30, 30, 0]
|
||||||
|
# fifth channel
|
||||||
|
output[4] = [0, 10, 0]
|
||||||
|
target[4] = [0, 10, 0]
|
||||||
|
preds = {'keypoints': output}
|
||||||
|
db = {
|
||||||
|
'joints_3d': target,
|
||||||
|
'joints_3d_visible': mask,
|
||||||
|
'bbox': [10, 10, 10, 10],
|
||||||
|
'head_size': 10
|
||||||
|
}
|
||||||
|
eval_res = evaluator.evaluate([preds, preds], [db, db])
|
||||||
|
self.assertAlmostEqual(eval_res['PCK'], 0.8)
|
||||||
|
self.assertAlmostEqual(eval_res['PCKh'], 0.8)
|
||||||
|
self.assertAlmostEqual(eval_res['EPE'], 3.0284271240234375)
|
||||||
|
self.assertAlmostEqual(eval_res['AUC'], 0.86)
|
||||||
|
self.assertAlmostEqual(eval_res['NME'], 3.0284271240234375)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
59
tests/datasets/pose/data_sources/test_coco_hand.py
Normal file
59
tests/datasets/pose/data_sources/test_coco_hand.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
import random
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from tests.ut_config import SMALL_COCO_WHOLE_BODY_HAND_ROOT
|
||||||
|
|
||||||
|
from easycv.datasets.pose.data_sources import HandCocoPoseTopDownSource
|
||||||
|
|
||||||
|
_DATA_CFG = dict(
|
||||||
|
image_size=[256, 256],
|
||||||
|
heatmap_size=[64, 64],
|
||||||
|
num_output_channels=21,
|
||||||
|
num_joints=21,
|
||||||
|
dataset_channel=[list(range(21))],
|
||||||
|
inference_channel=list(range(21)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HandCocoPoseSourceCocoTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||||
|
|
||||||
|
def test_top_down_source_coco(self):
|
||||||
|
data_source = HandCocoPoseTopDownSource(
|
||||||
|
data_cfg=_DATA_CFG,
|
||||||
|
ann_file=
|
||||||
|
f'{SMALL_COCO_WHOLE_BODY_HAND_ROOT}/annotations/small_whole_body_hand_coco.json',
|
||||||
|
img_prefix=f'{SMALL_COCO_WHOLE_BODY_HAND_ROOT}/train2017/')
|
||||||
|
index_list = random.choices(list(range(4)), k=3)
|
||||||
|
for idx in index_list:
|
||||||
|
data = data_source.get_sample(idx)
|
||||||
|
self.assertIn('image_file', data)
|
||||||
|
self.assertIn('image_id', data)
|
||||||
|
self.assertIn('bbox_score', data)
|
||||||
|
self.assertIn('bbox_id', data)
|
||||||
|
self.assertIn('image_id', data)
|
||||||
|
self.assertEqual(data['center'].shape, (2, ))
|
||||||
|
self.assertEqual(data['scale'].shape, (2, ))
|
||||||
|
self.assertEqual(len(data['bbox']), 4)
|
||||||
|
self.assertEqual(data['joints_3d'].shape, (21, 3))
|
||||||
|
self.assertEqual(data['joints_3d_visible'].shape, (21, 3))
|
||||||
|
self.assertEqual(data['img'].shape[-1], 3)
|
||||||
|
ann_info = data['ann_info']
|
||||||
|
self.assertEqual(ann_info['image_size'].all(),
|
||||||
|
np.array([256, 256]).all())
|
||||||
|
self.assertEqual(ann_info['heatmap_size'].all(),
|
||||||
|
np.array([64, 64]).all())
|
||||||
|
self.assertEqual(ann_info['num_joints'], 21)
|
||||||
|
self.assertEqual(len(ann_info['inference_channel']), 21)
|
||||||
|
self.assertEqual(ann_info['num_output_channels'], 21)
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertEqual(len(data_source), 4)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
75
tests/datasets/pose/test_coco_whole_body_hand_dataset.py
Normal file
75
tests/datasets/pose/test_coco_whole_body_hand_dataset.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from tests.ut_config import SMALL_COCO_WHOLE_BODY_HAND_ROOT
|
||||||
|
|
||||||
|
from easycv.datasets.pose import HandCocoWholeBodyDataset
|
||||||
|
|
||||||
|
_DATA_CFG = dict(
|
||||||
|
image_size=[256, 256],
|
||||||
|
heatmap_size=[64, 64],
|
||||||
|
num_output_channels=21,
|
||||||
|
num_joints=21,
|
||||||
|
dataset_channel=[list(range(21))],
|
||||||
|
inference_channel=list(range(21)))
|
||||||
|
|
||||||
|
_DATASET_ARGS = [{
|
||||||
|
'data_source':
|
||||||
|
dict(
|
||||||
|
type='HandCocoPoseTopDownSource',
|
||||||
|
data_cfg=_DATA_CFG,
|
||||||
|
ann_file=
|
||||||
|
f'{SMALL_COCO_WHOLE_BODY_HAND_ROOT}/annotations/small_whole_body_hand_coco.json',
|
||||||
|
img_prefix=f'{SMALL_COCO_WHOLE_BODY_HAND_ROOT}/train2017/'),
|
||||||
|
'pipeline': [
|
||||||
|
dict(type='TopDownRandomFlip', flip_prob=0.5),
|
||||||
|
dict(type='TopDownAffine'),
|
||||||
|
dict(type='MMToTensor'),
|
||||||
|
dict(type='TopDownGenerateTarget', sigma=3),
|
||||||
|
dict(
|
||||||
|
type='PoseCollect',
|
||||||
|
keys=['img', 'target', 'target_weight'],
|
||||||
|
meta_keys=[
|
||||||
|
'image_file', 'joints_3d', 'flip_pairs', 'joints_3d_visible',
|
||||||
|
'center', 'scale', 'rotation', 'bbox_score'
|
||||||
|
])
|
||||||
|
]
|
||||||
|
}, {}]
|
||||||
|
|
||||||
|
|
||||||
|
class PoseTopDownDatasetTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_dataset(index):
|
||||||
|
dataset = HandCocoWholeBodyDataset(
|
||||||
|
data_source=_DATASET_ARGS[index].get('data_source', None),
|
||||||
|
pipeline=_DATASET_ARGS[index].get('pipeline', None))
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
def test_0(self, index=0):
|
||||||
|
dataset = self.build_dataset(index)
|
||||||
|
ann_info = dataset.data_source.ann_info
|
||||||
|
|
||||||
|
self.assertEqual(len(dataset), 4)
|
||||||
|
for i, batch in enumerate(dataset):
|
||||||
|
self.assertEqual(
|
||||||
|
batch['img'].shape,
|
||||||
|
torch.Size([3] + list(ann_info['image_size'][::-1])))
|
||||||
|
self.assertEqual(batch['target'].shape,
|
||||||
|
(ann_info['num_joints'], ) +
|
||||||
|
tuple(ann_info['heatmap_size'][::-1]))
|
||||||
|
self.assertEqual(batch['img_metas'].data['joints_3d'].shape,
|
||||||
|
(ann_info['num_joints'], 3))
|
||||||
|
self.assertIn('center', batch['img_metas'].data)
|
||||||
|
self.assertIn('scale', batch['img_metas'].data)
|
||||||
|
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
@ -118,3 +118,4 @@ PRETRAINED_MODEL_SEGFORMER = os.path.join(
|
|||||||
)
|
)
|
||||||
MODEL_CONFIG_SEGFORMER = (
|
MODEL_CONFIG_SEGFORMER = (
|
||||||
'./configs/segmentation/segformer/segformer_b0_coco.py')
|
'./configs/segmentation/segformer/segformer_b0_coco.py')
|
||||||
|
SMALL_COCO_WHOLE_BODY_HAND_ROOT = 'data/test/pose/hand/small_whole_body_hand_coco'
|
||||||
|
Loading…
x
Reference in New Issue
Block a user