EasyCV/tests/datasets/pose/test_pose_top_down_dataset.py
Cathy0908 a11f200ec3
refactor get_sample to __getitem__ for all datasources api (#156)
* refactor `get_sample` to `__getitem__` for all datasources api

* fix MMpad config
2022-08-17 14:24:17 +08:00

81 lines
2.4 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
import torch
from tests.ut_config import POSE_DATA_SMALL_COCO_LOCAL
from easycv.datasets.pose import PoseTopDownDataset
_DATA_CFG = dict(
image_size=[288, 384],
heatmap_size=[72, 96],
num_output_channels=17,
num_joints=17,
dataset_channel=[list(range(17))],
inference_channel=list(range(17)),
soft_nms=False,
nms_thr=1.0,
oks_thr=0.9,
vis_thr=0.2,
use_gt_bbox=False,
det_bbox_thr=0.0)
_DATASET_ARGS = [{
'data_source':
dict(
type='PoseTopDownSourceCoco',
data_cfg=_DATA_CFG,
ann_file=f'{POSE_DATA_SMALL_COCO_LOCAL}/train_200.json',
img_prefix=f'{POSE_DATA_SMALL_COCO_LOCAL}/images/'),
'pipeline': [
dict(type='TopDownRandomFlip', flip_prob=0.5),
dict(type='TopDownAffine'),
dict(type='MMToTensor'),
dict(type='TopDownGenerateTarget', sigma=3),
dict(
type='PoseCollect',
keys=['img', 'target', 'target_weight'],
meta_keys=[
'image_file', 'joints_3d', 'flip_pairs', 'joints_3d_visible',
'center', 'scale', 'rotation', 'bbox_score'
])
]
}, {}]
class PoseTopDownDatasetTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
@staticmethod
def build_dataset(index):
dataset = PoseTopDownDataset(
data_source=_DATASET_ARGS[index].get('data_source', None),
pipeline=_DATASET_ARGS[index].get('pipeline', None))
return dataset
def test_0(self, index=0):
dataset = self.build_dataset(index)
ann_info = dataset.data_source.ann_info
self.assertEqual(len(dataset), 420)
for i, batch in enumerate(dataset):
self.assertEqual(
batch['img'].shape,
torch.Size([3] + list(ann_info['image_size'][::-1])))
self.assertEqual(batch['target'].shape,
(ann_info['num_joints'], ) +
tuple(ann_info['heatmap_size'][::-1]))
self.assertEqual(batch['img_metas'].data['joints_3d'].shape,
(ann_info['num_joints'], 3))
self.assertIn('center', batch['img_metas'].data)
self.assertIn('scale', batch['img_metas'].data)
break
if __name__ == '__main__':
unittest.main()