EasyCV/tests/models/detection/yolox/test_yolox.py

54 lines
1.9 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
import numpy as np
import torch
from easycv.models.detection.detectors.yolox.yolox import YOLOX
class YOLOXTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_yolox(self):
for model_type in ['s', 'm', 'l', 'x', 'tiny', 'nano']:
model = YOLOX(
test_conf=0.01,
nms_thre=0.65,
backbone='CSPDarknet',
model_type=model_type,
head=dict(
type='YOLOXHead', model_type=model_type, num_classes=2),
)
model = model.cuda()
model.train()
batch_size = 2
imgs = torch.randn(batch_size, 3, 640, 640).cuda()
num_boxes = 5
gt_bboxes = torch.randint(
0, 600, size=(batch_size, num_boxes, 4)).cuda()
gt_labels = torch.randint(
0, 1, size=(batch_size, num_boxes, 1)).cuda()
img_metas = [{'img_shape': (640, 640, 3)}] * batch_size
kwargs = {
'gt_bboxes': gt_bboxes,
'gt_labels': gt_labels,
'img_metas': img_metas
}
output = model(imgs, mode='train', **kwargs)
self.assertEqual(output['img_h'].cpu().numpy(),
np.array(640, dtype=np.float))
self.assertEqual(output['img_w'].cpu().numpy(),
np.array(640, dtype=np.float))
self.assertEqual(output['total_loss'].shape, torch.Size([]))
self.assertEqual(output['iou_l'].shape, torch.Size([]))
self.assertEqual(output['conf_l'].shape, torch.Size([]))
self.assertEqual(output['cls_l'].shape, torch.Size([]))
if __name__ == '__main__':
unittest.main()