mirror of https://github.com/alibaba/EasyCV.git
95 lines
3.6 KiB
Python
95 lines
3.6 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import os
|
|
import random
|
|
import unittest
|
|
|
|
import numpy as np
|
|
from tests.ut_config import DET_DATA_RAW_LOCAL
|
|
|
|
from easycv.datasets.detection.data_sources.raw import DetSourceRaw
|
|
from easycv.file import io
|
|
|
|
|
|
class DetSourceRawTest(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
|
data_root = DET_DATA_RAW_LOCAL
|
|
for cache_file in io.glob(os.path.join(data_root, 'labels/*.cache')):
|
|
io.remove(cache_file)
|
|
|
|
def _base_test(self, data_source):
|
|
index_list = random.choices(list(range(20)), k=3)
|
|
for idx in index_list:
|
|
data = data_source.get_sample(idx)
|
|
self.assertIn('img_shape', data)
|
|
self.assertIn('ori_img_shape', data)
|
|
self.assertEqual(len(data['img_shape']), 3)
|
|
self.assertEqual(data['img_fields'], ['img'])
|
|
self.assertEqual(data['bbox_fields'], ['gt_bboxes'])
|
|
self.assertEqual(data['gt_bboxes'].shape[-1], 4)
|
|
self.assertGreaterEqual(len(data['gt_labels']), 1)
|
|
self.assertIn('filename', data)
|
|
self.assertEqual(data['img'].shape[-1], 3)
|
|
|
|
exclude_idx = [i for i in list(range(20)) if i not in index_list]
|
|
for i in range(3):
|
|
|
|
if data_source.cache_at_init:
|
|
self.assertIn('img', data_source.samples_list[exclude_idx[i]])
|
|
if data_source.cache_on_the_fly:
|
|
self.assertNotIn('img',
|
|
data_source.samples_list[exclude_idx[i]])
|
|
|
|
length = data_source.get_length()
|
|
self.assertEqual(length, 128)
|
|
|
|
exists = False
|
|
for idx in range(length):
|
|
result = data_source.get_sample(idx)
|
|
file_name = result.get('filename', '')
|
|
if file_name.endswith('000000000086.jpg'):
|
|
exists = True
|
|
self.assertEqual(result['img_shape'], (640, 512, 3))
|
|
self.assertEqual(result['gt_labels'].tolist(),
|
|
np.array([0, 3, 26], dtype=np.int32).tolist())
|
|
self.assertEqual(
|
|
result['gt_bboxes'].astype(np.int32).tolist(),
|
|
np.array([[152.11008, 183.09982, 281.26004, 559.0697],
|
|
[130.13991, 346.51968, 424.40015, 635.15967],
|
|
[265.10004, 294.5901, 399.7002, 404.7699]],
|
|
dtype=np.int32).tolist())
|
|
self.assertTrue(exists)
|
|
|
|
def test_default(self):
|
|
data_root = DET_DATA_RAW_LOCAL
|
|
data_source = DetSourceRaw(
|
|
img_root_path=os.path.join(data_root, 'images/train2017'),
|
|
label_root_path=os.path.join(data_root, 'labels/train2017'),
|
|
)
|
|
self._base_test(data_source)
|
|
|
|
def test_cache_on_the_fly(self):
|
|
data_root = DET_DATA_RAW_LOCAL
|
|
data_source = DetSourceRaw(
|
|
img_root_path=os.path.join(data_root, 'images/train2017'),
|
|
label_root_path=os.path.join(data_root, 'labels/train2017'),
|
|
cache_on_the_fly=True,
|
|
cache_at_init=False,
|
|
)
|
|
self._base_test(data_source)
|
|
|
|
def test_cache_at_init(self):
|
|
data_root = DET_DATA_RAW_LOCAL
|
|
data_source = DetSourceRaw(
|
|
img_root_path=os.path.join(data_root, 'images/train2017'),
|
|
label_root_path=os.path.join(data_root, 'labels/train2017'),
|
|
cache_on_the_fly=False,
|
|
cache_at_init=True,
|
|
)
|
|
self._base_test(data_source)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|