EasyCV/tests/datasets/detection/data_sources/test_coco.py

65 lines
2.3 KiB
Python
Raw Normal View History

2022-04-02 20:01:06 +08:00
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import random
import unittest
import numpy as np
from tests.ut_config import COCO_CLASSES, DET_DATA_SMALL_COCO_LOCAL
from easycv.datasets.detection.data_sources.coco import DetSourceCoco
class DetSourceCocoTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_det_source_coco(self):
data_root = DET_DATA_SMALL_COCO_LOCAL
data_source = DetSourceCoco(
ann_file=os.path.join(data_root, 'instances_train2017_20.json'),
img_prefix=os.path.join(data_root, 'train2017'),
pipeline=[
dict(type='LoadImageFromFile', to_float32=True),
dict(type='LoadAnnotations', with_bbox=True)
],
classes=COCO_CLASSES,
filter_empty_gt=False,
iscrowd=False)
index_list = random.choices(list(range(20)), k=3)
for idx in index_list:
data = data_source.get_sample(idx)
self.assertIn('ann_info', data)
self.assertIn('img_info', data)
self.assertIn('filename', data)
self.assertEqual(data['img'].shape[-1], 3)
self.assertEqual(len(data['img_shape']), 3)
self.assertEqual(data['img_fields'], ['img'])
self.assertEqual(data['gt_bboxes'].shape[-1], 4)
self.assertGreater(len(data['gt_labels']), 1)
length = data_source.get_length()
self.assertEqual(length, 20)
exists = False
for idx in range(length):
result = data_source.get_sample(idx)
file_name = result.get('filename', '')
if file_name.endswith('000000224736.jpg'):
exists = True
self.assertEqual(result['img_shape'], (427, 640, 3))
self.assertEqual(result['gt_labels'].tolist(),
np.array([61, 71], dtype=np.int32).tolist())
self.assertEqual(
result['gt_bboxes'].tolist(),
np.array([[148.1, 297.65, 270.24, 383.24],
[470.09, 148.13, 552.07, 207.29]],
dtype=np.float32).tolist())
self.assertTrue(exists)
if __name__ == '__main__':
unittest.main()