EasyCV/tests/datasets/detection/test_raw.py

53 lines
1.7 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import time
import unittest
from tests.ut_config import DET_DATA_RAW_LOCAL, IMG_NORM_CFG_255
from easycv.datasets.detection import DetDataset
class DetDatasetTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_load(self):
img_scale = (640, 640)
data_source_cfg = dict(
type='DetSourceRaw',
img_root_path=os.path.join(DET_DATA_RAW_LOCAL, 'images/train2017'),
label_root_path=os.path.join(DET_DATA_RAW_LOCAL,
'labels/train2017'))
pipeline = [
dict(type='MMResize', img_scale=img_scale, keep_ratio=True),
dict(
type='MMPad',
pad_to_square=True,
pad_val=(114.0, 114.0, 114.0)),
dict(type='MMNormalize', **IMG_NORM_CFG_255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
dataset = DetDataset(data_source=data_source_cfg, pipeline=pipeline)
data_num = len(dataset)
s = time.time()
for data in dataset:
pass
t = time.time()
print(f'read data done {(t-s)/data_num}s per sample')
self.assertTrue('img' in data)
self.assertTrue('gt_labels' in data)
self.assertTrue('gt_bboxes' in data)
self.assertTrue('img_metas' in data)
img_metas = data['img_metas'].data
self.assertTrue('img_shape' in img_metas)
self.assertTrue('ori_img_shape' in img_metas)
if __name__ == '__main__':
unittest.main()