2022-07-11 14:52:49 +08:00
|
|
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
import unittest
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
from mmcv.parallel import collate, scatter
|
|
|
|
from numpy.testing import assert_array_almost_equal
|
|
|
|
from torchvision.transforms import Compose
|
|
|
|
|
|
|
|
from easycv.datasets.registry import PIPELINES
|
|
|
|
from easycv.datasets.utils import replace_ImageToTensor
|
|
|
|
from easycv.models import build_model
|
|
|
|
from easycv.utils.checkpoint import load_checkpoint
|
|
|
|
from easycv.utils.config_tools import mmcv_config_fromfile
|
|
|
|
from easycv.utils.registry import build_from_cfg
|
|
|
|
|
|
|
|
|
|
|
|
class DETRTest(unittest.TestCase):
|
|
|
|
|
|
|
|
def setUp(self):
|
|
|
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
|
|
|
|
|
|
|
def init_detr(self, model_path, config_path):
|
|
|
|
self.model_path = model_path
|
|
|
|
|
|
|
|
self.cfg = mmcv_config_fromfile(config_path)
|
|
|
|
|
|
|
|
# modify model_config
|
|
|
|
if self.cfg.model.head.get('num_select', None):
|
|
|
|
self.cfg.model.head.num_select = 10
|
|
|
|
|
|
|
|
# build model
|
|
|
|
self.model = build_model(self.cfg.model)
|
|
|
|
|
|
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
map_location = 'cpu' if self.device == 'cpu' else 'cuda'
|
|
|
|
self.ckpt = load_checkpoint(
|
|
|
|
self.model, self.model_path, map_location=map_location)
|
|
|
|
|
|
|
|
self.model.to(self.device)
|
|
|
|
self.model.eval()
|
|
|
|
|
|
|
|
self.CLASSES = self.cfg.CLASSES
|
|
|
|
|
|
|
|
def predict(self, imgs):
|
|
|
|
"""Inference image(s) with the detector.
|
|
|
|
Args:
|
|
|
|
model (nn.Module): The loaded detector.
|
|
|
|
imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
|
|
|
|
Either image files or loaded images.
|
|
|
|
Returns:
|
|
|
|
If imgs is a list or tuple, the same length list type results
|
|
|
|
will be returned, otherwise return the detection results directly.
|
|
|
|
"""
|
|
|
|
|
|
|
|
if isinstance(imgs, (list, tuple)):
|
|
|
|
is_batch = True
|
|
|
|
else:
|
|
|
|
imgs = [imgs]
|
|
|
|
is_batch = False
|
|
|
|
|
|
|
|
cfg = self.cfg
|
|
|
|
device = next(self.model.parameters()).device # model device
|
|
|
|
|
|
|
|
if isinstance(imgs[0], np.ndarray):
|
|
|
|
cfg = cfg.copy()
|
|
|
|
# set loading pipeline type
|
|
|
|
cfg.data.val.pipeline.insert(
|
|
|
|
0,
|
|
|
|
dict(
|
|
|
|
type='LoadImageFromWebcam',
|
|
|
|
file_client_args=dict(backend='http')))
|
|
|
|
else:
|
|
|
|
cfg = cfg.copy()
|
|
|
|
# set loading pipeline type
|
|
|
|
cfg.data.val.pipeline.insert(
|
|
|
|
0,
|
|
|
|
dict(
|
|
|
|
type='LoadImageFromFile',
|
|
|
|
file_client_args=dict(backend='http')))
|
|
|
|
|
|
|
|
cfg.data.val.pipeline = replace_ImageToTensor(cfg.data.val.pipeline)
|
|
|
|
|
|
|
|
transforms = []
|
|
|
|
for transform in cfg.data.val.pipeline:
|
|
|
|
if 'img_scale' in transform:
|
|
|
|
transform['img_scale'] = tuple(transform['img_scale'])
|
|
|
|
if isinstance(transform, dict):
|
|
|
|
transform = build_from_cfg(transform, PIPELINES)
|
|
|
|
transforms.append(transform)
|
|
|
|
elif callable(transform):
|
|
|
|
transforms.append(transform)
|
|
|
|
else:
|
|
|
|
raise TypeError('transform must be callable or a dict')
|
|
|
|
test_pipeline = Compose(transforms)
|
|
|
|
|
|
|
|
datas = []
|
|
|
|
for img in imgs:
|
|
|
|
# prepare data
|
|
|
|
if isinstance(img, np.ndarray):
|
|
|
|
# directly add img
|
|
|
|
data = dict(img=img)
|
|
|
|
else:
|
|
|
|
# add information into dict
|
|
|
|
data = dict(img_info=dict(filename=img), img_prefix=None)
|
|
|
|
# build the data pipeline
|
|
|
|
data = test_pipeline(data)
|
|
|
|
datas.append(data)
|
|
|
|
|
|
|
|
data = collate(datas, samples_per_gpu=len(imgs))
|
|
|
|
# just get the actual data from DataContainer
|
|
|
|
data['img_metas'] = [
|
|
|
|
img_metas.data[0] for img_metas in data['img_metas']
|
|
|
|
]
|
|
|
|
data['img'] = [img.data[0] for img in data['img']]
|
|
|
|
if next(self.model.parameters()).is_cuda:
|
|
|
|
# scatter to specified GPU
|
|
|
|
data = scatter(data, [device])[0]
|
|
|
|
|
|
|
|
# forward the model
|
|
|
|
with torch.no_grad():
|
|
|
|
results = self.model(mode='test', **data)
|
|
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
def test_detr(self):
|
|
|
|
model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/detr/epoch_150.pth'
|
|
|
|
config_path = 'configs/detection/detr/detr_r50_8x2_150e_coco.py'
|
|
|
|
img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg'
|
|
|
|
self.init_detr(model_path, config_path)
|
|
|
|
output = self.predict(img)
|
|
|
|
|
|
|
|
self.assertIn('detection_boxes', output)
|
|
|
|
self.assertIn('detection_scores', output)
|
|
|
|
self.assertIn('detection_classes', output)
|
|
|
|
self.assertIn('img_metas', output)
|
|
|
|
self.assertEqual(len(output['detection_boxes'][0]), 100)
|
|
|
|
self.assertEqual(len(output['detection_scores'][0]), 100)
|
|
|
|
self.assertEqual(len(output['detection_classes'][0]), 100)
|
|
|
|
|
|
|
|
self.assertListEqual(
|
|
|
|
output['detection_classes'][0].tolist(),
|
|
|
|
np.array([
|
|
|
|
2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 7, 2, 13, 2, 0, 13, 2,
|
|
|
|
0, 2, 56, 2, 7, 2, 2, 2, 2, 2, 2, 2, 7, 56, 2, 2, 7, 7, 2, 7,
|
|
|
|
2, 2, 56, 2, 7, 11, 2, 2, 2, 0, 7, 2, 2, 2, 2, 2, 7, 2, 2, 7,
|
|
|
|
2, 2, 2, 2, 13, 2, 2, 2, 13, 2, 2, 56, 2, 56, 2, 7, 56, 13, 7,
|
|
|
|
56, 2, 0, 2, 7, 2, 7, 2, 56, 2, 2, 2, 7, 56, 2, 2, 7, 2, 0, 2,
|
|
|
|
2
|
|
|
|
],
|
|
|
|
dtype=np.int32).tolist())
|
|
|
|
|
|
|
|
assert_array_almost_equal(
|
|
|
|
output['detection_scores'][0],
|
|
|
|
np.array([
|
|
|
|
0.07836595922708511, 0.219977006316185, 0.5831383466720581,
|
|
|
|
0.4256463646888733, 0.9853266477584839, 0.24607707560062408,
|
|
|
|
0.28005731105804443, 0.500579833984375, 0.09835881739854813,
|
|
|
|
0.05178987979888916, 0.07282666862010956, 0.9802166819572449,
|
|
|
|
0.04826607555150986, 0.06967002153396606, 0.9541336894035339,
|
|
|
|
0.36800140142440796, 0.2821184992790222, 0.009824174456298351,
|
|
|
|
0.981455385684967, 0.9890823364257812, 0.11702633649110794,
|
|
|
|
0.3397829532623291, 0.03982163220643997, 0.06306332349777222,
|
|
|
|
0.07951728254556656, 0.949343204498291, 0.1537322700023651,
|
|
|
|
0.3483341634273529, 0.044335901737213135, 0.03239326551556587,
|
|
|
|
0.11274639517068863, 0.462695449590683, 0.03906852751970291,
|
|
|
|
0.006577627267688513, 0.651928722858429, 0.13711832463741302,
|
|
|
|
0.15317879617214203, 0.5399832129478455, 0.08868053555488586,
|
|
|
|
0.026992695406079292, 0.0887782946228981, 0.081451416015625,
|
|
|
|
0.5485899448394775, 0.1959853619337082, 0.20348815619945526,
|
|
|
|
0.1804366111755371, 0.04546552523970604, 0.4005874693393707,
|
|
|
|
0.4241448938846588, 0.20359477400779724, 0.18858052790164948,
|
|
|
|
0.5971255898475647, 0.6823391914367676, 0.09363959729671478,
|
|
|
|
0.9855959415435791, 0.5903261303901672, 0.0731084868311882,
|
|
|
|
0.9813686609268188, 0.9814890027046204, 0.11285952478647232,
|
|
|
|
0.46758928894996643, 0.5004158616065979, 0.5852540731430054,
|
|
|
|
0.1944422572851181, 0.04896926134824753, 0.17205820977687836,
|
|
|
|
0.188123881816864, 0.43242165446281433, 0.3784835636615753,
|
|
|
|
0.06754120439291, 0.8264386057853699, 0.054902296513319016,
|
|
|
|
0.05457871034741402, 0.05988362058997154, 0.054624997079372406,
|
|
|
|
0.37744957208633423, 0.08150151371955872, 0.015097505412995815,
|
|
|
|
0.1074686348438263, 0.004187499638646841, 0.9733405709266663,
|
|
|
|
0.15225540101528168, 0.711842954158783, 0.06490222364664078,
|
|
|
|
0.9512462615966797, 0.03674759343266487, 0.09688679873943329,
|
|
|
|
0.02119528315961361, 0.9736435413360596, 0.9338251948356628,
|
|
|
|
0.09611554443836212, 0.09142979979515076, 0.01647237129509449,
|
|
|
|
0.9805111289024353, 0.3779929280281067, 0.09553579986095428,
|
|
|
|
0.11411411315202713, 0.0063759335316717625, 0.2972108721733093,
|
|
|
|
0.07761078327894211
|
|
|
|
],
|
|
|
|
dtype=np.float32),
|
|
|
|
decimal=2)
|
|
|
|
|
|
|
|
assert_array_almost_equal(
|
|
|
|
output['detection_boxes'][0],
|
|
|
|
np.array([[
|
|
|
|
131.10389709472656, 90.93302154541016, 148.95504760742188,
|
|
|
|
101.69216918945312
|
|
|
|
],
|
|
|
|
[
|
|
|
|
239.10910034179688, 113.36551666259766,
|
|
|
|
256.0523376464844, 125.22894287109375
|
|
|
|
],
|
|
|
|
[
|
|
|
|
132.1316375732422, 90.8366470336914,
|
|
|
|
151.00839233398438, 101.83119201660156
|
|
|
|
],
|
|
|
|
[
|
|
|
|
579.37646484375, 108.26667785644531,
|
|
|
|
605.0717163085938, 124.79525756835938
|
|
|
|
],
|
|
|
|
[
|
|
|
|
189.69073486328125, 108.04875946044922,
|
|
|
|
296.8011779785156, 154.44204711914062
|
|
|
|
],
|
|
|
|
[
|
|
|
|
588.5413208007812, 107.89535522460938,
|
|
|
|
615.6463012695312, 124.41362762451172
|
|
|
|
],
|
|
|
|
[
|
|
|
|
57.38536071777344, 89.7335433959961,
|
|
|
|
79.20274353027344, 102.61941528320312
|
|
|
|
],
|
|
|
|
[
|
|
|
|
163.97628784179688, 92.95049285888672,
|
|
|
|
180.87033081054688, 102.6163330078125
|
|
|
|
],
|
|
|
|
[
|
|
|
|
127.82454681396484, 90.27918243408203,
|
|
|
|
144.6781768798828, 99.71304321289062
|
|
|
|
],
|
|
|
|
[
|
|
|
|
438.4545593261719, 103.00477600097656,
|
|
|
|
480.4275817871094, 121.69993591308594
|
|
|
|
],
|
|
|
|
[
|
|
|
|
23.33353042602539, 113.83180236816406,
|
|
|
|
58.04608154296875, 141.02174377441406
|
|
|
|
],
|
|
|
|
[
|
|
|
|
165.77127075195312, 108.23619842529297,
|
|
|
|
208.61306762695312, 136.4344482421875
|
|
|
|
],
|
|
|
|
[
|
|
|
|
402.93792724609375, 110.03141784667969,
|
|
|
|
437.510009765625, 132.2660369873047
|
|
|
|
],
|
|
|
|
[
|
|
|
|
571.1912841796875, 108.44816589355469,
|
|
|
|
596.1377563476562, 125.31720733642578
|
|
|
|
],
|
|
|
|
[
|
|
|
|
564.8051147460938, 108.13015747070312,
|
|
|
|
593.9150390625, 126.26905822753906
|
|
|
|
],
|
|
|
|
[
|
|
|
|
79.12519836425781, 105.58458709716797,
|
|
|
|
111.23175048828125, 119.07565307617188
|
|
|
|
],
|
|
|
|
[
|
|
|
|
549.2890625, 110.77001190185547, 563.7803955078125,
|
|
|
|
122.94622039794922
|
|
|
|
],
|
|
|
|
[
|
|
|
|
384.75567626953125, 118.3704605102539,
|
|
|
|
422.6398010253906, 138.06492614746094
|
|
|
|
],
|
|
|
|
[
|
|
|
|
218.92532348632812, 177.14031982421875,
|
|
|
|
459.10760498046875, 381.1133728027344
|
|
|
|
],
|
|
|
|
[
|
|
|
|
397.3675231933594, 110.41165161132812,
|
|
|
|
436.5208740234375, 133.16848754882812
|
|
|
|
],
|
|
|
|
[
|
|
|
|
239.1329803466797, 114.10742950439453,
|
|
|
|
255.9464874267578, 126.29271697998047
|
|
|
|
],
|
|
|
|
[
|
|
|
|
214.19888305664062, 95.95294952392578,
|
|
|
|
233.9381103515625, 104.9795913696289
|
|
|
|
],
|
|
|
|
[
|
|
|
|
395.9658508300781, 148.40223693847656,
|
|
|
|
434.2126770019531, 182.97567749023438
|
|
|
|
],
|
|
|
|
[
|
|
|
|
269.0050964355469, 104.73245239257812,
|
|
|
|
320.3499755859375, 123.80233001708984
|
|
|
|
],
|
|
|
|
[
|
|
|
|
483.2522888183594, 107.74896240234375, 522.5078125,
|
|
|
|
128.4845428466797
|
|
|
|
],
|
|
|
|
[
|
|
|
|
57.623504638671875, 90.50349426269531,
|
|
|
|
82.20435333251953, 103.5736312866211
|
|
|
|
],
|
|
|
|
[
|
|
|
|
375.03070068359375, 117.73094940185547,
|
|
|
|
385.3664855957031, 132.47479248046875
|
|
|
|
],
|
|
|
|
[
|
|
|
|
555.9656372070312, 102.33853912353516,
|
|
|
|
568.1951293945312, 113.82102966308594
|
|
|
|
],
|
|
|
|
[
|
|
|
|
388.681640625, 107.18630981445312,
|
|
|
|
415.21905517578125, 121.56800842285156
|
|
|
|
],
|
|
|
|
[
|
|
|
|
510.5431823730469, 107.02037811279297,
|
|
|
|
533.8384399414062, 122.61180114746094
|
|
|
|
],
|
|
|
|
[
|
|
|
|
187.148193359375, 100.21558380126953,
|
|
|
|
253.47540283203125, 123.25538635253906
|
|
|
|
],
|
|
|
|
[
|
|
|
|
552.3801879882812, 103.33021545410156,
|
|
|
|
564.61865234375, 115.99454498291016
|
|
|
|
],
|
|
|
|
[
|
|
|
|
425.8926086425781, 104.35319519042969,
|
|
|
|
477.8686218261719, 130.82357788085938
|
|
|
|
],
|
|
|
|
[
|
|
|
|
222.24378967285156, 176.4434051513672,
|
|
|
|
456.42266845703125, 312.5479431152344
|
|
|
|
],
|
|
|
|
[
|
|
|
|
227.29019165039062, 98.5999755859375,
|
|
|
|
250.33477783203125, 107.1373291015625
|
|
|
|
],
|
|
|
|
[
|
|
|
|
165.81600952148438, 107.86138916015625,
|
|
|
|
202.11196899414062, 134.08160400390625
|
|
|
|
],
|
|
|
|
[
|
|
|
|
175.83389282226562, 89.20259857177734,
|
|
|
|
224.58187866210938, 105.41484832763672
|
|
|
|
],
|
|
|
|
[
|
|
|
|
186.885986328125, 107.32003021240234,
|
|
|
|
300.068115234375, 152.51370239257812
|
|
|
|
],
|
|
|
|
[
|
|
|
|
165.5398712158203, 107.98709106445312,
|
|
|
|
202.61941528320312, 134.54295349121094
|
|
|
|
],
|
|
|
|
[
|
|
|
|
611.3699951171875, 110.01651000976562, 639.626953125,
|
|
|
|
133.97329711914062
|
|
|
|
],
|
|
|
|
[
|
|
|
|
550.5084838867188, 104.33821105957031,
|
|
|
|
562.5010986328125, 115.29130554199219
|
|
|
|
],
|
|
|
|
[
|
|
|
|
59.68817901611328, 97.86705017089844,
|
|
|
|
76.69844055175781, 110.52032470703125
|
|
|
|
],
|
|
|
|
[
|
|
|
|
372.9800720214844, 135.38967895507812,
|
|
|
|
435.77044677734375, 187.3105010986328
|
|
|
|
],
|
|
|
|
[
|
|
|
|
233.5430908203125, 99.0816421508789,
|
|
|
|
255.1123809814453, 109.31993103027344
|
|
|
|
],
|
|
|
|
[
|
|
|
|
57.52912521362305, 90.85675048828125,
|
|
|
|
81.06048583984375, 104.28386688232422
|
|
|
|
],
|
|
|
|
[
|
|
|
|
566.7468872070312, 81.79965209960938,
|
|
|
|
581.5723266601562, 92.56966400146484
|
|
|
|
],
|
|
|
|
[
|
|
|
|
166.86119079589844, 108.41394805908203,
|
|
|
|
205.91815185546875, 135.9706268310547
|
|
|
|
],
|
|
|
|
[
|
|
|
|
87.78324890136719, 89.354736328125,
|
|
|
|
108.06605529785156, 99.89331817626953
|
|
|
|
],
|
|
|
|
[
|
|
|
|
-0.0010925531387329102, 111.63217163085938,
|
|
|
|
13.02929401397705, 123.92339324951172
|
|
|
|
],
|
|
|
|
[
|
|
|
|
235.19432067871094, 114.09554290771484,
|
|
|
|
251.94717407226562, 126.1430892944336
|
|
|
|
],
|
|
|
|
[
|
|
|
|
268.6304626464844, 104.09455108642578,
|
|
|
|
328.7283935546875, 124.19095611572266
|
|
|
|
],
|
|
|
|
[
|
|
|
|
599.0895385742188, 105.6767807006836, 627.2431640625,
|
|
|
|
121.63172912597656
|
|
|
|
],
|
|
|
|
[
|
|
|
|
80.78914642333984, 88.88621520996094,
|
|
|
|
103.19034576416016, 99.85254669189453
|
|
|
|
],
|
|
|
|
[
|
|
|
|
620.0072021484375, 109.53975677490234,
|
|
|
|
640.0657958984375, 133.46539306640625
|
|
|
|
],
|
|
|
|
[
|
|
|
|
611.6638793945312, 109.55789947509766,
|
|
|
|
640.0369873046875, 135.73045349121094
|
|
|
|
],
|
|
|
|
[
|
|
|
|
220.8399200439453, 96.41732025146484,
|
|
|
|
244.06399536132812, 105.75860595703125
|
|
|
|
],
|
|
|
|
[
|
|
|
|
434.6024169921875, 105.29331970214844,
|
|
|
|
482.67218017578125, 130.61903381347656
|
|
|
|
],
|
|
|
|
[
|
|
|
|
482.16302490234375, 108.22539520263672,
|
|
|
|
523.8209228515625, 128.8394317626953
|
|
|
|
],
|
|
|
|
[
|
|
|
|
294.1483154296875, 114.8856201171875,
|
|
|
|
377.6090087890625, 148.9021453857422
|
|
|
|
],
|
|
|
|
[
|
|
|
|
197.5174560546875, 91.83529663085938,
|
|
|
|
224.16571044921875, 104.21776580810547
|
|
|
|
],
|
|
|
|
[
|
|
|
|
167.1007080078125, 94.26935577392578,
|
|
|
|
185.2867431640625, 103.67475128173828
|
|
|
|
],
|
|
|
|
[
|
|
|
|
77.7591552734375, 88.3407974243164,
|
|
|
|
98.73750305175781, 98.35700988769531
|
|
|
|
],
|
|
|
|
[
|
|
|
|
374.9325866699219, 117.9875259399414,
|
|
|
|
385.34991455078125, 132.2331085205078
|
|
|
|
],
|
|
|
|
[
|
|
|
|
167.0509490966797, 108.91958618164062,
|
|
|
|
205.8968505859375, 136.49874877929688
|
|
|
|
],
|
|
|
|
[
|
|
|
|
51.54475784301758, 104.56134796142578,
|
|
|
|
73.30614471435547, 120.99707794189453
|
|
|
|
],
|
|
|
|
[
|
|
|
|
274.94195556640625, 101.94827270507812,
|
|
|
|
323.97650146484375, 115.7809829711914
|
|
|
|
],
|
|
|
|
[
|
|
|
|
236.10757446289062, 97.75923919677734,
|
|
|
|
254.6915283203125, 107.15653228759766
|
|
|
|
],
|
|
|
|
[
|
|
|
|
609.9969482421875, 100.36739349365234,
|
|
|
|
638.1365966796875, 115.2613754272461
|
|
|
|
],
|
|
|
|
[
|
|
|
|
74.30718994140625, 103.82154083251953,
|
|
|
|
108.9682388305664, 118.71984100341797
|
|
|
|
],
|
|
|
|
[
|
|
|
|
367.69842529296875, 118.2647933959961,
|
|
|
|
380.7364501953125, 132.5258331298828
|
|
|
|
],
|
|
|
|
[
|
|
|
|
97.63047790527344, 89.68116760253906,
|
|
|
|
117.95793151855469, 100.69364166259766
|
|
|
|
],
|
|
|
|
[
|
|
|
|
373.5556640625, 135.07652282714844,
|
|
|
|
434.33709716796875, 185.4060516357422
|
|
|
|
],
|
|
|
|
[
|
|
|
|
54.454566955566406, 104.98534393310547,
|
|
|
|
71.63267517089844, 119.91905212402344
|
|
|
|
],
|
|
|
|
[
|
|
|
|
375.6446838378906, 135.28944396972656,
|
|
|
|
435.38543701171875, 185.16932678222656
|
|
|
|
],
|
|
|
|
[
|
|
|
|
614.5422973632812, 107.62055969238281,
|
|
|
|
640.0924072265625, 132.63742065429688
|
|
|
|
],
|
|
|
|
[
|
|
|
|
433.181396484375, 103.48396301269531,
|
|
|
|
484.539794921875, 131.49851989746094
|
|
|
|
],
|
|
|
|
[
|
|
|
|
375.8929443359375, 142.8272247314453,
|
|
|
|
434.06439208984375, 186.07620239257812
|
|
|
|
],
|
|
|
|
[
|
|
|
|
331.8176574707031, 131.65638732910156,
|
|
|
|
437.4866027832031, 185.01356506347656
|
|
|
|
],
|
|
|
|
[
|
|
|
|
0.02536296844482422, 109.71121978759766,
|
|
|
|
59.71095657348633, 143.3699951171875
|
|
|
|
],
|
|
|
|
[
|
|
|
|
225.39547729492188, 179.25990295410156,
|
|
|
|
453.9075012207031, 378.88177490234375
|
|
|
|
],
|
|
|
|
[
|
|
|
|
-0.048813819885253906, 109.48809051513672,
|
|
|
|
61.9161376953125, 143.69325256347656
|
|
|
|
],
|
|
|
|
[
|
|
|
|
230.8810272216797, 113.71331787109375,
|
|
|
|
246.87635803222656, 125.97274780273438
|
|
|
|
],
|
|
|
|
[
|
|
|
|
90.28150939941406, 88.90176391601562,
|
|
|
|
112.26461791992188, 99.73652648925781
|
|
|
|
],
|
|
|
|
[
|
|
|
|
398.63372802734375, 109.19901275634766,
|
|
|
|
437.08404541015625, 131.7988739013672
|
|
|
|
],
|
|
|
|
[
|
|
|
|
591.0872802734375, 108.76611328125,
|
|
|
|
618.3915405273438, 124.87841796875
|
|
|
|
],
|
|
|
|
[
|
|
|
|
294.1384582519531, 114.44403076171875,
|
|
|
|
380.4442138671875, 149.1887664794922
|
|
|
|
],
|
|
|
|
[
|
|
|
|
203.1075439453125, 109.37834167480469,
|
|
|
|
235.15982055664062, 125.3521957397461
|
|
|
|
],
|
|
|
|
[
|
|
|
|
223.9816436767578, 177.92266845703125,
|
|
|
|
456.14263916015625, 357.9170837402344
|
|
|
|
],
|
|
|
|
[
|
|
|
|
267.4551086425781, 105.07503509521484,
|
|
|
|
325.762939453125, 128.3075408935547
|
|
|
|
],
|
|
|
|
[
|
|
|
|
135.58905029296875, 91.94464111328125,
|
|
|
|
159.6639404296875, 103.34713745117188
|
|
|
|
],
|
|
|
|
[
|
|
|
|
540.2098388671875, 103.45868682861328,
|
|
|
|
555.146240234375, 116.43596649169922
|
|
|
|
],
|
|
|
|
[
|
|
|
|
293.979736328125, 114.60274505615234,
|
|
|
|
380.0762939453125, 149.8252716064453
|
|
|
|
],
|
|
|
|
[
|
|
|
|
220.00042724609375, 175.55282592773438,
|
|
|
|
452.283447265625, 327.056640625
|
|
|
|
],
|
|
|
|
[
|
|
|
|
433.3155822753906, 103.43656158447266,
|
|
|
|
485.6103515625, 130.87503051757812
|
|
|
|
],
|
|
|
|
[
|
|
|
|
553.2203369140625, 101.77803802490234,
|
|
|
|
564.004150390625, 112.206298828125
|
|
|
|
],
|
|
|
|
[
|
|
|
|
567.9517822265625, 107.1722640991211,
|
|
|
|
595.299560546875, 124.94164276123047
|
|
|
|
],
|
|
|
|
[
|
|
|
|
555.1934814453125, 109.10118103027344,
|
|
|
|
572.1039428710938, 122.53047180175781
|
|
|
|
],
|
|
|
|
[
|
|
|
|
77.2689208984375, 90.0588607788086,
|
|
|
|
501.6678466796875, 346.69378662109375
|
|
|
|
],
|
|
|
|
[
|
|
|
|
552.4683227539062, 111.0732650756836,
|
|
|
|
567.466064453125, 123.53128814697266
|
|
|
|
],
|
|
|
|
[
|
|
|
|
79.25263977050781, 89.3648452758789,
|
|
|
|
111.41500854492188, 101.7647476196289
|
|
|
|
]]),
|
|
|
|
decimal=1)
|
|
|
|
|
|
|
|
def test_dab_detr(self):
|
2022-07-27 15:06:06 +08:00
|
|
|
model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dab_detr/dab_detr_epoch_50.pth'
|
2022-07-11 14:52:49 +08:00
|
|
|
config_path = 'configs/detection/dab_detr/dab_detr_r50_8x2_50e_coco.py'
|
|
|
|
img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg'
|
|
|
|
self.init_detr(model_path, config_path)
|
|
|
|
output = self.predict(img)
|
|
|
|
|
|
|
|
self.assertIn('detection_boxes', output)
|
|
|
|
self.assertIn('detection_scores', output)
|
|
|
|
self.assertIn('detection_classes', output)
|
|
|
|
self.assertIn('img_metas', output)
|
|
|
|
self.assertEqual(len(output['detection_boxes'][0]), 10)
|
|
|
|
self.assertEqual(len(output['detection_scores'][0]), 10)
|
|
|
|
self.assertEqual(len(output['detection_classes'][0]), 10)
|
|
|
|
|
|
|
|
self.assertListEqual(
|
|
|
|
output['detection_classes'][0].tolist(),
|
|
|
|
np.array([2, 2, 13, 2, 2, 2, 2, 2, 2, 2], dtype=np.int32).tolist())
|
|
|
|
|
|
|
|
assert_array_almost_equal(
|
|
|
|
output['detection_scores'][0],
|
|
|
|
np.array([
|
|
|
|
0.7688284516334534, 0.7646799683570862, 0.7159939408302307,
|
|
|
|
0.6902833580970764, 0.6633996367454529, 0.6523147821426392,
|
|
|
|
0.633848249912262, 0.6229104995727539, 0.611840009689331,
|
|
|
|
0.5631589293479919
|
|
|
|
],
|
|
|
|
dtype=np.float32),
|
|
|
|
decimal=2)
|
|
|
|
|
|
|
|
assert_array_almost_equal(
|
|
|
|
output['detection_boxes'][0],
|
|
|
|
np.array([[
|
|
|
|
294.2984313964844, 116.07160949707031, 380.4406433105469,
|
|
|
|
149.6365509033203
|
|
|
|
],
|
|
|
|
[
|
|
|
|
480.0610656738281, 109.99347686767578,
|
|
|
|
523.2314453125, 130.26318359375
|
|
|
|
],
|
|
|
|
[
|
|
|
|
220.32269287109375, 176.51010131835938,
|
|
|
|
456.51715087890625, 386.30767822265625
|
|
|
|
],
|
|
|
|
[
|
|
|
|
167.6925506591797, 108.25935363769531,
|
|
|
|
214.93780517578125, 138.94424438476562
|
|
|
|
],
|
|
|
|
[
|
|
|
|
398.1152648925781, 111.34457397460938,
|
|
|
|
433.72052001953125, 133.36280822753906
|
|
|
|
],
|
|
|
|
[
|
|
|
|
430.48736572265625, 104.4018325805664,
|
|
|
|
484.1470947265625, 132.18893432617188
|
|
|
|
],
|
|
|
|
[
|
|
|
|
607.396728515625, 111.72560119628906,
|
|
|
|
637.2987670898438, 136.2375946044922
|
|
|
|
],
|
|
|
|
[
|
|
|
|
267.43353271484375, 105.93965911865234,
|
|
|
|
327.1937561035156, 130.18527221679688
|
|
|
|
],
|
|
|
|
[
|
|
|
|
589.790771484375, 110.36975860595703,
|
|
|
|
618.8001098632812, 126.1950454711914
|
|
|
|
],
|
|
|
|
[
|
|
|
|
0.3374290466308594, 110.91182708740234,
|
|
|
|
63.00359344482422, 146.0926971435547
|
|
|
|
]]),
|
|
|
|
decimal=1)
|
|
|
|
|
2022-07-27 15:06:06 +08:00
|
|
|
def test_dn_detr(self):
|
|
|
|
model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dn_detr/dn_detr_epoch_50.pth'
|
|
|
|
config_path = 'configs/detection/dab_detr/dn_detr_r50_8x2_50e_coco.py'
|
|
|
|
img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg'
|
|
|
|
self.init_detr(model_path, config_path)
|
|
|
|
output = self.predict(img)
|
|
|
|
|
|
|
|
self.assertIn('detection_boxes', output)
|
|
|
|
self.assertIn('detection_scores', output)
|
|
|
|
self.assertIn('detection_classes', output)
|
|
|
|
self.assertIn('img_metas', output)
|
|
|
|
self.assertEqual(len(output['detection_boxes'][0]), 10)
|
|
|
|
self.assertEqual(len(output['detection_scores'][0]), 10)
|
|
|
|
self.assertEqual(len(output['detection_classes'][0]), 10)
|
|
|
|
|
|
|
|
self.assertListEqual(
|
|
|
|
output['detection_classes'][0].tolist(),
|
|
|
|
np.array([2, 13, 2, 2, 2, 2, 2, 2, 2, 2], dtype=np.int32).tolist())
|
|
|
|
|
|
|
|
assert_array_almost_equal(
|
|
|
|
output['detection_scores'][0],
|
|
|
|
np.array([
|
|
|
|
0.8800525665283203, 0.866659939289093, 0.8665854930877686,
|
|
|
|
0.8030595183372498, 0.7642921209335327, 0.7375038862228394,
|
|
|
|
0.7270554304122925, 0.6710091233253479, 0.6316548585891724,
|
|
|
|
0.6164721846580505
|
|
|
|
],
|
|
|
|
dtype=np.float32),
|
|
|
|
decimal=2)
|
|
|
|
|
|
|
|
assert_array_almost_equal(
|
|
|
|
output['detection_boxes'][0],
|
|
|
|
np.array([[
|
|
|
|
294.9338073730469, 115.7542495727539, 377.5517578125,
|
|
|
|
150.59274291992188
|
|
|
|
],
|
|
|
|
[
|
|
|
|
220.57424926757812, 175.97023010253906,
|
|
|
|
456.9001770019531, 383.2597351074219
|
|
|
|
],
|
|
|
|
[
|
|
|
|
479.5928649902344, 109.94012451171875,
|
|
|
|
523.7343139648438, 130.80604553222656
|
|
|
|
],
|
|
|
|
[
|
|
|
|
398.6956787109375, 111.45973205566406,
|
|
|
|
434.0437316894531, 134.1909637451172
|
|
|
|
],
|
|
|
|
[
|
|
|
|
166.98208618164062, 109.44792938232422,
|
|
|
|
210.35342407226562, 139.9746856689453
|
|
|
|
],
|
|
|
|
[
|
|
|
|
609.432373046875, 113.08062744140625,
|
|
|
|
635.9082641601562, 136.74383544921875
|
|
|
|
],
|
|
|
|
[
|
|
|
|
268.0716552734375, 105.00788879394531,
|
|
|
|
327.4037170410156, 128.01449584960938
|
|
|
|
],
|
|
|
|
[
|
|
|
|
190.77467346191406, 107.42850494384766,
|
|
|
|
298.35760498046875, 156.2850341796875
|
|
|
|
],
|
|
|
|
[
|
|
|
|
591.0296020507812, 110.53913116455078,
|
|
|
|
620.702880859375, 127.42123413085938
|
|
|
|
],
|
|
|
|
[
|
|
|
|
431.6607971191406, 105.04813385009766,
|
|
|
|
484.4869689941406, 132.45864868164062
|
|
|
|
]]),
|
|
|
|
decimal=1)
|
|
|
|
|
2022-07-11 14:52:49 +08:00
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
unittest.main()
|