EasyCV/tests/predictors/test_detector.py

250 lines
11 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
"""
isort:skip_file
"""
import os
import tempfile
import unittest
import cv2
import numpy as np
from PIL import Image
from easycv.predictors.detector import TorchYoloXPredictor, TorchViTDetPredictor
from tests.ut_config import (PRETRAINED_MODEL_YOLOXS_EXPORT,
PRETRAINED_MODEL_YOLOXS_EXPORT_JIT,
PRETRAINED_MODEL_YOLOXS_END2END_JIT,
DET_DATA_SMALL_COCO_LOCAL)
from numpy.testing import assert_array_almost_equal
class DetectorTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_yolox_detector(self):
detection_model_path = PRETRAINED_MODEL_YOLOXS_EXPORT
img = os.path.join(DET_DATA_SMALL_COCO_LOCAL,
'val2017/000000037777.jpg')
input_data_list = [np.asarray(Image.open(img))]
predictor = TorchYoloXPredictor(
model_path=detection_model_path, score_thresh=0.5)
output = predictor.predict(input_data_list)[0]
self.assertIn('detection_boxes', output)
self.assertIn('detection_scores', output)
self.assertIn('detection_classes', output)
self.assertIn('detection_class_names', output)
self.assertIn('ori_img_shape', output)
self.assertEqual(len(output['detection_boxes']), 9)
self.assertEqual(output['ori_img_shape'], [230, 352])
self.assertListEqual(
output['detection_classes'].tolist(),
np.array([72, 69, 60, 56, 49, 49, 72, 46, 49],
dtype=np.int32).tolist())
self.assertListEqual(output['detection_class_names'], [
'refrigerator', 'oven', 'dining table', 'chair', 'orange',
'orange', 'refrigerator', 'banana', 'orange'
])
assert_array_almost_equal(
output['detection_scores'],
np.array([
0.93252, 0.88439, 0.75048, 0.74093, 0.67255, 0.65550, 0.63942,
0.60507, 0.56973
],
dtype=np.float32),
decimal=2)
assert_array_almost_equal(
output['detection_boxes'],
np.array([[298.28256, 76.26037, 352., 228.91579],
[137.9849, 124.92237, 196.49876, 193.12375],
[76.42237, 170.30052, 292.4093, 227.32962],
[117.317, 188.9916, 165.43694, 212.3457],
[231.36719, 199.89865, 248.27888, 217.50288],
[217.1154, 200.18729, 232.20607, 214.38866],
[121.948105, 90.01667, 193.17673, 194.04584],
[240.4494, 188.07112, 258.7406, 206.78226],
[204.21452, 187.11292, 220.3842, 207.25877]]),
decimal=1)
def test_yolox_detector_jit_end2end(self):
img = os.path.join(DET_DATA_SMALL_COCO_LOCAL,
'val2017/000000037777.jpg')
input_data_list = [np.asarray(Image.open(img))]
jit_path = PRETRAINED_MODEL_YOLOXS_END2END_JIT
predictor_jit = TorchYoloXPredictor(
model_path=jit_path, score_thresh=0.5)
output_jit = predictor_jit.predict(input_data_list)[0]
self.assertIn('detection_boxes', output_jit)
self.assertIn('detection_scores', output_jit)
self.assertIn('detection_classes', output_jit)
self.assertListEqual(
output_jit['detection_classes'].tolist(),
np.array([72, 69, 60, 56, 49, 49, 72, 46, 49],
dtype=np.int32).tolist())
self.assertListEqual(output_jit['detection_class_names'], [
'refrigerator', 'oven', 'dining table', 'chair', 'orange',
'orange', 'refrigerator', 'banana', 'orange'
])
assert_array_almost_equal(
output_jit['detection_scores'],
np.array([
0.93252, 0.88439, 0.75048, 0.74093, 0.67255, 0.65550, 0.63942,
0.60507, 0.56973
],
dtype=np.float32),
decimal=2)
assert_array_almost_equal(
output_jit['detection_boxes'],
np.array([[298.28256, 76.26037, 352., 228.91579],
[137.9849, 124.92237, 196.49876, 193.12375],
[76.42237, 170.30052, 292.4093, 227.32962],
[117.317, 188.9916, 165.43694, 212.3457],
[231.36719, 199.89865, 248.27888, 217.50288],
[217.1154, 200.18729, 232.20607, 214.38866],
[121.948105, 90.01667, 193.17673, 194.04584],
[240.4494, 188.07112, 258.7406, 206.78226],
[204.21452, 187.11292, 220.3842, 207.25877]]),
decimal=1)
def test_yolox_detector_jit(self):
img = os.path.join(DET_DATA_SMALL_COCO_LOCAL,
'val2017/000000037777.jpg')
input_data_list = [np.asarray(Image.open(img))]
jit_path = PRETRAINED_MODEL_YOLOXS_EXPORT_JIT
predictor_jit = TorchYoloXPredictor(
model_path=jit_path, score_thresh=0.5)
output_jit = predictor_jit.predict(input_data_list)[0]
self.assertIn('detection_boxes', output_jit)
self.assertIn('detection_scores', output_jit)
self.assertIn('detection_classes', output_jit)
self.assertListEqual(
output_jit['detection_classes'].tolist(),
np.array([72, 69, 60, 56, 49, 49, 72, 46, 49],
dtype=np.int32).tolist())
self.assertListEqual(output_jit['detection_class_names'], [
'refrigerator', 'oven', 'dining table', 'chair', 'orange',
'orange', 'refrigerator', 'banana', 'orange'
])
assert_array_almost_equal(
output_jit['detection_scores'],
np.array([
0.93252, 0.88439, 0.75048, 0.74093, 0.67255, 0.65550, 0.63942,
0.60507, 0.56973
],
dtype=np.float32),
decimal=2)
assert_array_almost_equal(
output_jit['detection_boxes'],
np.array([[298.28256, 76.26037, 352., 228.91579],
[137.9849, 124.92237, 196.49876, 193.12375],
[76.42237, 170.30052, 292.4093, 227.32962],
[117.317, 188.9916, 165.43694, 212.3457],
[231.36719, 199.89865, 248.27888, 217.50288],
[217.1154, 200.18729, 232.20607, 214.38866],
[121.948105, 90.01667, 193.17673, 194.04584],
[240.4494, 188.07112, 258.7406, 206.78226],
[204.21452, 187.11292, 220.3842, 207.25877]]),
decimal=1)
def test_vitdet_detector(self):
model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/vitdet_maskrcnn_export.pth'
img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg'
out_file = './result.jpg'
vitdet = TorchViTDetPredictor(model_path)
output = vitdet.predict(img)
self.assertIn('detection_boxes', output)
self.assertIn('detection_scores', output)
self.assertIn('detection_classes', output)
self.assertIn('detection_masks', output)
self.assertIn('img_metas', output)
self.assertEqual(len(output['detection_boxes'][0]), 30)
self.assertEqual(len(output['detection_scores'][0]), 30)
self.assertEqual(len(output['detection_classes'][0]), 30)
self.assertListEqual(
output['detection_classes'][0].tolist(),
np.array([
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 7, 7, 13, 13, 13, 56
],
dtype=np.int32).tolist())
assert_array_almost_equal(
output['detection_scores'][0],
np.array([
0.99791867, 0.99665856, 0.99480623, 0.99060905, 0.9882515,
0.98319584, 0.9738879, 0.97290784, 0.9514897, 0.95104814,
0.9321701, 0.86165, 0.8228847, 0.7623552, 0.76129806,
0.6050861, 0.44348577, 0.3452973, 0.2895671, 0.22109479,
0.21265312, 0.17855245, 0.1205352, 0.08981906, 0.10596471,
0.05854294, 0.99749386, 0.9472857, 0.5945908, 0.09855112
],
dtype=np.float32),
decimal=2)
assert_array_almost_equal(
output['detection_boxes'][0],
np.array([[294.7058, 117.29371, 378.83713, 149.99928],
[609.05444, 112.526474, 633.2971, 136.35175],
[481.4165, 110.987335, 522.5531, 130.01529],
[167.68184, 109.89049, 215.49057, 139.86987],
[374.75082, 110.68697, 433.10028, 136.23654],
[189.54971, 110.09322, 297.6167, 155.77412],
[266.5185, 105.37718, 326.54385, 127.916374],
[556.30225, 110.43166, 592.8248, 128.03764],
[432.49252, 105.086464, 484.0512, 132.272],
[0., 110.566444, 62.01249, 146.44017],
[591.74664, 110.43527, 619.73816, 126.68549],
[99.126854, 90.947975, 118.46699, 101.11096],
[59.895264, 94.110054, 85.60521, 106.67633],
[142.95819, 96.61966, 165.96964, 104.95929],
[83.062515, 89.802605, 99.1546, 98.69074],
[226.28802, 98.32568, 249.06772, 108.86408],
[136.67789, 94.75706, 154.62924, 104.289536],
[170.42459, 98.458694, 183.16309, 106.203156],
[67.56731, 89.68286, 82.62955, 98.35645],
[222.80092, 97.828445, 239.02655, 108.29377],
[134.34427, 92.31653, 149.19615, 102.97457],
[613.5186, 102.27066, 636.0434, 112.813644],
[607.4787, 110.87984, 630.1123, 127.65646],
[135.13664, 90.989876, 155.67192, 100.18036],
[431.61505, 105.43844, 484.36508, 132.50078],
[189.92722, 110.38832, 297.74353, 155.95557],
[220.67035, 177.13489, 455.32092, 380.45712],
[372.76584, 134.33807, 432.44357, 188.51534],
[50.403812, 110.543495, 70.4368, 119.65186],
[373.50272, 134.27258, 432.18475, 187.81824]]),
decimal=1)
vitdet.show_result_pyplot(img, output, out_file=out_file)
if __name__ == '__main__':
unittest.main()