mirror of https://github.com/open-mmlab/mmyolo.git
84 lines
2.8 KiB
Python
84 lines
2.8 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import unittest
|
|
|
|
from mmengine.dataset import ConcatDataset
|
|
|
|
from mmyolo.datasets import YOLOv5VOCDataset
|
|
from mmyolo.utils import register_all_modules
|
|
|
|
register_all_modules()
|
|
|
|
|
|
class TestYOLOv5VocDataset(unittest.TestCase):
|
|
|
|
def test_batch_shapes_cfg(self):
|
|
batch_shapes_cfg = dict(
|
|
type='BatchShapePolicy',
|
|
batch_size=2,
|
|
img_size=640,
|
|
size_divisor=32,
|
|
extra_pad_ratio=0.5)
|
|
|
|
# test serialize_data=True
|
|
dataset = YOLOv5VOCDataset(
|
|
data_root='tests/data/VOCdevkit/',
|
|
ann_file='VOC2007/ImageSets/Main/trainval.txt',
|
|
data_prefix=dict(sub_data_root='VOC2007/'),
|
|
test_mode=True,
|
|
pipeline=[],
|
|
batch_shapes_cfg=batch_shapes_cfg,
|
|
)
|
|
|
|
expected_img_ids = ['000001']
|
|
expected_batch_shapes = [[672, 480]]
|
|
for i, data in enumerate(dataset):
|
|
assert data['img_id'] == expected_img_ids[i]
|
|
assert data['batch_shape'].tolist() == expected_batch_shapes[i]
|
|
|
|
def test_prepare_data(self):
|
|
dataset = YOLOv5VOCDataset(
|
|
data_root='tests/data/VOCdevkit/',
|
|
ann_file='VOC2007/ImageSets/Main/trainval.txt',
|
|
data_prefix=dict(sub_data_root='VOC2007/'),
|
|
filter_cfg=dict(filter_empty_gt=False, min_size=0),
|
|
pipeline=[],
|
|
serialize_data=True,
|
|
batch_shapes_cfg=None,
|
|
)
|
|
for data in dataset:
|
|
assert 'dataset' in data
|
|
|
|
# test with test_mode = True
|
|
dataset = YOLOv5VOCDataset(
|
|
data_root='tests/data/VOCdevkit/',
|
|
ann_file='VOC2007/ImageSets/Main/trainval.txt',
|
|
data_prefix=dict(sub_data_root='VOC2007/'),
|
|
filter_cfg=dict(
|
|
filter_empty_gt=True, min_size=32, bbox_min_size=None),
|
|
pipeline=[],
|
|
test_mode=True,
|
|
batch_shapes_cfg=None)
|
|
|
|
for data in dataset:
|
|
assert 'dataset' not in data
|
|
|
|
def test_concat_dataset(self):
|
|
dataset = ConcatDataset(datasets=[
|
|
dict(
|
|
type='YOLOv5VOCDataset',
|
|
data_root='tests/data/VOCdevkit/',
|
|
ann_file='VOC2007/ImageSets/Main/trainval.txt',
|
|
data_prefix=dict(sub_data_root='VOC2007/'),
|
|
filter_cfg=dict(filter_empty_gt=False, min_size=32),
|
|
pipeline=[]),
|
|
dict(
|
|
type='YOLOv5VOCDataset',
|
|
data_root='tests/data/VOCdevkit/',
|
|
ann_file='VOC2012/ImageSets/Main/trainval.txt',
|
|
data_prefix=dict(sub_data_root='VOC2012/'),
|
|
filter_cfg=dict(filter_empty_gt=False, min_size=32),
|
|
pipeline=[])
|
|
])
|
|
dataset.full_init()
|
|
self.assertEqual(len(dataset), 2)
|