EasyCV/tests/datasets/pose/data_sources/test_coco_hand.py

60 lines
2.1 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import random
import unittest
import numpy as np
from tests.ut_config import SMALL_COCO_WHOLE_BODY_HAND_ROOT
from easycv.datasets.pose.data_sources import HandCocoPoseTopDownSource
_DATA_CFG = dict(
image_size=[256, 256],
heatmap_size=[64, 64],
num_output_channels=21,
num_joints=21,
dataset_channel=[list(range(21))],
inference_channel=list(range(21)),
)
class HandCocoPoseSourceCocoTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_top_down_source_coco(self):
data_source = HandCocoPoseTopDownSource(
data_cfg=_DATA_CFG,
ann_file=
f'{SMALL_COCO_WHOLE_BODY_HAND_ROOT}/annotations/small_whole_body_hand_coco.json',
img_prefix=f'{SMALL_COCO_WHOLE_BODY_HAND_ROOT}/train2017/')
index_list = random.choices(list(range(4)), k=3)
for idx in index_list:
data = data_source.get_sample(idx)
self.assertIn('image_file', data)
self.assertIn('image_id', data)
self.assertIn('bbox_score', data)
self.assertIn('bbox_id', data)
self.assertIn('image_id', data)
self.assertEqual(data['center'].shape, (2, ))
self.assertEqual(data['scale'].shape, (2, ))
self.assertEqual(len(data['bbox']), 4)
self.assertEqual(data['joints_3d'].shape, (21, 3))
self.assertEqual(data['joints_3d_visible'].shape, (21, 3))
self.assertEqual(data['img'].shape[-1], 3)
ann_info = data['ann_info']
self.assertEqual(ann_info['image_size'].all(),
np.array([256, 256]).all())
self.assertEqual(ann_info['heatmap_size'].all(),
np.array([64, 64]).all())
self.assertEqual(ann_info['num_joints'], 21)
self.assertEqual(len(ann_info['inference_channel']), 21)
self.assertEqual(ann_info['num_output_channels'], 21)
break
self.assertEqual(len(data_source), 4)
if __name__ == '__main__':
unittest.main()