mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
82 lines
2.9 KiB
Python
82 lines
2.9 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import os
|
|
import random
|
|
import unittest
|
|
|
|
from tests.ut_config import (DET_DATA_RAW_LOCAL, DET_DATA_SMALL_VOC_LOCAL,
|
|
VOC_CLASSES)
|
|
|
|
from easycv.datasets.builder import build_datasource
|
|
|
|
|
|
class DetSourceCocoTest(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
|
|
|
def test_concat_source(self):
|
|
data_root = DET_DATA_RAW_LOCAL
|
|
data_source = dict(
|
|
type='SourceConcat',
|
|
data_source_list=[
|
|
dict(
|
|
type='DetSourceRaw',
|
|
img_root_path=os.path.join(data_root, 'images'),
|
|
label_root_path=os.path.join(data_root, 'labels')),
|
|
dict(
|
|
type='DetSourceRaw',
|
|
img_root_path=os.path.join(data_root, 'images'),
|
|
label_root_path=os.path.join(data_root, 'labels'))
|
|
])
|
|
|
|
data_source = build_datasource(data_source)
|
|
index_list = random.choices(list(range(20)), k=3)
|
|
for idx in index_list:
|
|
data = data_source.get_sample(idx)
|
|
self.assertEqual(len(data['img_shape']), 3)
|
|
self.assertEqual(data['img_fields'], ['img'])
|
|
self.assertEqual(data['gt_bboxes'].shape[-1], 4)
|
|
self.assertIn('filename', data)
|
|
self.assertIn('gt_labels', data)
|
|
self.assertEqual(data['img'].shape[-1], 3)
|
|
self.assertEqual(len(data['img_shape']), 3)
|
|
|
|
length = data_source.get_length()
|
|
self.assertEqual(length, 252)
|
|
|
|
def test_concat_diff_source(self):
|
|
raw_data_root = DET_DATA_RAW_LOCAL
|
|
voc_data_root = DET_DATA_SMALL_VOC_LOCAL
|
|
data_source = dict(
|
|
type='SourceConcat',
|
|
data_source_list=[
|
|
dict(
|
|
type='DetSourceVOC',
|
|
path=os.path.join(voc_data_root,
|
|
'ImageSets/Main/train_20.txt'),
|
|
classes=VOC_CLASSES),
|
|
dict(
|
|
type='DetSourceRaw',
|
|
img_root_path=os.path.join(raw_data_root, 'images'),
|
|
label_root_path=os.path.join(raw_data_root, 'labels'))
|
|
])
|
|
|
|
data_source = build_datasource(data_source)
|
|
index_list = random.choices(list(range(20)), k=3)
|
|
for idx in index_list:
|
|
data = data_source.get_sample(idx)
|
|
self.assertEqual(len(data['img_shape']), 3)
|
|
self.assertEqual(data['img_fields'], ['img'])
|
|
self.assertEqual(data['gt_bboxes'].shape[-1], 4)
|
|
self.assertIn('filename', data)
|
|
self.assertIn('gt_labels', data)
|
|
self.assertEqual(data['img'].shape[-1], 3)
|
|
self.assertEqual(len(data['img_shape']), 3)
|
|
|
|
length = data_source.get_length()
|
|
self.assertEqual(length, 146)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|