EasyCV/tests/predictors/test_detector.py

305 lines
13 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
"""
isort:skip_file
"""
import os
import unittest
import tempfile
import numpy as np
from PIL import Image
from easycv.predictors.detector import DetectionPredictor, YoloXPredictor, TorchYoloXPredictor
from tests.ut_config import (PRETRAINED_MODEL_YOLOXS_EXPORT,
PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_JIT,
PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_JIT,
PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_JIT_B2,
DET_DATA_SMALL_COCO_LOCAL)
from numpy.testing import assert_array_almost_equal
class YoloXPredictorTest(unittest.TestCase):
img = os.path.join(DET_DATA_SMALL_COCO_LOCAL, 'val2017/000000522713.jpg')
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def _assert_results(self, results):
self.assertEqual(results['ori_img_shape'], [480, 640])
self.assertListEqual(results['detection_classes'].tolist(),
np.array([13, 8, 8, 8], dtype=np.int32).tolist())
self.assertListEqual(results['detection_class_names'],
['bench', 'boat', 'boat', 'boat'])
assert_array_almost_equal(
results['detection_scores'],
np.array([0.92335737, 0.59416807, 0.5567955, 0.55368793],
dtype=np.float32),
decimal=2)
assert_array_almost_equal(
results['detection_boxes'],
np.array([[408.1708, 285.11456, 561.84924, 356.42285],
[438.88098, 264.46606, 467.07275, 271.76355],
[510.19467, 268.46664, 528.26935, 273.37192],
[480.9472, 269.74115, 502.00842, 274.85553]]),
decimal=1)
def _base_test_single(self, model_path, inputs):
predictor = YoloXPredictor(model_path=model_path, score_thresh=0.5)
outputs = predictor(inputs)
self.assertEqual(len(outputs), 1)
output = outputs[0]
self._assert_results(output)
def _base_test_batch(self, model_path, inputs, num_samples, batch_size):
assert isinstance(inputs, list) and len(inputs) == 1
predictor = YoloXPredictor(
model_path=model_path, score_thresh=0.5, batch_size=batch_size)
outputs = predictor(inputs * num_samples)
self.assertEqual(len(outputs), num_samples)
for output in outputs:
self._assert_results(output)
def test_single_raw(self):
model_path = PRETRAINED_MODEL_YOLOXS_EXPORT
inputs = [np.asarray(Image.open(self.img))]
self._base_test_single(model_path, inputs)
def test_batch_raw(self):
model_path = PRETRAINED_MODEL_YOLOXS_EXPORT
inputs = [np.asarray(Image.open(self.img))]
self._base_test_batch(model_path, inputs, num_samples=3, batch_size=2)
def test_single_jit_nopre_notrt(self):
jit_path = PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_JIT
self._base_test_single(jit_path, self.img)
def test_batch_jit_nopre_notrt(self):
jit_path = PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_JIT
self._base_test_batch(
jit_path, [self.img], num_samples=2, batch_size=1)
def test_single_jit_pre_trt(self):
jit_path = PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_JIT
self._base_test_single(jit_path, [self.img])
def test_batch_jit_pre_trt(self):
jit_path = PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_JIT_B2
self._base_test_batch(
jit_path, [self.img], num_samples=4, batch_size=2)
def test_single_raw_TorchYoloXPredictor(self):
detection_model_path = PRETRAINED_MODEL_YOLOXS_EXPORT
input_data_list = [np.asarray(Image.open(self.img))]
predictor = TorchYoloXPredictor(
model_path=detection_model_path, score_thresh=0.5)
output = predictor(input_data_list)[0]
self._assert_results(output)
class DetectionPredictorTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def _detection_detector_assert(self, output):
self.assertIn('detection_boxes', output)
self.assertIn('detection_scores', output)
self.assertIn('detection_classes', output)
self.assertIn('detection_masks', output)
self.assertIn('img_metas', output)
self.assertEqual(len(output['detection_boxes']), 33)
self.assertEqual(len(output['detection_scores']), 33)
self.assertEqual(len(output['detection_classes']), 33)
self.assertListEqual(
output['detection_classes'].tolist(),
np.array([
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 7, 7, 13, 13, 13, 56
],
dtype=np.int32).tolist())
assert_array_almost_equal(
output['detection_scores'],
np.array([
0.9975854158401489, 0.9965696334838867, 0.9922919869422913,
0.9833580851554871, 0.983080267906189, 0.970454752445221,
0.9701289534568787, 0.9649872183799744, 0.9642795324325562,
0.9642238020896912, 0.9529680609703064, 0.9403366446495056,
0.9391788244247437, 0.8941807150840759, 0.8178097009658813,
0.8013413548469543, 0.6677654385566711, 0.3952914774417877,
0.33463895320892334, 0.32501447200775146, 0.27323535084724426,
0.20197080075740814, 0.15607696771621704, 0.1068163588643074,
0.10183875262737274, 0.09735643863677979, 0.06559795141220093,
0.08890066295862198, 0.076363705098629, 0.9954648613929749,
0.9212945699691772, 0.5224372148513794, 0.20555885136127472
],
dtype=np.float32),
decimal=2)
assert_array_almost_equal(
output['detection_boxes'],
np.array([[
294.22674560546875, 116.6078109741211, 379.4328918457031,
150.14097595214844
],
[
482.6017761230469, 110.75955963134766,
522.8798828125, 129.71286010742188
],
[
167.06460571289062, 109.95974731445312,
212.83975219726562, 140.16102600097656
],
[
609.2930908203125, 113.13909149169922,
637.3115844726562, 136.4690704345703
],
[
191.185791015625, 111.1408920288086, 301.31689453125,
155.7731170654297
],
[
431.2244873046875, 106.19962310791016,
483.860595703125, 132.21627807617188
],
[
267.48358154296875, 105.5920639038086,
325.2832336425781, 127.11176300048828
],
[
591.2138671875, 110.29329681396484,
619.8524169921875, 126.1990966796875
],
[
0.0, 110.7026596069336, 61.487945556640625,
146.33018493652344
],
[
555.9155883789062, 110.03486633300781,
591.7050170898438, 127.06097412109375
],
[
60.24559783935547, 94.12760162353516,
85.63741302490234, 106.66705322265625
],
[
99.02665710449219, 90.53657531738281,
118.83953094482422, 101.18717956542969
],
[
396.30438232421875, 111.59194946289062,
431.559814453125, 133.96914672851562
],
[
83.81543731689453, 89.65665435791016,
99.9166259765625, 98.25627899169922
],
[
139.29647827148438, 96.68000793457031,
165.22410583496094, 105.60000610351562
],
[
67.27152252197266, 89.42798614501953,
83.25617980957031, 98.0460205078125
],
[
223.74176025390625, 98.68321990966797,
250.42506408691406, 109.32588958740234
],
[
136.7582244873047, 96.51412963867188,
152.51190185546875, 104.73160552978516
],
[
221.71812438964844, 97.86445617675781,
238.9705810546875, 106.96803283691406
],
[
135.06964111328125, 91.80916595458984, 155.24609375,
102.20686340332031
],
[
169.11180114746094, 97.53628540039062,
182.88504028320312, 105.95404815673828
],
[
133.8811798095703, 91.00375366210938,
145.35507202148438, 102.3780288696289
],
[
614.2507934570312, 102.19828796386719,
636.5692749023438, 112.59198760986328
],
[
35.94759750366211, 91.7213363647461,
70.38274383544922, 117.19855499267578
],
[
554.6401977539062, 115.18976593017578,
562.0255737304688, 127.4429931640625
],
[
39.07550811767578, 92.73261260986328,
85.36636352539062, 106.73953247070312
],
[
200.85513305664062, 93.00469970703125,
219.73086547851562, 107.99642181396484
],
[
0.0, 111.18904876708984, 61.7393684387207,
146.72547912597656
],
[
191.88568115234375, 111.09577178955078,
299.4097900390625, 155.14639282226562
],
[
221.06834411621094, 176.6427001953125,
458.3475341796875, 378.89300537109375
],
[
372.7131652832031, 135.51429748535156,
433.2494201660156, 188.0106658935547
],
[
52.19819641113281, 110.3646011352539,
70.95110321044922, 120.10567474365234
],
[
376.1671447753906, 133.6930694580078,
432.2721862792969, 187.99481201171875
]]),
decimal=1)
def test_detection_detector_single(self):
model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/epoch_100_export.pth'
img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg'
vitdet = DetectionPredictor(model_path, score_threshold=0.0)
output = vitdet(img)
output = output[0]
with tempfile.NamedTemporaryFile(suffix='.jpg') as tmp_file:
tmp_save_path = tmp_file.name
vitdet.visualize(img, output, out_file=tmp_save_path)
self._detection_detector_assert(output)
def test_detection_detector_batch(self):
model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/epoch_100_export.pth'
img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg'
vitdet = DetectionPredictor(
model_path, score_threshold=0.0, batch_size=2)
num_samples = 3
images = [img] * num_samples
outputs = vitdet(images)
self.assertEqual(len(outputs), num_samples)
for output in outputs:
with tempfile.NamedTemporaryFile(suffix='.jpg') as tmp_file:
tmp_save_path = tmp_file.name
vitdet.visualize(img, output, out_file=tmp_save_path)
self._detection_detector_assert(output)
if __name__ == '__main__':
unittest.main()