mirror of https://github.com/open-mmlab/mmyolo.git
139 lines
4.8 KiB
Python
139 lines
4.8 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import unittest
|
|
|
|
import numpy as np
|
|
import torch
|
|
from mmdet.structures import DetDataSample
|
|
from mmdet.structures.bbox import HorizontalBoxes
|
|
from mmengine.structures import InstanceData
|
|
|
|
from mmyolo.datasets import BatchShapePolicy, yolov5_collate
|
|
|
|
|
|
def _rand_bboxes(rng, num_boxes, w, h):
|
|
cx, cy, bw, bh = rng.rand(num_boxes, 4).T
|
|
|
|
tl_x = ((cx * w) - (w * bw / 2)).clip(0, w)
|
|
tl_y = ((cy * h) - (h * bh / 2)).clip(0, h)
|
|
br_x = ((cx * w) + (w * bw / 2)).clip(0, w)
|
|
br_y = ((cy * h) + (h * bh / 2)).clip(0, h)
|
|
|
|
bboxes = np.vstack([tl_x, tl_y, br_x, br_y]).T
|
|
return bboxes
|
|
|
|
|
|
class TestYOLOv5Collate(unittest.TestCase):
|
|
|
|
def test_yolov5_collate(self):
|
|
rng = np.random.RandomState(0)
|
|
|
|
inputs = torch.randn((3, 10, 10))
|
|
data_samples = DetDataSample()
|
|
gt_instances = InstanceData()
|
|
bboxes = _rand_bboxes(rng, 4, 6, 8)
|
|
gt_instances.bboxes = HorizontalBoxes(bboxes, dtype=torch.float32)
|
|
labels = rng.randint(1, 2, size=len(bboxes))
|
|
gt_instances.labels = torch.LongTensor(labels)
|
|
data_samples.gt_instances = gt_instances
|
|
|
|
out = yolov5_collate([dict(inputs=inputs, data_samples=data_samples)])
|
|
self.assertIsInstance(out, dict)
|
|
self.assertTrue(out['inputs'].shape == (1, 3, 10, 10))
|
|
self.assertTrue(out['data_samples'], dict)
|
|
self.assertTrue(out['data_samples']['bboxes_labels'].shape == (4, 6))
|
|
|
|
out = yolov5_collate([dict(inputs=inputs, data_samples=data_samples)] *
|
|
2)
|
|
self.assertIsInstance(out, dict)
|
|
self.assertTrue(out['inputs'].shape == (2, 3, 10, 10))
|
|
self.assertTrue(out['data_samples'], dict)
|
|
self.assertTrue(out['data_samples']['bboxes_labels'].shape == (8, 6))
|
|
|
|
def test_yolov5_collate_with_multi_scale(self):
|
|
rng = np.random.RandomState(0)
|
|
|
|
inputs = torch.randn((3, 10, 10))
|
|
data_samples = DetDataSample()
|
|
gt_instances = InstanceData()
|
|
bboxes = _rand_bboxes(rng, 4, 6, 8)
|
|
gt_instances.bboxes = HorizontalBoxes(bboxes, dtype=torch.float32)
|
|
labels = rng.randint(1, 2, size=len(bboxes))
|
|
gt_instances.labels = torch.LongTensor(labels)
|
|
data_samples.gt_instances = gt_instances
|
|
|
|
out = yolov5_collate([dict(inputs=inputs, data_samples=data_samples)],
|
|
use_ms_training=True)
|
|
self.assertIsInstance(out, dict)
|
|
self.assertTrue(out['inputs'][0].shape == (3, 10, 10))
|
|
self.assertTrue(out['data_samples'], dict)
|
|
self.assertTrue(out['data_samples']['bboxes_labels'].shape == (4, 6))
|
|
self.assertIsInstance(out['inputs'], list)
|
|
self.assertIsInstance(out['data_samples']['bboxes_labels'],
|
|
torch.Tensor)
|
|
|
|
out = yolov5_collate(
|
|
[dict(inputs=inputs, data_samples=data_samples)] * 2,
|
|
use_ms_training=True)
|
|
self.assertIsInstance(out, dict)
|
|
self.assertTrue(out['inputs'][0].shape == (3, 10, 10))
|
|
self.assertTrue(out['data_samples'], dict)
|
|
self.assertTrue(out['data_samples']['bboxes_labels'].shape == (8, 6))
|
|
self.assertIsInstance(out['inputs'], list)
|
|
self.assertIsInstance(out['data_samples']['bboxes_labels'],
|
|
torch.Tensor)
|
|
|
|
|
|
class TestBatchShapePolicy(unittest.TestCase):
|
|
|
|
def test_batch_shape_policy(self):
|
|
src_data_infos = [{
|
|
'height': 20,
|
|
'width': 100,
|
|
}, {
|
|
'height': 11,
|
|
'width': 100,
|
|
}, {
|
|
'height': 21,
|
|
'width': 100,
|
|
}, {
|
|
'height': 30,
|
|
'width': 100,
|
|
}, {
|
|
'height': 10,
|
|
'width': 100,
|
|
}]
|
|
|
|
expected_data_infos = [{
|
|
'height': 10,
|
|
'width': 100,
|
|
'batch_shape': np.array([96, 672])
|
|
}, {
|
|
'height': 11,
|
|
'width': 100,
|
|
'batch_shape': np.array([96, 672])
|
|
}, {
|
|
'height': 20,
|
|
'width': 100,
|
|
'batch_shape': np.array([160, 672])
|
|
}, {
|
|
'height': 21,
|
|
'width': 100,
|
|
'batch_shape': np.array([160, 672])
|
|
}, {
|
|
'height': 30,
|
|
'width': 100,
|
|
'batch_shape': np.array([224, 672])
|
|
}]
|
|
|
|
batch_shapes_policy = BatchShapePolicy(batch_size=2)
|
|
out_data_infos = batch_shapes_policy(src_data_infos)
|
|
|
|
for i in range(5):
|
|
self.assertEqual(
|
|
(expected_data_infos[i]['height'],
|
|
expected_data_infos[i]['width']),
|
|
(out_data_infos[i]['height'], out_data_infos[i]['width']))
|
|
self.assertTrue(
|
|
np.allclose(expected_data_infos[i]['batch_shape'],
|
|
out_data_infos[i]['batch_shape']))
|