tuofeilun 74cde39e66
Support Single-lens MOT (#258)
Support Single-lens MOT
2023-01-03 16:40:50 +08:00

113 lines
4.7 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import tempfile
import unittest
import numpy as np
from numpy.testing import assert_array_almost_equal
from easycv.predictors.detector import DetectionPredictor
class FCOSTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_fcos(self):
model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/fcos/fcos_epoch_12.pth'
config_path = 'configs/detection/fcos/fcos_r50_torch_1x_coco.py'
img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg'
model = DetectionPredictor(model_path, config_path)
output = model(img)[0]
with tempfile.NamedTemporaryFile(suffix='.jpg') as tmp_file:
tmp_save_path = tmp_file.name
model.visualize(img, output, out_file=tmp_save_path)
self.assertIn('detection_boxes', output)
self.assertIn('detection_scores', output)
self.assertIn('detection_classes', output)
self.assertIn('img_metas', output)
self.assertEqual(len(output['detection_boxes']), 7)
self.assertEqual(len(output['detection_scores']), 7)
self.assertEqual(len(output['detection_classes']), 7)
assert_array_almost_equal(
output['detection_classes'].tolist(),
np.array([2, 2, 2, 2, 2, 2, 2], dtype=np.int32).tolist())
assert_array_almost_equal(
output['detection_scores'],
np.array([
0.7142099, 0.61647004, 0.5857586, 0.5839255, 0.5378273,
0.5127002, 0.5077106
],
dtype=np.float32),
decimal=2)
assert_array_almost_equal(
output['detection_boxes'],
np.array([[294.96497, 116.47906, 378.7294, 149.90738],
[480.34415, 110.31671, 523.0271, 130.33409],
[398.22247, 110.64816, 433.01566, 133.1527],
[608.2505, 111.9937, 636.7885, 137.0966],
[591.46234, 109.84667, 619.6144, 126.97513],
[431.47202, 104.88086, 482.88544, 131.95964],
[189.96198, 108.948654, 297.10025, 154.80592]]),
decimal=1)
@unittest.skip('skip bytetrack unittest')
def test_bytetrack(self):
from easycv.thirdparty.mot.bytetrack.byte_tracker import BYTETracker
model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/fcos/fcos_epoch_12.pth'
config_path = 'configs/detection/fcos/fcos_r50_torch_1x_coco.py'
img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg'
model = DetectionPredictor(model_path, config_path)
output = model(img)[0]
tracker = BYTETracker(
det_high_thresh=0.2,
det_low_thresh=0.05,
match_thresh=1.0,
match_thresh_second=1.0,
match_thresh_init=1.0,
track_buffer=2,
frame_rate=25)
track_result = tracker.update(output['detection_boxes'],
output['detection_scores'],
output['detection_classes'])
assert_array_almost_equal(
track_result['track_bboxes'],
np.array([[
1.00000000e+00, 2.94000000e+02, 1.16000000e+02, 3.78000000e+02,
1.49000000e+02, 7.14209914e-01
],
[
2.00000000e+00, 4.80000000e+02, 1.10000000e+02,
5.23000000e+02, 1.30000000e+02, 6.16470039e-01
],
[
3.00000000e+00, 3.98000000e+02, 1.10000000e+02,
4.33000000e+02, 1.33000000e+02, 5.85758626e-01
],
[
4.00000000e+00, 6.08000000e+02, 1.11000000e+02,
6.36000000e+02, 1.37000000e+02, 5.83925486e-01
],
[
5.00000000e+00, 5.91000000e+02, 1.09000000e+02,
6.19000000e+02, 1.26000000e+02, 5.37827313e-01
],
[
6.00000000e+00, 4.31000000e+02, 1.04000000e+02,
4.82000000e+02, 1.31000000e+02, 5.12700200e-01
],
[
7.00000000e+00, 1.89000000e+02, 1.08000000e+02,
2.97000000e+02, 1.54000000e+02, 5.07710576e-01
]]),
decimal=1)
if __name__ == '__main__':
unittest.main()