mirror of https://github.com/alibaba/EasyCV.git
88 lines
4.3 KiB
Python
88 lines
4.3 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
"""
|
|
isort:skip_file
|
|
"""
|
|
import json
|
|
import os
|
|
import unittest
|
|
import numpy as np
|
|
import time
|
|
import cv2
|
|
import torch
|
|
import scipy.io
|
|
from easycv.predictors.mot_predictor import MOTPredictor
|
|
from tests.ut_config import TEST_MOT_DIR
|
|
from numpy.testing import assert_array_almost_equal
|
|
|
|
|
|
class MOTPredictorTest(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
|
|
|
@unittest.skip('skip mot unittest')
|
|
def test(self):
|
|
checkpoint = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/tracking/fcos_r50_epoch_12_export.pt'
|
|
output = None # output video file (mp4 format) or folder or None
|
|
imgs = TEST_MOT_DIR # input video file or folder
|
|
|
|
model = MOTPredictor(
|
|
model_path=checkpoint,
|
|
detection_predictor_config={
|
|
'type': 'DetectionPredictor',
|
|
'model_path': checkpoint,
|
|
'config_file': None,
|
|
'score_threshold': 0.2
|
|
},
|
|
save_path=output)
|
|
|
|
track_result_list = model(imgs)
|
|
assert_array_almost_equal(
|
|
track_result_list[0][10]['track_bboxes'].tolist(),
|
|
[[1.0, 1442.0, 646.0, 1518.0, 844.0, 0.6654601693153381],
|
|
[2.0, 222.0, 812.0, 363.0, 1079.0, 0.8128248453140259],
|
|
[3.0, 422.0, 780.0, 521.0, 1033.0, 0.7565178871154785],
|
|
[4.0, 662.0, 692.0, 762.0, 917.0, 0.6685569882392883],
|
|
[5.0, 1551.0, 671.0, 1655.0, 906.0, 0.7031927108764648],
|
|
[6.0, 373.0, 822.0, 496.0, 1078.0, 0.6881393790245056],
|
|
[7.0, 752.0, 710.0, 849.0, 922.0, 0.7925270199775696],
|
|
[8.0, 1694.0, 834.0, 1789.0, 1078.0, 0.7597178816795349],
|
|
[9.0, 1075.0, 521.0, 1130.0, 660.0, 0.7152606844902039],
|
|
[10.0, 1015.0, 522.0, 1074.0, 655.0, 0.6544228196144104],
|
|
[11.0, 874.0, 543.0, 933.0, 701.0, 0.6027799844741821],
|
|
[12.0, 789.0, 491.0, 839.0, 630.0, 0.5886006951332092],
|
|
[13.0, 921.0, 551.0, 985.0, 720.0, 0.478473424911499],
|
|
[14.0, 1613.0, 670.0, 1709.0, 890.0, 0.6661025285720825],
|
|
[15.0, 977.0, 612.0, 1050.0, 818.0, 0.5041629672050476],
|
|
[16.0, 962.0, 530.0, 1018.0, 662.0, 0.5144294500350952],
|
|
[17.0, 1258.0, 449.0, 1307.0, 546.0, 0.5149790048599243],
|
|
[18.0, 1230.0, 447.0, 1279.0, 550.0, 0.6033780574798584],
|
|
[19.0, 793.0, 456.0, 836.0, 569.0, 0.4767516255378723],
|
|
[20.0, 879.0, 478.0, 924.0, 581.0, 0.5247951149940491],
|
|
[21.0, 1192.0, 624.0, 1255.0, 812.0, 0.41572073101997375],
|
|
[22.0, 813.0, 540.0, 872.0, 706.0, 0.4569028317928314],
|
|
[23.0, 1455.0, 605.0, 1520.0, 772.0, 0.4564402997493744],
|
|
[24.0, 1011.0, 581.0, 1082.0, 761.0, 0.5536622405052185],
|
|
[25.0, 1050.0, 622.0, 1129.0, 817.0, 0.5953567624092102],
|
|
[28.0, 754.0, 477.0, 800.0, 585.0, 0.5676276683807373],
|
|
[29.0, 1097.0, 551.0, 1166.0, 723.0, 0.49641332030296326],
|
|
[30.0, 841.0, 459.0, 888.0, 581.0, 0.5733954906463623],
|
|
[31.0, 908.0, 456.0, 961.0, 581.0, 0.5193840861320496],
|
|
[33.0, 945.0, 579.0, 1011.0, 728.0, 0.37510085105895996],
|
|
[35.0, 1515.0, 521.0, 1555.0, 614.0, 0.3778478503227234],
|
|
[36.0, 1403.0, 578.0, 1471.0, 766.0, 0.5565927028656006],
|
|
[37.0, 1230.0, 573.0, 1289.0, 764.0, 0.4239506125450134],
|
|
[41.0, 1448.0, 515.0, 1496.0, 626.0, 0.37311825156211853],
|
|
[45.0, 1361.0, 495.0, 1402.0, 592.0, 0.22757700085639954],
|
|
[46.0, 1083.0, 617.0, 1147.0, 768.0, 0.418988972902298],
|
|
[50.0, 1323.0, 591.0, 1386.0, 768.0, 0.24450582265853882],
|
|
[52.0, 1194.0, 572.0, 1257.0, 754.0, 0.333666056394577],
|
|
[53.0, 1267.0, 589.0, 1331.0, 766.0, 0.34295716881752014],
|
|
[54.0, 821.0, 481.0, 871.0, 610.0, 0.47558629512786865],
|
|
[26.0, 1037.0, 585.0, 1124.0, 783.0, 0.35130345821380615],
|
|
[27.0, 830.0, 477.0, 877.0, 595.0, 0.21788115799427032],
|
|
[56.0, 886.0, 479.0, 941.0, 589.0, 0.21801409125328064]],
|
|
decimal=1)
|
|
assert_array_almost_equal(track_result_list[0][10]['timestamp'],
|
|
0.41666666666666663)
|