mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
60 lines
2.1 KiB
Python
60 lines
2.1 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import random
|
|
import unittest
|
|
|
|
import numpy as np
|
|
from tests.ut_config import SMALL_COCO_WHOLE_BODY_HAND_ROOT
|
|
|
|
from easycv.datasets.pose.data_sources import HandCocoPoseTopDownSource
|
|
|
|
_DATA_CFG = dict(
|
|
image_size=[256, 256],
|
|
heatmap_size=[64, 64],
|
|
num_output_channels=21,
|
|
num_joints=21,
|
|
dataset_channel=[list(range(21))],
|
|
inference_channel=list(range(21)),
|
|
)
|
|
|
|
|
|
class HandCocoPoseSourceCocoTest(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
|
|
|
def test_top_down_source_coco(self):
|
|
data_source = HandCocoPoseTopDownSource(
|
|
data_cfg=_DATA_CFG,
|
|
ann_file=
|
|
f'{SMALL_COCO_WHOLE_BODY_HAND_ROOT}/annotations/small_whole_body_hand_coco.json',
|
|
img_prefix=f'{SMALL_COCO_WHOLE_BODY_HAND_ROOT}/train2017/')
|
|
index_list = random.choices(list(range(4)), k=3)
|
|
for idx in index_list:
|
|
data = data_source.get_sample(idx)
|
|
self.assertIn('image_file', data)
|
|
self.assertIn('image_id', data)
|
|
self.assertIn('bbox_score', data)
|
|
self.assertIn('bbox_id', data)
|
|
self.assertIn('image_id', data)
|
|
self.assertEqual(data['center'].shape, (2, ))
|
|
self.assertEqual(data['scale'].shape, (2, ))
|
|
self.assertEqual(len(data['bbox']), 4)
|
|
self.assertEqual(data['joints_3d'].shape, (21, 3))
|
|
self.assertEqual(data['joints_3d_visible'].shape, (21, 3))
|
|
self.assertEqual(data['img'].shape[-1], 3)
|
|
ann_info = data['ann_info']
|
|
self.assertEqual(ann_info['image_size'].all(),
|
|
np.array([256, 256]).all())
|
|
self.assertEqual(ann_info['heatmap_size'].all(),
|
|
np.array([64, 64]).all())
|
|
self.assertEqual(ann_info['num_joints'], 21)
|
|
self.assertEqual(len(ann_info['inference_channel']), 21)
|
|
self.assertEqual(ann_info['num_output_channels'], 21)
|
|
break
|
|
|
|
self.assertEqual(len(data_source), 4)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|