mirror of https://github.com/alibaba/EasyCV.git
134 lines
4.9 KiB
Python
134 lines
4.9 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import os
|
|
import time
|
|
import unittest
|
|
|
|
import torch
|
|
from tests.ut_config import (COCO_CLASSES, DET_DATA_SMALL_COCO_LOCAL,
|
|
IMG_NORM_CFG_255)
|
|
|
|
from easycv.datasets.builder import build_dataset
|
|
|
|
|
|
class DetImagesMixDatasetTest(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
|
|
|
def test_load_train(self):
|
|
img_scale = (640, 640)
|
|
scale_ratio = (0.1, 2)
|
|
train_pipeline = [
|
|
dict(type='MMMosaic', img_scale=img_scale, pad_val=114.0),
|
|
dict(
|
|
type='MMRandomAffine',
|
|
scaling_ratio_range=scale_ratio,
|
|
border=(-img_scale[0] // 2, -img_scale[1] // 2)),
|
|
dict(
|
|
type='MMMixUp', # s m x l; tiny nano will detele
|
|
img_scale=img_scale,
|
|
ratio_range=(0.8, 1.6),
|
|
pad_val=114.0),
|
|
dict(
|
|
type='MMPhotoMetricDistortion',
|
|
brightness_delta=32,
|
|
contrast_range=(0.5, 1.5),
|
|
saturation_range=(0.5, 1.5),
|
|
hue_delta=18),
|
|
dict(type='MMRandomFlip', flip_ratio=0.5),
|
|
dict(type='MMResize', keep_ratio=True),
|
|
dict(
|
|
type='MMPad',
|
|
pad_to_square=True,
|
|
pad_val=(114.0, 114.0, 114.0)),
|
|
dict(type='MMNormalize', **IMG_NORM_CFG_255),
|
|
dict(type='DefaultFormatBundle'),
|
|
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
|
|
]
|
|
train_dataset = dict(
|
|
type='DetImagesMixDataset',
|
|
data_source=dict(
|
|
type='DetSourceCoco',
|
|
ann_file=os.path.join(DET_DATA_SMALL_COCO_LOCAL,
|
|
'instances_train2017_20.json'),
|
|
img_prefix=os.path.join(DET_DATA_SMALL_COCO_LOCAL,
|
|
'train2017'),
|
|
pipeline=[
|
|
dict(type='LoadImageFromFile', to_float32=True),
|
|
dict(type='LoadAnnotations', with_bbox=True)
|
|
],
|
|
classes=COCO_CLASSES,
|
|
filter_empty_gt=False,
|
|
iscrowd=False),
|
|
pipeline=train_pipeline,
|
|
dynamic_scale=img_scale)
|
|
|
|
dataset = build_dataset(train_dataset)
|
|
data_num = len(dataset)
|
|
s = time.time()
|
|
for data in dataset:
|
|
pass
|
|
t = time.time()
|
|
print(f'read data done {(t-s)/data_num}s per sample')
|
|
|
|
self.assertEqual(data['img'].data.shape, torch.Size([3, 640, 640]))
|
|
img_metas = data['img_metas'].data
|
|
self.assertIn('flip', img_metas)
|
|
self.assertIn('filename', img_metas)
|
|
self.assertIn('img_shape', img_metas)
|
|
self.assertIn('ori_img_shape', img_metas)
|
|
self.assertEqual(data['gt_bboxes'].shape, torch.Size([120, 4]))
|
|
self.assertEqual(data['gt_labels'].shape, torch.Size([120, 1]))
|
|
self.assertEqual(data_num, 20)
|
|
|
|
def test_load_test(self):
|
|
img_scale = (640, 640)
|
|
test_pipeline = [
|
|
dict(type='MMResize', img_scale=img_scale, keep_ratio=True),
|
|
dict(
|
|
type='MMPad',
|
|
pad_to_square=True,
|
|
pad_val=(114.0, 114.0, 114.0)),
|
|
dict(type='MMNormalize', **IMG_NORM_CFG_255),
|
|
dict(type='DefaultFormatBundle'),
|
|
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
|
|
]
|
|
val_dataset = dict(
|
|
type='DetImagesMixDataset',
|
|
data_source=dict(
|
|
type='DetSourceCoco',
|
|
ann_file=os.path.join(DET_DATA_SMALL_COCO_LOCAL,
|
|
'instances_val2017_20.json'),
|
|
img_prefix=os.path.join(DET_DATA_SMALL_COCO_LOCAL, 'val2017'),
|
|
pipeline=[
|
|
dict(type='LoadImageFromFile', to_float32=True),
|
|
dict(type='LoadAnnotations', with_bbox=True)
|
|
],
|
|
classes=COCO_CLASSES,
|
|
filter_empty_gt=False,
|
|
iscrowd=True),
|
|
pipeline=test_pipeline,
|
|
dynamic_scale=None,
|
|
label_padding=False)
|
|
|
|
dataset = build_dataset(val_dataset)
|
|
data_num = len(dataset)
|
|
s = time.time()
|
|
for data in dataset:
|
|
pass
|
|
t = time.time()
|
|
print(f'read data done {(t-s)/data_num}s per sample')
|
|
|
|
self.assertEqual(data['img'].data.shape, torch.Size([3, 640, 640]))
|
|
img_metas = data['img_metas'].data
|
|
self.assertIn('filename', img_metas)
|
|
self.assertIn('img_shape', img_metas)
|
|
self.assertIn('ori_img_shape', img_metas)
|
|
self.assertEqual(data['gt_bboxes'].data.shape, torch.Size([16, 4]))
|
|
self.assertEqual(data['gt_labels'].data.shape, torch.Size([16]))
|
|
self.assertEqual(data_num, 20)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|