1
0
mirror of https://github.com/alibaba/EasyCV.git synced 2025-06-03 14:49:00 +08:00
2022-04-02 20:01:06 +08:00

82 lines
2.9 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import random
import unittest
from tests.ut_config import (DET_DATA_RAW_LOCAL, DET_DATA_SMALL_VOC_LOCAL,
VOC_CLASSES)
from easycv.datasets.builder import build_datasource
class DetSourceCocoTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_concat_source(self):
data_root = DET_DATA_RAW_LOCAL
data_source = dict(
type='SourceConcat',
data_source_list=[
dict(
type='DetSourceRaw',
img_root_path=os.path.join(data_root, 'images'),
label_root_path=os.path.join(data_root, 'labels')),
dict(
type='DetSourceRaw',
img_root_path=os.path.join(data_root, 'images'),
label_root_path=os.path.join(data_root, 'labels'))
])
data_source = build_datasource(data_source)
index_list = random.choices(list(range(20)), k=3)
for idx in index_list:
data = data_source.get_sample(idx)
self.assertEqual(len(data['img_shape']), 3)
self.assertEqual(data['img_fields'], ['img'])
self.assertEqual(data['gt_bboxes'].shape[-1], 4)
self.assertIn('filename', data)
self.assertIn('gt_labels', data)
self.assertEqual(data['img'].shape[-1], 3)
self.assertEqual(len(data['img_shape']), 3)
length = data_source.get_length()
self.assertEqual(length, 252)
def test_concat_diff_source(self):
raw_data_root = DET_DATA_RAW_LOCAL
voc_data_root = DET_DATA_SMALL_VOC_LOCAL
data_source = dict(
type='SourceConcat',
data_source_list=[
dict(
type='DetSourceVOC',
path=os.path.join(voc_data_root,
'ImageSets/Main/train_20.txt'),
classes=VOC_CLASSES),
dict(
type='DetSourceRaw',
img_root_path=os.path.join(raw_data_root, 'images'),
label_root_path=os.path.join(raw_data_root, 'labels'))
])
data_source = build_datasource(data_source)
index_list = random.choices(list(range(20)), k=3)
for idx in index_list:
data = data_source.get_sample(idx)
self.assertEqual(len(data['img_shape']), 3)
self.assertEqual(data['img_fields'], ['img'])
self.assertEqual(data['gt_bboxes'].shape[-1], 4)
self.assertIn('filename', data)
self.assertIn('gt_labels', data)
self.assertEqual(data['img'].shape[-1], 3)
self.assertEqual(len(data['img_shape']), 3)
length = data_source.get_length()
self.assertEqual(length, 146)
if __name__ == '__main__':
unittest.main()