EasyCV/easycv/datasets/pose/hand_coco_wholebody_dataset.py

71 lines
2.6 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
# Adapt from
# https://github.com/open-mmlab/mmpose/blob/master/mmpose/datasets/datasets/hand/hand_coco_wholebody_dataset.py
from easycv.core.evaluation.keypoint_eval import KeyPointEvaluator
from easycv.datasets.pose.data_sources.coco import PoseTopDownSource
from easycv.datasets.registry import DATASETS
from easycv.datasets.shared.base import BaseDataset
@DATASETS.register_module()
class HandCocoWholeBodyDataset(BaseDataset):
"""CocoWholeBodyDataset for top-down hand pose estimation.
Args:
data_source: Data_source config dict
pipeline: Pipeline config list
profiling: If set True, will print pipeline time
"""
def __init__(self, data_source, pipeline, profiling=False):
super(HandCocoWholeBodyDataset, self).__init__(data_source, pipeline,
profiling)
if not isinstance(self.data_source, PoseTopDownSource):
raise ValueError('Only support `PoseTopDownSource`, but get %s' %
self.data_source)
def evaluate(self, outputs, evaluators, **kwargs):
if len(evaluators) > 1 or not isinstance(evaluators[0],
KeyPointEvaluator):
raise ValueError(
'HandCocoWholeBodyDataset only support one `KeyPointEvaluator` now, '
'but get %s' % evaluators)
evaluator = evaluators[0]
image_ids = outputs['image_ids']
preds = outputs['preds']
boxes = outputs['boxes']
bbox_ids = outputs['bbox_ids']
kpts = []
for i, image_id in enumerate(image_ids):
kpts.append({
'keypoints': preds[i],
'center': boxes[i][0:2],
'scale': boxes[i][2:4],
'area': boxes[i][4],
'score': boxes[i][5],
'image_id': image_id,
'bbox_id': bbox_ids[i]
})
kpts = self._sort_and_unique_bboxes(kpts)
eval_res = evaluator.evaluate(kpts, self.data_source.db)
return eval_res
def _sort_and_unique_bboxes(self, kpts, key='bbox_id'):
"""sort kpts and remove the repeated ones."""
kpts = sorted(kpts, key=lambda x: x[key])
num = len(kpts)
for i in range(num - 1, 0, -1):
if kpts[i][key] == kpts[i - 1][key]:
del kpts[i]
return kpts
def __getitem__(self, idx):
"""Get the sample given index."""
results = self.data_source[idx]
return self.pipeline(results)