# Copyright (c) Alibaba, Inc. and its affiliates. """ isort:skip_file """ import os import unittest import tempfile import numpy as np from PIL import Image from easycv.predictors.detector import DetectionPredictor, YoloXPredictor, TorchYoloXPredictor from tests.ut_config import (PRETRAINED_MODEL_YOLOXS_EXPORT, PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_JIT, PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_JIT, DET_DATA_SMALL_COCO_LOCAL) from numpy.testing import assert_array_almost_equal class YoloXPredictorTest(unittest.TestCase): img = os.path.join(DET_DATA_SMALL_COCO_LOCAL, 'val2017/000000522713.jpg') def setUp(self): print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) def _assert_results(self, results): self.assertEqual(results['ori_img_shape'], [480, 640]) self.assertListEqual(results['detection_classes'].tolist(), np.array([13, 8, 8, 8], dtype=np.int32).tolist()) self.assertListEqual(results['detection_class_names'], ['bench', 'boat', 'boat', 'boat']) assert_array_almost_equal( results['detection_scores'], np.array([0.92335737, 0.59416807, 0.5567955, 0.55368793], dtype=np.float32), decimal=2) assert_array_almost_equal( results['detection_boxes'], np.array([[408.1708, 285.11456, 561.84924, 356.42285], [438.88098, 264.46606, 467.07275, 271.76355], [510.19467, 268.46664, 528.26935, 273.37192], [480.9472, 269.74115, 502.00842, 274.85553]]), decimal=1) def _base_test_single(self, model_path, inputs): predictor = YoloXPredictor(model_path=model_path, score_thresh=0.5) outputs = predictor(inputs) self.assertEqual(len(outputs), 1) output = outputs[0] self._assert_results(output) def _base_test_batch(self, model_path, inputs, num_samples, batch_size): assert isinstance(inputs, list) and len(inputs) == 1 predictor = YoloXPredictor( model_path=model_path, score_thresh=0.5, batch_size=batch_size) outputs = predictor(inputs * num_samples) self.assertEqual(len(outputs), num_samples) for output in outputs: self._assert_results(output) def test_single_raw(self): model_path = PRETRAINED_MODEL_YOLOXS_EXPORT inputs = [np.asarray(Image.open(self.img))] self._base_test_single(model_path, inputs) def test_batch_raw(self): model_path = PRETRAINED_MODEL_YOLOXS_EXPORT inputs = [np.asarray(Image.open(self.img))] self._base_test_batch(model_path, inputs, num_samples=3, batch_size=2) def test_single_jit_nopre_notrt(self): jit_path = PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_JIT self._base_test_single(jit_path, self.img) def test_batch_jit_nopre_notrt(self): jit_path = PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_JIT self._base_test_batch( jit_path, [self.img], num_samples=2, batch_size=1) def test_single_jit_pre_trt(self): jit_path = PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_JIT self._base_test_single(jit_path, [self.img]) def test_batch_jit_pre_trt(self): jit_path = PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_JIT self._base_test_batch( jit_path, [self.img], num_samples=4, batch_size=2) def test_single_raw_TorchYoloXPredictor(self): detection_model_path = PRETRAINED_MODEL_YOLOXS_EXPORT input_data_list = [np.asarray(Image.open(self.img))] predictor = TorchYoloXPredictor( model_path=detection_model_path, score_thresh=0.5) output = predictor(input_data_list)[0] self._assert_results(output) class DetectionPredictorTest(unittest.TestCase): def setUp(self): print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) def _detection_detector_assert(self, output): self.assertIn('detection_boxes', output) self.assertIn('detection_scores', output) self.assertIn('detection_classes', output) self.assertIn('detection_masks', output) self.assertIn('img_metas', output) self.assertEqual(len(output['detection_boxes']), 33) self.assertEqual(len(output['detection_scores']), 33) self.assertEqual(len(output['detection_classes']), 33) self.assertListEqual( output['detection_classes'].tolist(), np.array([ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 7, 7, 13, 13, 13, 56 ], dtype=np.int32).tolist()) assert_array_almost_equal( output['detection_scores'], np.array([ 0.9975854158401489, 0.9965696334838867, 0.9922919869422913, 0.9833580851554871, 0.983080267906189, 0.970454752445221, 0.9701289534568787, 0.9649872183799744, 0.9642795324325562, 0.9642238020896912, 0.9529680609703064, 0.9403366446495056, 0.9391788244247437, 0.8941807150840759, 0.8178097009658813, 0.8013413548469543, 0.6677654385566711, 0.3952914774417877, 0.33463895320892334, 0.32501447200775146, 0.27323535084724426, 0.20197080075740814, 0.15607696771621704, 0.1068163588643074, 0.10183875262737274, 0.09735643863677979, 0.06559795141220093, 0.08890066295862198, 0.076363705098629, 0.9954648613929749, 0.9212945699691772, 0.5224372148513794, 0.20555885136127472 ], dtype=np.float32), decimal=2) assert_array_almost_equal( output['detection_boxes'], np.array([[ 294.22674560546875, 116.6078109741211, 379.4328918457031, 150.14097595214844 ], [ 482.6017761230469, 110.75955963134766, 522.8798828125, 129.71286010742188 ], [ 167.06460571289062, 109.95974731445312, 212.83975219726562, 140.16102600097656 ], [ 609.2930908203125, 113.13909149169922, 637.3115844726562, 136.4690704345703 ], [ 191.185791015625, 111.1408920288086, 301.31689453125, 155.7731170654297 ], [ 431.2244873046875, 106.19962310791016, 483.860595703125, 132.21627807617188 ], [ 267.48358154296875, 105.5920639038086, 325.2832336425781, 127.11176300048828 ], [ 591.2138671875, 110.29329681396484, 619.8524169921875, 126.1990966796875 ], [ 0.0, 110.7026596069336, 61.487945556640625, 146.33018493652344 ], [ 555.9155883789062, 110.03486633300781, 591.7050170898438, 127.06097412109375 ], [ 60.24559783935547, 94.12760162353516, 85.63741302490234, 106.66705322265625 ], [ 99.02665710449219, 90.53657531738281, 118.83953094482422, 101.18717956542969 ], [ 396.30438232421875, 111.59194946289062, 431.559814453125, 133.96914672851562 ], [ 83.81543731689453, 89.65665435791016, 99.9166259765625, 98.25627899169922 ], [ 139.29647827148438, 96.68000793457031, 165.22410583496094, 105.60000610351562 ], [ 67.27152252197266, 89.42798614501953, 83.25617980957031, 98.0460205078125 ], [ 223.74176025390625, 98.68321990966797, 250.42506408691406, 109.32588958740234 ], [ 136.7582244873047, 96.51412963867188, 152.51190185546875, 104.73160552978516 ], [ 221.71812438964844, 97.86445617675781, 238.9705810546875, 106.96803283691406 ], [ 135.06964111328125, 91.80916595458984, 155.24609375, 102.20686340332031 ], [ 169.11180114746094, 97.53628540039062, 182.88504028320312, 105.95404815673828 ], [ 133.8811798095703, 91.00375366210938, 145.35507202148438, 102.3780288696289 ], [ 614.2507934570312, 102.19828796386719, 636.5692749023438, 112.59198760986328 ], [ 35.94759750366211, 91.7213363647461, 70.38274383544922, 117.19855499267578 ], [ 554.6401977539062, 115.18976593017578, 562.0255737304688, 127.4429931640625 ], [ 39.07550811767578, 92.73261260986328, 85.36636352539062, 106.73953247070312 ], [ 200.85513305664062, 93.00469970703125, 219.73086547851562, 107.99642181396484 ], [ 0.0, 111.18904876708984, 61.7393684387207, 146.72547912597656 ], [ 191.88568115234375, 111.09577178955078, 299.4097900390625, 155.14639282226562 ], [ 221.06834411621094, 176.6427001953125, 458.3475341796875, 378.89300537109375 ], [ 372.7131652832031, 135.51429748535156, 433.2494201660156, 188.0106658935547 ], [ 52.19819641113281, 110.3646011352539, 70.95110321044922, 120.10567474365234 ], [ 376.1671447753906, 133.6930694580078, 432.2721862792969, 187.99481201171875 ]]), decimal=1) def test_detection_detector_single(self): model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/epoch_100_export.pth' img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg' vitdet = DetectionPredictor(model_path, score_threshold=0.0) output = vitdet(img) output = output[0] with tempfile.NamedTemporaryFile(suffix='.jpg') as tmp_file: tmp_save_path = tmp_file.name vitdet.visualize(img, output, out_file=tmp_save_path) self._detection_detector_assert(output) def test_detection_detector_batch(self): model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/epoch_100_export.pth' img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg' vitdet = DetectionPredictor( model_path, score_threshold=0.0, batch_size=2) num_samples = 3 images = [img] * num_samples outputs = vitdet(images) self.assertEqual(len(outputs), num_samples) for output in outputs: with tempfile.NamedTemporaryFile(suffix='.jpg') as tmp_file: tmp_save_path = tmp_file.name vitdet.visualize(img, output, out_file=tmp_save_path) self._detection_detector_assert(output) if __name__ == '__main__': unittest.main()