mirror of https://github.com/alibaba/EasyCV.git
90 lines
2.9 KiB
Python
90 lines
2.9 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import os
|
|
import time
|
|
import unittest
|
|
|
|
import numpy as np
|
|
from tests.ut_config import DET_DATA_RAW_LOCAL, IMG_NORM_CFG_255
|
|
|
|
from easycv.datasets.detection import DetDataset
|
|
|
|
|
|
class DetDatasetTest(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
|
|
|
def _get_dataset(self):
|
|
img_scale = (640, 640)
|
|
data_source_cfg = dict(
|
|
type='DetSourceRaw',
|
|
img_root_path=os.path.join(DET_DATA_RAW_LOCAL, 'images/train2017'),
|
|
label_root_path=os.path.join(DET_DATA_RAW_LOCAL,
|
|
'labels/train2017'))
|
|
|
|
pipeline = [
|
|
dict(type='MMResize', img_scale=img_scale, keep_ratio=True),
|
|
dict(
|
|
type='MMPad',
|
|
pad_to_square=True,
|
|
pad_val=(114.0, 114.0, 114.0)),
|
|
dict(type='MMNormalize', **IMG_NORM_CFG_255),
|
|
dict(type='DefaultFormatBundle'),
|
|
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
|
|
]
|
|
|
|
dataset = DetDataset(data_source=data_source_cfg, pipeline=pipeline)
|
|
|
|
return dataset
|
|
|
|
def test_load(self):
|
|
dataset = self._get_dataset()
|
|
data_num = len(dataset)
|
|
s = time.time()
|
|
for data in dataset:
|
|
pass
|
|
t = time.time()
|
|
print(f'read data done {(t-s)/data_num}s per sample')
|
|
self.assertTrue('img' in data)
|
|
self.assertTrue('gt_labels' in data)
|
|
self.assertTrue('gt_bboxes' in data)
|
|
self.assertTrue('img_metas' in data)
|
|
img_metas = data['img_metas'].data
|
|
self.assertTrue('img_shape' in img_metas)
|
|
self.assertTrue('ori_img_shape' in img_metas)
|
|
|
|
def test_visualize(self):
|
|
dataset = self._get_dataset()
|
|
count = 5
|
|
detection_boxes = []
|
|
detection_classes = []
|
|
img_metas = []
|
|
for i, data in enumerate(dataset):
|
|
detection_boxes.append(
|
|
data['gt_bboxes'].data.cpu().detach().numpy())
|
|
detection_classes.append(
|
|
data['gt_labels'].data.cpu().detach().numpy())
|
|
img_metas.append(data['img_metas'].data)
|
|
if i > count:
|
|
break
|
|
|
|
detection_scores = []
|
|
for classes in detection_classes:
|
|
detection_scores.append(0.1 * np.array(range(len(classes))))
|
|
|
|
results = {
|
|
'detection_boxes': detection_boxes,
|
|
'detection_scores': detection_scores,
|
|
'detection_classes': detection_classes,
|
|
'img_metas': img_metas
|
|
}
|
|
output = dataset.visualize(results, vis_num=2, score_thr=0.1)
|
|
self.assertEqual(len(output['images']), 2)
|
|
self.assertEqual(len(output['img_metas']), 2)
|
|
self.assertEqual(len(output['images'][0].shape), 3)
|
|
self.assertIn('filename', output['img_metas'][0])
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|