mirror of https://github.com/alibaba/EasyCV.git
96 lines
3.5 KiB
Python
96 lines
3.5 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import os
|
|
import random
|
|
import unittest
|
|
|
|
import numpy as np
|
|
from tests.ut_config import DET_DATA_SMALL_VOC_LOCAL, VOC_CLASSES
|
|
|
|
from easycv.datasets.detection.data_sources.voc import DetSourceVOC
|
|
from easycv.file import io
|
|
|
|
|
|
class DetSourceVOCTest(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
data_root = DET_DATA_SMALL_VOC_LOCAL
|
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
|
for cache_file in io.glob(
|
|
os.path.join(data_root, 'ImageSets/Main/*.cache')):
|
|
io.remove(cache_file)
|
|
|
|
def _base_test(self, data_source):
|
|
index_list = random.choices(list(range(20)), k=3)
|
|
for idx in index_list:
|
|
data = data_source.get_sample(idx)
|
|
self.assertIn('img_shape', data)
|
|
self.assertIn('ori_img_shape', data)
|
|
self.assertIn('filename', data)
|
|
self.assertEqual(len(data['img_shape']), 3)
|
|
self.assertEqual(data['img_fields'], ['img'])
|
|
self.assertEqual(data['bbox_fields'], ['gt_bboxes'])
|
|
self.assertEqual(data['gt_bboxes'].shape[-1], 4)
|
|
self.assertGreaterEqual(len(data['gt_labels']), 1)
|
|
self.assertEqual(data['img'].shape[-1], 3)
|
|
|
|
length = data_source.get_length()
|
|
self.assertEqual(length, 20)
|
|
|
|
exists = False
|
|
for idx in range(length):
|
|
result = data_source.get_sample(idx)
|
|
file_name = result.get('filename', '')
|
|
if file_name.endswith('000032.jpg'):
|
|
exists = True
|
|
self.assertEqual(result['img_shape'], (281, 500, 3))
|
|
self.assertEqual(
|
|
result['gt_labels'].tolist(),
|
|
np.array([0, 0, 14, 14], dtype=np.int32).tolist())
|
|
self.assertEqual(
|
|
result['gt_bboxes'].astype(np.int32).tolist(),
|
|
np.array(
|
|
[[104., 78., 375., 183.], [133., 88., 197., 123.],
|
|
[195., 180., 213., 229.], [26., 189., 44., 238.]],
|
|
dtype=np.int32).tolist())
|
|
self.assertTrue(exists)
|
|
|
|
def test_default(self):
|
|
data_root = DET_DATA_SMALL_VOC_LOCAL
|
|
data_source = DetSourceVOC(
|
|
path=os.path.join(data_root, 'ImageSets/Main/train_20.txt'),
|
|
classes=VOC_CLASSES)
|
|
self._base_test(data_source)
|
|
|
|
def test_cache_on_the_fly(self):
|
|
data_root = DET_DATA_SMALL_VOC_LOCAL
|
|
data_source = DetSourceVOC(
|
|
path=os.path.join(data_root, 'ImageSets/Main/train_20.txt'),
|
|
classes=VOC_CLASSES,
|
|
cache_at_init=True,
|
|
cache_on_the_fly=False)
|
|
self._base_test(data_source)
|
|
|
|
def test_cache_at_init(self):
|
|
data_root = DET_DATA_SMALL_VOC_LOCAL
|
|
data_source = DetSourceVOC(
|
|
path=os.path.join(data_root, 'ImageSets/Main/train_20.txt'),
|
|
classes=VOC_CLASSES,
|
|
cache_at_init=False,
|
|
cache_on_the_fly=True)
|
|
self._base_test(data_source)
|
|
|
|
def test_image_root_and_label_root(self):
|
|
data_root = DET_DATA_SMALL_VOC_LOCAL
|
|
data_source = DetSourceVOC(
|
|
path=os.path.join(data_root, 'ImageSets/Main/train_20.txt'),
|
|
classes=VOC_CLASSES,
|
|
img_root_path=os.path.join(data_root, 'JPEGImages'),
|
|
label_root_path=os.path.join(data_root, 'Annotations'),
|
|
cache_at_init=False,
|
|
cache_on_the_fly=True)
|
|
self._base_test(data_source)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|