mirror of https://github.com/alibaba/EasyCV.git
318 lines
13 KiB
Python
318 lines
13 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import unittest
|
|
|
|
import numpy as np
|
|
from numpy.testing import assert_array_almost_equal
|
|
|
|
from easycv.predictors.detector import DetrPredictor
|
|
|
|
|
|
class DETRTest(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
|
|
|
def test_detr(self):
|
|
model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/detr/epoch_150.pth'
|
|
config_path = 'configs/detection/detr/detr_r50_8x2_150e_coco.py'
|
|
img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg'
|
|
detr = DetrPredictor(model_path, config_path)
|
|
output = detr.predict(img)
|
|
detr.visualize(img, output, out_file=None)
|
|
|
|
self.assertIn('detection_boxes', output)
|
|
self.assertIn('detection_scores', output)
|
|
self.assertIn('detection_classes', output)
|
|
self.assertIn('img_metas', output)
|
|
self.assertEqual(len(output['detection_boxes'][0]), 100)
|
|
self.assertEqual(len(output['detection_scores'][0]), 100)
|
|
self.assertEqual(len(output['detection_classes'][0]), 100)
|
|
|
|
self.assertListEqual(
|
|
output['detection_classes'][0][:10].tolist(),
|
|
np.array([2, 0, 2, 2, 2, 2, 2, 2, 2, 2], dtype=np.int32).tolist())
|
|
|
|
assert_array_almost_equal(
|
|
output['detection_scores'][0][:10],
|
|
np.array([
|
|
0.07836595922708511, 0.219977006316185, 0.5831383466720581,
|
|
0.4256463646888733, 0.9853266477584839, 0.24607707560062408,
|
|
0.28005731105804443, 0.500579833984375, 0.09835881739854813,
|
|
0.05178987979888916
|
|
],
|
|
dtype=np.float32),
|
|
decimal=2)
|
|
|
|
assert_array_almost_equal(
|
|
output['detection_boxes'][0][:10],
|
|
np.array([[
|
|
131.10389709472656, 90.93302154541016, 148.95504760742188,
|
|
101.69216918945312
|
|
],
|
|
[
|
|
239.10910034179688, 113.36551666259766,
|
|
256.0523376464844, 125.22894287109375
|
|
],
|
|
[
|
|
132.1316375732422, 90.8366470336914,
|
|
151.00839233398438, 101.83119201660156
|
|
],
|
|
[
|
|
579.37646484375, 108.26667785644531,
|
|
605.0717163085938, 124.79525756835938
|
|
],
|
|
[
|
|
189.69073486328125, 108.04875946044922,
|
|
296.8011779785156, 154.44204711914062
|
|
],
|
|
[
|
|
588.5413208007812, 107.89535522460938,
|
|
615.6463012695312, 124.41362762451172
|
|
],
|
|
[
|
|
57.38536071777344, 89.7335433959961,
|
|
79.20274353027344, 102.61941528320312
|
|
],
|
|
[
|
|
163.97628784179688, 92.95049285888672,
|
|
180.87033081054688, 102.6163330078125
|
|
],
|
|
[
|
|
127.82454681396484, 90.27918243408203,
|
|
144.6781768798828, 99.71304321289062
|
|
],
|
|
[
|
|
438.4545593261719, 103.00477600097656,
|
|
480.4275817871094, 121.69993591308594
|
|
]]),
|
|
decimal=1)
|
|
|
|
def test_dab_detr(self):
|
|
model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dab_detr/dab_detr_epoch_50.pth'
|
|
config_path = 'configs/detection/dab_detr/dab_detr_r50_8x2_50e_coco.py'
|
|
img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg'
|
|
dab_detr = DetrPredictor(model_path, config_path)
|
|
output = dab_detr.predict(img)
|
|
dab_detr.visualize(img, output, out_file=None)
|
|
|
|
self.assertIn('detection_boxes', output)
|
|
self.assertIn('detection_scores', output)
|
|
self.assertIn('detection_classes', output)
|
|
self.assertIn('img_metas', output)
|
|
self.assertEqual(len(output['detection_boxes'][0]), 300)
|
|
self.assertEqual(len(output['detection_scores'][0]), 300)
|
|
self.assertEqual(len(output['detection_classes'][0]), 300)
|
|
|
|
self.assertListEqual(
|
|
output['detection_classes'][0][:10].tolist(),
|
|
np.array([2, 2, 13, 2, 2, 2, 2, 2, 2, 2], dtype=np.int32).tolist())
|
|
|
|
assert_array_almost_equal(
|
|
output['detection_scores'][0][:10],
|
|
np.array([
|
|
0.7688284516334534, 0.7646799683570862, 0.7159939408302307,
|
|
0.6902833580970764, 0.6633996367454529, 0.6523147821426392,
|
|
0.633848249912262, 0.6229104995727539, 0.611840009689331,
|
|
0.5631589293479919
|
|
],
|
|
dtype=np.float32),
|
|
decimal=2)
|
|
|
|
assert_array_almost_equal(
|
|
output['detection_boxes'][0][:10],
|
|
np.array([[
|
|
294.2984313964844, 116.07160949707031, 380.4406433105469,
|
|
149.6365509033203
|
|
],
|
|
[
|
|
480.0610656738281, 109.99347686767578,
|
|
523.2314453125, 130.26318359375
|
|
],
|
|
[
|
|
220.32269287109375, 176.51010131835938,
|
|
456.51715087890625, 386.30767822265625
|
|
],
|
|
[
|
|
167.6925506591797, 108.25935363769531,
|
|
214.93780517578125, 138.94424438476562
|
|
],
|
|
[
|
|
398.1152648925781, 111.34457397460938,
|
|
433.72052001953125, 133.36280822753906
|
|
],
|
|
[
|
|
430.48736572265625, 104.4018325805664,
|
|
484.1470947265625, 132.18893432617188
|
|
],
|
|
[
|
|
607.396728515625, 111.72560119628906,
|
|
637.2987670898438, 136.2375946044922
|
|
],
|
|
[
|
|
267.43353271484375, 105.93965911865234,
|
|
327.1937561035156, 130.18527221679688
|
|
],
|
|
[
|
|
589.790771484375, 110.36975860595703,
|
|
618.8001098632812, 126.1950454711914
|
|
],
|
|
[
|
|
0.3374290466308594, 110.91182708740234,
|
|
63.00359344482422, 146.0926971435547
|
|
]]),
|
|
decimal=1)
|
|
|
|
def test_dn_detr(self):
|
|
model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dn_detr/dn_detr_epoch_50.pth'
|
|
config_path = 'configs/detection/dab_detr/dn_detr_r50_8x2_50e_coco.py'
|
|
img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg'
|
|
dn_detr = DetrPredictor(model_path, config_path)
|
|
output = dn_detr.predict(img)
|
|
dn_detr.visualize(img, output, out_file=None)
|
|
|
|
self.assertIn('detection_boxes', output)
|
|
self.assertIn('detection_scores', output)
|
|
self.assertIn('detection_classes', output)
|
|
self.assertIn('img_metas', output)
|
|
self.assertEqual(len(output['detection_boxes'][0]), 300)
|
|
self.assertEqual(len(output['detection_scores'][0]), 300)
|
|
self.assertEqual(len(output['detection_classes'][0]), 300)
|
|
|
|
self.assertListEqual(
|
|
output['detection_classes'][0][:10].tolist(),
|
|
np.array([2, 13, 2, 2, 2, 2, 2, 2, 2, 2], dtype=np.int32).tolist())
|
|
|
|
assert_array_almost_equal(
|
|
output['detection_scores'][0][:10],
|
|
np.array([
|
|
0.8800525665283203, 0.866659939289093, 0.8665854930877686,
|
|
0.8030595183372498, 0.7642921209335327, 0.7375038862228394,
|
|
0.7270554304122925, 0.6710091233253479, 0.6316548585891724,
|
|
0.6164721846580505
|
|
],
|
|
dtype=np.float32),
|
|
decimal=2)
|
|
|
|
assert_array_almost_equal(
|
|
output['detection_boxes'][0][:10],
|
|
np.array([[
|
|
294.9338073730469, 115.7542495727539, 377.5517578125,
|
|
150.59274291992188
|
|
],
|
|
[
|
|
220.57424926757812, 175.97023010253906,
|
|
456.9001770019531, 383.2597351074219
|
|
],
|
|
[
|
|
479.5928649902344, 109.94012451171875,
|
|
523.7343139648438, 130.80604553222656
|
|
],
|
|
[
|
|
398.6956787109375, 111.45973205566406,
|
|
434.0437316894531, 134.1909637451172
|
|
],
|
|
[
|
|
166.98208618164062, 109.44792938232422,
|
|
210.35342407226562, 139.9746856689453
|
|
],
|
|
[
|
|
609.432373046875, 113.08062744140625,
|
|
635.9082641601562, 136.74383544921875
|
|
],
|
|
[
|
|
268.0716552734375, 105.00788879394531,
|
|
327.4037170410156, 128.01449584960938
|
|
],
|
|
[
|
|
190.77467346191406, 107.42850494384766,
|
|
298.35760498046875, 156.2850341796875
|
|
],
|
|
[
|
|
591.0296020507812, 110.53913116455078,
|
|
620.702880859375, 127.42123413085938
|
|
],
|
|
[
|
|
431.6607971191406, 105.04813385009766,
|
|
484.4869689941406, 132.45864868164062
|
|
]]),
|
|
decimal=1)
|
|
|
|
def test_dino(self):
|
|
model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_4sc_r50_36e/epoch_29.pth'
|
|
config_path = 'configs/detection/dino/dino_4sc_r50_36e_coco.py'
|
|
img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg'
|
|
dino = DetrPredictor(model_path, config_path)
|
|
output = dino.predict(img)
|
|
dino.visualize(img, output, out_file=None)
|
|
|
|
self.assertIn('detection_boxes', output)
|
|
self.assertIn('detection_scores', output)
|
|
self.assertIn('detection_classes', output)
|
|
self.assertIn('img_metas', output)
|
|
self.assertEqual(len(output['detection_boxes'][0]), 300)
|
|
self.assertEqual(len(output['detection_scores'][0]), 300)
|
|
self.assertEqual(len(output['detection_classes'][0]), 300)
|
|
|
|
self.assertListEqual(
|
|
output['detection_classes'][0][:10].tolist(),
|
|
np.array([13, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=np.int32).tolist())
|
|
|
|
assert_array_almost_equal(
|
|
output['detection_scores'][0][:10],
|
|
np.array([
|
|
0.8808171153068542, 0.8584598898887634, 0.8214247226715088,
|
|
0.8156911134719849, 0.7707086801528931, 0.6717984080314636,
|
|
0.6578451991081238, 0.6269607543945312, 0.6063129901885986,
|
|
0.5223093628883362
|
|
],
|
|
dtype=np.float32),
|
|
decimal=2)
|
|
|
|
assert_array_almost_equal(
|
|
output['detection_boxes'][0][:10],
|
|
np.array([[
|
|
222.15492248535156, 175.9025421142578, 456.3177490234375,
|
|
382.48211669921875
|
|
],
|
|
[
|
|
295.12115478515625, 115.97019958496094,
|
|
378.97119140625, 150.2149658203125
|
|
],
|
|
[
|
|
190.94241333007812, 108.94568634033203,
|
|
298.280517578125, 155.6221160888672
|
|
],
|
|
[
|
|
167.8346405029297, 109.49150085449219,
|
|
211.50537109375, 140.08895874023438
|
|
],
|
|
[
|
|
482.0719909667969, 110.47320556640625,
|
|
523.1851806640625, 130.19410705566406
|
|
],
|
|
[
|
|
609.3395385742188, 113.26068115234375,
|
|
635.8460083007812, 136.93771362304688
|
|
],
|
|
[
|
|
266.5657958984375, 105.04171752929688,
|
|
326.9735107421875, 127.39012145996094
|
|
],
|
|
[
|
|
431.43096923828125, 105.18028259277344,
|
|
484.13787841796875, 131.9821319580078
|
|
],
|
|
[
|
|
60.43342971801758, 94.02497100830078,
|
|
86.346435546875, 106.31623840332031
|
|
],
|
|
[
|
|
139.32015991210938, 96.0668716430664,
|
|
167.1505126953125, 105.44377899169922
|
|
]]),
|
|
decimal=1)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|