EasyCV/tests/test_predictors/test_mot_predictor.py

88 lines
4.3 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
"""
isort:skip_file
"""
import json
import os
import unittest
import numpy as np
import time
import cv2
import torch
import scipy.io
from easycv.predictors.mot_predictor import MOTPredictor
from tests.ut_config import TEST_MOT_DIR
from numpy.testing import assert_array_almost_equal
class MOTPredictorTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
@unittest.skip('skip mot unittest')
def test(self):
checkpoint = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/tracking/fcos_r50_epoch_12_export.pt'
output = None # output video file (mp4 format) or folder or None
imgs = TEST_MOT_DIR # input video file or folder
model = MOTPredictor(
model_path=checkpoint,
detection_predictor_config={
'type': 'DetectionPredictor',
'model_path': checkpoint,
'config_file': None,
'score_threshold': 0.2
},
save_path=output)
track_result_list = model(imgs)
assert_array_almost_equal(
track_result_list[0][10]['track_bboxes'].tolist(),
[[1.0, 1442.0, 646.0, 1518.0, 844.0, 0.6654601693153381],
[2.0, 222.0, 812.0, 363.0, 1079.0, 0.8128248453140259],
[3.0, 422.0, 780.0, 521.0, 1033.0, 0.7565178871154785],
[4.0, 662.0, 692.0, 762.0, 917.0, 0.6685569882392883],
[5.0, 1551.0, 671.0, 1655.0, 906.0, 0.7031927108764648],
[6.0, 373.0, 822.0, 496.0, 1078.0, 0.6881393790245056],
[7.0, 752.0, 710.0, 849.0, 922.0, 0.7925270199775696],
[8.0, 1694.0, 834.0, 1789.0, 1078.0, 0.7597178816795349],
[9.0, 1075.0, 521.0, 1130.0, 660.0, 0.7152606844902039],
[10.0, 1015.0, 522.0, 1074.0, 655.0, 0.6544228196144104],
[11.0, 874.0, 543.0, 933.0, 701.0, 0.6027799844741821],
[12.0, 789.0, 491.0, 839.0, 630.0, 0.5886006951332092],
[13.0, 921.0, 551.0, 985.0, 720.0, 0.478473424911499],
[14.0, 1613.0, 670.0, 1709.0, 890.0, 0.6661025285720825],
[15.0, 977.0, 612.0, 1050.0, 818.0, 0.5041629672050476],
[16.0, 962.0, 530.0, 1018.0, 662.0, 0.5144294500350952],
[17.0, 1258.0, 449.0, 1307.0, 546.0, 0.5149790048599243],
[18.0, 1230.0, 447.0, 1279.0, 550.0, 0.6033780574798584],
[19.0, 793.0, 456.0, 836.0, 569.0, 0.4767516255378723],
[20.0, 879.0, 478.0, 924.0, 581.0, 0.5247951149940491],
[21.0, 1192.0, 624.0, 1255.0, 812.0, 0.41572073101997375],
[22.0, 813.0, 540.0, 872.0, 706.0, 0.4569028317928314],
[23.0, 1455.0, 605.0, 1520.0, 772.0, 0.4564402997493744],
[24.0, 1011.0, 581.0, 1082.0, 761.0, 0.5536622405052185],
[25.0, 1050.0, 622.0, 1129.0, 817.0, 0.5953567624092102],
[28.0, 754.0, 477.0, 800.0, 585.0, 0.5676276683807373],
[29.0, 1097.0, 551.0, 1166.0, 723.0, 0.49641332030296326],
[30.0, 841.0, 459.0, 888.0, 581.0, 0.5733954906463623],
[31.0, 908.0, 456.0, 961.0, 581.0, 0.5193840861320496],
[33.0, 945.0, 579.0, 1011.0, 728.0, 0.37510085105895996],
[35.0, 1515.0, 521.0, 1555.0, 614.0, 0.3778478503227234],
[36.0, 1403.0, 578.0, 1471.0, 766.0, 0.5565927028656006],
[37.0, 1230.0, 573.0, 1289.0, 764.0, 0.4239506125450134],
[41.0, 1448.0, 515.0, 1496.0, 626.0, 0.37311825156211853],
[45.0, 1361.0, 495.0, 1402.0, 592.0, 0.22757700085639954],
[46.0, 1083.0, 617.0, 1147.0, 768.0, 0.418988972902298],
[50.0, 1323.0, 591.0, 1386.0, 768.0, 0.24450582265853882],
[52.0, 1194.0, 572.0, 1257.0, 754.0, 0.333666056394577],
[53.0, 1267.0, 589.0, 1331.0, 766.0, 0.34295716881752014],
[54.0, 821.0, 481.0, 871.0, 610.0, 0.47558629512786865],
[26.0, 1037.0, 585.0, 1124.0, 783.0, 0.35130345821380615],
[27.0, 830.0, 477.0, 877.0, 595.0, 0.21788115799427032],
[56.0, 886.0, 479.0, 941.0, 589.0, 0.21801409125328064]],
decimal=1)
assert_array_almost_equal(track_result_list[0][10]['timestamp'],
0.41666666666666663)