# Copyright (c) Alibaba, Inc. and its affiliates. import unittest import numpy as np import torch from mmcv.parallel import collate, scatter from numpy.testing import assert_array_almost_equal from torchvision.transforms import Compose from easycv.datasets.registry import PIPELINES from easycv.datasets.utils import replace_ImageToTensor from easycv.models import build_model from easycv.utils.checkpoint import load_checkpoint from easycv.utils.config_tools import mmcv_config_fromfile from easycv.utils.mmlab_utils import dynamic_adapt_for_mmlab from easycv.utils.registry import build_from_cfg class DETRTest(unittest.TestCase): def setUp(self): print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) def init_detr(self, model_path, config_path): self.model_path = model_path self.cfg = mmcv_config_fromfile(config_path) # dynamic adapt mmdet models dynamic_adapt_for_mmlab(self.cfg) # modify model_config if self.cfg.model.head.get('num_select', None): self.cfg.model.head.num_select = 10 # build model self.model = build_model(self.cfg.model) self.device = 'cuda' if torch.cuda.is_available() else 'cpu' map_location = 'cpu' if self.device == 'cpu' else 'cuda' self.ckpt = load_checkpoint( self.model, self.model_path, map_location=map_location) self.model.to(self.device) self.model.eval() self.CLASSES = self.cfg.CLASSES def predict(self, imgs): """Inference image(s) with the detector. Args: model (nn.Module): The loaded detector. imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]): Either image files or loaded images. Returns: If imgs is a list or tuple, the same length list type results will be returned, otherwise return the detection results directly. """ if isinstance(imgs, (list, tuple)): is_batch = True else: imgs = [imgs] is_batch = False cfg = self.cfg device = next(self.model.parameters()).device # model device if isinstance(imgs[0], np.ndarray): cfg = cfg.copy() # set loading pipeline type cfg.data.val.pipeline.insert( 0, dict( type='LoadImageFromWebcam', file_client_args=dict(backend='http'))) else: cfg = cfg.copy() # set loading pipeline type cfg.data.val.pipeline.insert( 0, dict( type='LoadImageFromFile', file_client_args=dict(backend='http'))) cfg.data.val.pipeline = replace_ImageToTensor(cfg.data.val.pipeline) transforms = [] for transform in cfg.data.val.pipeline: if 'img_scale' in transform: transform['img_scale'] = tuple(transform['img_scale']) if isinstance(transform, dict): transform = build_from_cfg(transform, PIPELINES) transforms.append(transform) elif callable(transform): transforms.append(transform) else: raise TypeError('transform must be callable or a dict') test_pipeline = Compose(transforms) datas = [] for img in imgs: # prepare data if isinstance(img, np.ndarray): # directly add img data = dict(img=img) else: # add information into dict data = dict(img_info=dict(filename=img), img_prefix=None) # build the data pipeline data = test_pipeline(data) datas.append(data) data = collate(datas, samples_per_gpu=len(imgs)) # just get the actual data from DataContainer data['img_metas'] = [ img_metas.data[0] for img_metas in data['img_metas'] ] data['img'] = [img.data[0] for img in data['img']] if next(self.model.parameters()).is_cuda: # scatter to specified GPU data = scatter(data, [device])[0] # forward the model with torch.no_grad(): results = self.model(mode='test', **data) return results def test_detr(self): model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/detr/epoch_150.pth' config_path = 'configs/detection/detr/detr_r50_8x2_150e_coco.py' img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg' self.init_detr(model_path, config_path) output = self.predict(img) self.assertIn('detection_boxes', output) self.assertIn('detection_scores', output) self.assertIn('detection_classes', output) self.assertIn('img_metas', output) self.assertEqual(len(output['detection_boxes'][0]), 100) self.assertEqual(len(output['detection_scores'][0]), 100) self.assertEqual(len(output['detection_classes'][0]), 100) self.assertListEqual( output['detection_classes'][0].tolist(), np.array([ 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 7, 2, 13, 2, 0, 13, 2, 0, 2, 56, 2, 7, 2, 2, 2, 2, 2, 2, 2, 7, 56, 2, 2, 7, 7, 2, 7, 2, 2, 56, 2, 7, 11, 2, 2, 2, 0, 7, 2, 2, 2, 2, 2, 7, 2, 2, 7, 2, 2, 2, 2, 13, 2, 2, 2, 13, 2, 2, 56, 2, 56, 2, 7, 56, 13, 7, 56, 2, 0, 2, 7, 2, 7, 2, 56, 2, 2, 2, 7, 56, 2, 2, 7, 2, 0, 2, 2 ], dtype=np.int32).tolist()) assert_array_almost_equal( output['detection_scores'][0], np.array([ 0.07836595922708511, 0.219977006316185, 0.5831383466720581, 0.4256463646888733, 0.9853266477584839, 0.24607707560062408, 0.28005731105804443, 0.500579833984375, 0.09835881739854813, 0.05178987979888916, 0.07282666862010956, 0.9802166819572449, 0.04826607555150986, 0.06967002153396606, 0.9541336894035339, 0.36800140142440796, 0.2821184992790222, 0.009824174456298351, 0.981455385684967, 0.9890823364257812, 0.11702633649110794, 0.3397829532623291, 0.03982163220643997, 0.06306332349777222, 0.07951728254556656, 0.949343204498291, 0.1537322700023651, 0.3483341634273529, 0.044335901737213135, 0.03239326551556587, 0.11274639517068863, 0.462695449590683, 0.03906852751970291, 0.006577627267688513, 0.651928722858429, 0.13711832463741302, 0.15317879617214203, 0.5399832129478455, 0.08868053555488586, 0.026992695406079292, 0.0887782946228981, 0.081451416015625, 0.5485899448394775, 0.1959853619337082, 0.20348815619945526, 0.1804366111755371, 0.04546552523970604, 0.4005874693393707, 0.4241448938846588, 0.20359477400779724, 0.18858052790164948, 0.5971255898475647, 0.6823391914367676, 0.09363959729671478, 0.9855959415435791, 0.5903261303901672, 0.0731084868311882, 0.9813686609268188, 0.9814890027046204, 0.11285952478647232, 0.46758928894996643, 0.5004158616065979, 0.5852540731430054, 0.1944422572851181, 0.04896926134824753, 0.17205820977687836, 0.188123881816864, 0.43242165446281433, 0.3784835636615753, 0.06754120439291, 0.8264386057853699, 0.054902296513319016, 0.05457871034741402, 0.05988362058997154, 0.054624997079372406, 0.37744957208633423, 0.08150151371955872, 0.015097505412995815, 0.1074686348438263, 0.004187499638646841, 0.9733405709266663, 0.15225540101528168, 0.711842954158783, 0.06490222364664078, 0.9512462615966797, 0.03674759343266487, 0.09688679873943329, 0.02119528315961361, 0.9736435413360596, 0.9338251948356628, 0.09611554443836212, 0.09142979979515076, 0.01647237129509449, 0.9805111289024353, 0.3779929280281067, 0.09553579986095428, 0.11411411315202713, 0.0063759335316717625, 0.2972108721733093, 0.07761078327894211 ], dtype=np.float32), decimal=2) assert_array_almost_equal( output['detection_boxes'][0], np.array([[ 131.10389709472656, 90.93302154541016, 148.95504760742188, 101.69216918945312 ], [ 239.10910034179688, 113.36551666259766, 256.0523376464844, 125.22894287109375 ], [ 132.1316375732422, 90.8366470336914, 151.00839233398438, 101.83119201660156 ], [ 579.37646484375, 108.26667785644531, 605.0717163085938, 124.79525756835938 ], [ 189.69073486328125, 108.04875946044922, 296.8011779785156, 154.44204711914062 ], [ 588.5413208007812, 107.89535522460938, 615.6463012695312, 124.41362762451172 ], [ 57.38536071777344, 89.7335433959961, 79.20274353027344, 102.61941528320312 ], [ 163.97628784179688, 92.95049285888672, 180.87033081054688, 102.6163330078125 ], [ 127.82454681396484, 90.27918243408203, 144.6781768798828, 99.71304321289062 ], [ 438.4545593261719, 103.00477600097656, 480.4275817871094, 121.69993591308594 ], [ 23.33353042602539, 113.83180236816406, 58.04608154296875, 141.02174377441406 ], [ 165.77127075195312, 108.23619842529297, 208.61306762695312, 136.4344482421875 ], [ 402.93792724609375, 110.03141784667969, 437.510009765625, 132.2660369873047 ], [ 571.1912841796875, 108.44816589355469, 596.1377563476562, 125.31720733642578 ], [ 564.8051147460938, 108.13015747070312, 593.9150390625, 126.26905822753906 ], [ 79.12519836425781, 105.58458709716797, 111.23175048828125, 119.07565307617188 ], [ 549.2890625, 110.77001190185547, 563.7803955078125, 122.94622039794922 ], [ 384.75567626953125, 118.3704605102539, 422.6398010253906, 138.06492614746094 ], [ 218.92532348632812, 177.14031982421875, 459.10760498046875, 381.1133728027344 ], [ 397.3675231933594, 110.41165161132812, 436.5208740234375, 133.16848754882812 ], [ 239.1329803466797, 114.10742950439453, 255.9464874267578, 126.29271697998047 ], [ 214.19888305664062, 95.95294952392578, 233.9381103515625, 104.9795913696289 ], [ 395.9658508300781, 148.40223693847656, 434.2126770019531, 182.97567749023438 ], [ 269.0050964355469, 104.73245239257812, 320.3499755859375, 123.80233001708984 ], [ 483.2522888183594, 107.74896240234375, 522.5078125, 128.4845428466797 ], [ 57.623504638671875, 90.50349426269531, 82.20435333251953, 103.5736312866211 ], [ 375.03070068359375, 117.73094940185547, 385.3664855957031, 132.47479248046875 ], [ 555.9656372070312, 102.33853912353516, 568.1951293945312, 113.82102966308594 ], [ 388.681640625, 107.18630981445312, 415.21905517578125, 121.56800842285156 ], [ 510.5431823730469, 107.02037811279297, 533.8384399414062, 122.61180114746094 ], [ 187.148193359375, 100.21558380126953, 253.47540283203125, 123.25538635253906 ], [ 552.3801879882812, 103.33021545410156, 564.61865234375, 115.99454498291016 ], [ 425.8926086425781, 104.35319519042969, 477.8686218261719, 130.82357788085938 ], [ 222.24378967285156, 176.4434051513672, 456.42266845703125, 312.5479431152344 ], [ 227.29019165039062, 98.5999755859375, 250.33477783203125, 107.1373291015625 ], [ 165.81600952148438, 107.86138916015625, 202.11196899414062, 134.08160400390625 ], [ 175.83389282226562, 89.20259857177734, 224.58187866210938, 105.41484832763672 ], [ 186.885986328125, 107.32003021240234, 300.068115234375, 152.51370239257812 ], [ 165.5398712158203, 107.98709106445312, 202.61941528320312, 134.54295349121094 ], [ 611.3699951171875, 110.01651000976562, 639.626953125, 133.97329711914062 ], [ 550.5084838867188, 104.33821105957031, 562.5010986328125, 115.29130554199219 ], [ 59.68817901611328, 97.86705017089844, 76.69844055175781, 110.52032470703125 ], [ 372.9800720214844, 135.38967895507812, 435.77044677734375, 187.3105010986328 ], [ 233.5430908203125, 99.0816421508789, 255.1123809814453, 109.31993103027344 ], [ 57.52912521362305, 90.85675048828125, 81.06048583984375, 104.28386688232422 ], [ 566.7468872070312, 81.79965209960938, 581.5723266601562, 92.56966400146484 ], [ 166.86119079589844, 108.41394805908203, 205.91815185546875, 135.9706268310547 ], [ 87.78324890136719, 89.354736328125, 108.06605529785156, 99.89331817626953 ], [ -0.0010925531387329102, 111.63217163085938, 13.02929401397705, 123.92339324951172 ], [ 235.19432067871094, 114.09554290771484, 251.94717407226562, 126.1430892944336 ], [ 268.6304626464844, 104.09455108642578, 328.7283935546875, 124.19095611572266 ], [ 599.0895385742188, 105.6767807006836, 627.2431640625, 121.63172912597656 ], [ 80.78914642333984, 88.88621520996094, 103.19034576416016, 99.85254669189453 ], [ 620.0072021484375, 109.53975677490234, 640.0657958984375, 133.46539306640625 ], [ 611.6638793945312, 109.55789947509766, 640.0369873046875, 135.73045349121094 ], [ 220.8399200439453, 96.41732025146484, 244.06399536132812, 105.75860595703125 ], [ 434.6024169921875, 105.29331970214844, 482.67218017578125, 130.61903381347656 ], [ 482.16302490234375, 108.22539520263672, 523.8209228515625, 128.8394317626953 ], [ 294.1483154296875, 114.8856201171875, 377.6090087890625, 148.9021453857422 ], [ 197.5174560546875, 91.83529663085938, 224.16571044921875, 104.21776580810547 ], [ 167.1007080078125, 94.26935577392578, 185.2867431640625, 103.67475128173828 ], [ 77.7591552734375, 88.3407974243164, 98.73750305175781, 98.35700988769531 ], [ 374.9325866699219, 117.9875259399414, 385.34991455078125, 132.2331085205078 ], [ 167.0509490966797, 108.91958618164062, 205.8968505859375, 136.49874877929688 ], [ 51.54475784301758, 104.56134796142578, 73.30614471435547, 120.99707794189453 ], [ 274.94195556640625, 101.94827270507812, 323.97650146484375, 115.7809829711914 ], [ 236.10757446289062, 97.75923919677734, 254.6915283203125, 107.15653228759766 ], [ 609.9969482421875, 100.36739349365234, 638.1365966796875, 115.2613754272461 ], [ 74.30718994140625, 103.82154083251953, 108.9682388305664, 118.71984100341797 ], [ 367.69842529296875, 118.2647933959961, 380.7364501953125, 132.5258331298828 ], [ 97.63047790527344, 89.68116760253906, 117.95793151855469, 100.69364166259766 ], [ 373.5556640625, 135.07652282714844, 434.33709716796875, 185.4060516357422 ], [ 54.454566955566406, 104.98534393310547, 71.63267517089844, 119.91905212402344 ], [ 375.6446838378906, 135.28944396972656, 435.38543701171875, 185.16932678222656 ], [ 614.5422973632812, 107.62055969238281, 640.0924072265625, 132.63742065429688 ], [ 433.181396484375, 103.48396301269531, 484.539794921875, 131.49851989746094 ], [ 375.8929443359375, 142.8272247314453, 434.06439208984375, 186.07620239257812 ], [ 331.8176574707031, 131.65638732910156, 437.4866027832031, 185.01356506347656 ], [ 0.02536296844482422, 109.71121978759766, 59.71095657348633, 143.3699951171875 ], [ 225.39547729492188, 179.25990295410156, 453.9075012207031, 378.88177490234375 ], [ -0.048813819885253906, 109.48809051513672, 61.9161376953125, 143.69325256347656 ], [ 230.8810272216797, 113.71331787109375, 246.87635803222656, 125.97274780273438 ], [ 90.28150939941406, 88.90176391601562, 112.26461791992188, 99.73652648925781 ], [ 398.63372802734375, 109.19901275634766, 437.08404541015625, 131.7988739013672 ], [ 591.0872802734375, 108.76611328125, 618.3915405273438, 124.87841796875 ], [ 294.1384582519531, 114.44403076171875, 380.4442138671875, 149.1887664794922 ], [ 203.1075439453125, 109.37834167480469, 235.15982055664062, 125.3521957397461 ], [ 223.9816436767578, 177.92266845703125, 456.14263916015625, 357.9170837402344 ], [ 267.4551086425781, 105.07503509521484, 325.762939453125, 128.3075408935547 ], [ 135.58905029296875, 91.94464111328125, 159.6639404296875, 103.34713745117188 ], [ 540.2098388671875, 103.45868682861328, 555.146240234375, 116.43596649169922 ], [ 293.979736328125, 114.60274505615234, 380.0762939453125, 149.8252716064453 ], [ 220.00042724609375, 175.55282592773438, 452.283447265625, 327.056640625 ], [ 433.3155822753906, 103.43656158447266, 485.6103515625, 130.87503051757812 ], [ 553.2203369140625, 101.77803802490234, 564.004150390625, 112.206298828125 ], [ 567.9517822265625, 107.1722640991211, 595.299560546875, 124.94164276123047 ], [ 555.1934814453125, 109.10118103027344, 572.1039428710938, 122.53047180175781 ], [ 77.2689208984375, 90.0588607788086, 501.6678466796875, 346.69378662109375 ], [ 552.4683227539062, 111.0732650756836, 567.466064453125, 123.53128814697266 ], [ 79.25263977050781, 89.3648452758789, 111.41500854492188, 101.7647476196289 ]]), decimal=1) def test_dab_detr(self): model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dab_detr/epoch_50.pth' config_path = 'configs/detection/dab_detr/dab_detr_r50_8x2_50e_coco.py' img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg' self.init_detr(model_path, config_path) output = self.predict(img) self.assertIn('detection_boxes', output) self.assertIn('detection_scores', output) self.assertIn('detection_classes', output) self.assertIn('img_metas', output) self.assertEqual(len(output['detection_boxes'][0]), 10) self.assertEqual(len(output['detection_scores'][0]), 10) self.assertEqual(len(output['detection_classes'][0]), 10) self.assertListEqual( output['detection_classes'][0].tolist(), np.array([2, 2, 13, 2, 2, 2, 2, 2, 2, 2], dtype=np.int32).tolist()) assert_array_almost_equal( output['detection_scores'][0], np.array([ 0.7688284516334534, 0.7646799683570862, 0.7159939408302307, 0.6902833580970764, 0.6633996367454529, 0.6523147821426392, 0.633848249912262, 0.6229104995727539, 0.611840009689331, 0.5631589293479919 ], dtype=np.float32), decimal=2) assert_array_almost_equal( output['detection_boxes'][0], np.array([[ 294.2984313964844, 116.07160949707031, 380.4406433105469, 149.6365509033203 ], [ 480.0610656738281, 109.99347686767578, 523.2314453125, 130.26318359375 ], [ 220.32269287109375, 176.51010131835938, 456.51715087890625, 386.30767822265625 ], [ 167.6925506591797, 108.25935363769531, 214.93780517578125, 138.94424438476562 ], [ 398.1152648925781, 111.34457397460938, 433.72052001953125, 133.36280822753906 ], [ 430.48736572265625, 104.4018325805664, 484.1470947265625, 132.18893432617188 ], [ 607.396728515625, 111.72560119628906, 637.2987670898438, 136.2375946044922 ], [ 267.43353271484375, 105.93965911865234, 327.1937561035156, 130.18527221679688 ], [ 589.790771484375, 110.36975860595703, 618.8001098632812, 126.1950454711914 ], [ 0.3374290466308594, 110.91182708740234, 63.00359344482422, 146.0926971435547 ]]), decimal=1) if __name__ == '__main__': unittest.main()