mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
81 lines
2.4 KiB
Python
81 lines
2.4 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import unittest
|
|
|
|
import torch
|
|
from tests.ut_config import POSE_DATA_SMALL_COCO_LOCAL
|
|
|
|
from easycv.datasets.pose import PoseTopDownDataset
|
|
|
|
_DATA_CFG = dict(
|
|
image_size=[288, 384],
|
|
heatmap_size=[72, 96],
|
|
num_output_channels=17,
|
|
num_joints=17,
|
|
dataset_channel=[list(range(17))],
|
|
inference_channel=list(range(17)),
|
|
soft_nms=False,
|
|
nms_thr=1.0,
|
|
oks_thr=0.9,
|
|
vis_thr=0.2,
|
|
use_gt_bbox=False,
|
|
det_bbox_thr=0.0)
|
|
|
|
_DATASET_ARGS = [{
|
|
'data_source':
|
|
dict(
|
|
type='PoseTopDownSourceCoco',
|
|
data_cfg=_DATA_CFG,
|
|
ann_file=f'{POSE_DATA_SMALL_COCO_LOCAL}/train_200.json',
|
|
img_prefix=f'{POSE_DATA_SMALL_COCO_LOCAL}/images/'),
|
|
'pipeline': [
|
|
dict(type='TopDownRandomFlip', flip_prob=0.5),
|
|
dict(type='TopDownAffine'),
|
|
dict(type='MMToTensor'),
|
|
dict(type='TopDownGenerateTarget', sigma=3),
|
|
dict(
|
|
type='PoseCollect',
|
|
keys=['img', 'target', 'target_weight'],
|
|
meta_keys=[
|
|
'image_file', 'joints_3d', 'flip_pairs', 'joints_3d_visible',
|
|
'center', 'scale', 'rotation', 'bbox_score'
|
|
])
|
|
]
|
|
}, {}]
|
|
|
|
|
|
class PoseTopDownDatasetTest(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
|
|
|
@staticmethod
|
|
def build_dataset(index):
|
|
dataset = PoseTopDownDataset(
|
|
data_source=_DATASET_ARGS[index].get('data_source', None),
|
|
pipeline=_DATASET_ARGS[index].get('pipeline', None))
|
|
|
|
return dataset
|
|
|
|
def test_0(self, index=0):
|
|
dataset = self.build_dataset(index)
|
|
ann_info = dataset.data_source.ann_info
|
|
|
|
self.assertEqual(len(dataset), 420)
|
|
for i, batch in enumerate(dataset):
|
|
self.assertEqual(
|
|
batch['img'].shape,
|
|
torch.Size([3] + list(ann_info['image_size'][::-1])))
|
|
self.assertEqual(batch['target'].shape,
|
|
(ann_info['num_joints'], ) +
|
|
tuple(ann_info['heatmap_size'][::-1]))
|
|
self.assertEqual(batch['img_metas'].data['joints_3d'].shape,
|
|
(ann_info['num_joints'], 3))
|
|
self.assertIn('center', batch['img_metas'].data)
|
|
self.assertIn('scale', batch['img_metas'].data)
|
|
|
|
break
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|