mirror of https://github.com/alibaba/EasyCV.git
parent
4fa81beb21
commit
74cde39e66
|
@ -0,0 +1,74 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# MIT License
|
||||
|
||||
# Copyright (c) 2021 Yifu Zhang
|
||||
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class TrackState(object):
|
||||
New = 0
|
||||
Tracked = 1
|
||||
Lost = 2
|
||||
Removed = 3
|
||||
|
||||
|
||||
class BaseTrack(object):
|
||||
_count = 0
|
||||
|
||||
track_id = 0
|
||||
is_activated = False
|
||||
state = TrackState.New
|
||||
|
||||
history = OrderedDict()
|
||||
features = []
|
||||
curr_feature = None
|
||||
score = 0
|
||||
start_frame = 0
|
||||
frame_id = 0
|
||||
time_since_update = 0
|
||||
|
||||
# multi-camera
|
||||
location = (np.inf, np.inf)
|
||||
|
||||
@property
|
||||
def end_frame(self):
|
||||
return self.frame_id
|
||||
|
||||
@staticmethod
|
||||
def next_id():
|
||||
BaseTrack._count += 1
|
||||
return BaseTrack._count
|
||||
|
||||
def activate(self, *args):
|
||||
raise NotImplementedError
|
||||
|
||||
def predict(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def mark_lost(self):
|
||||
self.state = TrackState.Lost
|
||||
|
||||
def mark_removed(self):
|
||||
self.state = TrackState.Removed
|
|
@ -0,0 +1,363 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# MIT License
|
||||
|
||||
# Copyright (c) 2021 Yifu Zhang
|
||||
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
import os
|
||||
import os.path as osp
|
||||
import copy
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from easycv.thirdparty.mot.bytetrack.kalman_filter import KalmanFilter
|
||||
from easycv.thirdparty.mot.bytetrack.basetrack import BaseTrack, TrackState
|
||||
from easycv.thirdparty.mot.bytetrack import matching
|
||||
|
||||
|
||||
class STrack(BaseTrack):
|
||||
shared_kalman = KalmanFilter()
|
||||
def __init__(self, tlwh, score):
|
||||
|
||||
# wait activate
|
||||
self._tlwh = np.asarray(tlwh, dtype=np.float)
|
||||
self.kalman_filter = None
|
||||
self.mean, self.covariance = None, None
|
||||
self.is_activated = False
|
||||
|
||||
self.score = score
|
||||
self.tracklet_len = 0
|
||||
|
||||
def predict(self):
|
||||
mean_state = self.mean.copy()
|
||||
if self.state != TrackState.Tracked:
|
||||
mean_state[7] = 0
|
||||
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
|
||||
|
||||
@staticmethod
|
||||
def multi_predict(stracks):
|
||||
if len(stracks) > 0:
|
||||
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
||||
multi_covariance = np.asarray([st.covariance for st in stracks])
|
||||
for i, st in enumerate(stracks):
|
||||
if st.state != TrackState.Tracked:
|
||||
multi_mean[i][7] = 0
|
||||
multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
|
||||
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
|
||||
stracks[i].mean = mean
|
||||
stracks[i].covariance = cov
|
||||
|
||||
def activate(self, kalman_filter, frame_id):
|
||||
"""Start a new tracklet"""
|
||||
self.kalman_filter = kalman_filter
|
||||
self.track_id = self.next_id()
|
||||
self.mean, self.covariance = self.kalman_filter.initiate(self.tlwh_to_xyah(self._tlwh))
|
||||
|
||||
self.tracklet_len = 0
|
||||
self.state = TrackState.Tracked
|
||||
if frame_id == 1:
|
||||
self.is_activated = True
|
||||
# self.is_activated = True
|
||||
self.frame_id = frame_id
|
||||
self.start_frame = frame_id
|
||||
|
||||
def re_activate(self, new_track, frame_id, new_id=False):
|
||||
self.mean, self.covariance = self.kalman_filter.update(
|
||||
self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh)
|
||||
)
|
||||
self.tracklet_len = 0
|
||||
self.state = TrackState.Tracked
|
||||
self.is_activated = True
|
||||
self.frame_id = frame_id
|
||||
if new_id:
|
||||
self.track_id = self.next_id()
|
||||
self.score = new_track.score
|
||||
|
||||
def update(self, new_track, frame_id):
|
||||
"""
|
||||
Update a matched track
|
||||
:type new_track: STrack
|
||||
:type frame_id: int
|
||||
:type update_feature: bool
|
||||
:return:
|
||||
"""
|
||||
self.frame_id = frame_id
|
||||
self.tracklet_len += 1
|
||||
|
||||
new_tlwh = new_track.tlwh
|
||||
self.mean, self.covariance = self.kalman_filter.update(
|
||||
self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh))
|
||||
self.state = TrackState.Tracked
|
||||
self.is_activated = True
|
||||
|
||||
self.score = new_track.score
|
||||
|
||||
@property
|
||||
# @jit(nopython=True)
|
||||
def tlwh(self):
|
||||
"""Get current position in bounding box format `(top left x, top left y,
|
||||
width, height)`.
|
||||
"""
|
||||
if self.mean is None:
|
||||
return self._tlwh.copy()
|
||||
ret = self.mean[:4].copy()
|
||||
ret[2] *= ret[3]
|
||||
ret[:2] -= ret[2:] / 2
|
||||
return ret
|
||||
|
||||
@property
|
||||
# @jit(nopython=True)
|
||||
def tlbr(self):
|
||||
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
|
||||
`(top left, bottom right)`.
|
||||
"""
|
||||
ret = self.tlwh.copy()
|
||||
ret[2:] += ret[:2]
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
# @jit(nopython=True)
|
||||
def tlwh_to_xyah(tlwh):
|
||||
"""Convert bounding box to format `(center x, center y, aspect ratio,
|
||||
height)`, where the aspect ratio is `width / height`.
|
||||
"""
|
||||
ret = np.asarray(tlwh).copy()
|
||||
ret[:2] += ret[2:] / 2
|
||||
ret[2] /= ret[3]
|
||||
return ret
|
||||
|
||||
def to_xyah(self):
|
||||
return self.tlwh_to_xyah(self.tlwh)
|
||||
|
||||
@staticmethod
|
||||
# @jit(nopython=True)
|
||||
def tlbr_to_tlwh(tlbr):
|
||||
ret = np.asarray(tlbr).copy()
|
||||
ret[2:] -= ret[:2]
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
# @jit(nopython=True)
|
||||
def tlwh_to_tlbr(tlwh):
|
||||
ret = np.asarray(tlwh).copy()
|
||||
ret[2:] += ret[:2]
|
||||
return ret
|
||||
|
||||
def __repr__(self):
|
||||
return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)
|
||||
|
||||
|
||||
class BYTETracker(object):
|
||||
def __init__(self,
|
||||
det_high_thresh=0.7,
|
||||
det_low_thresh=0.1,
|
||||
match_thresh=0.8,
|
||||
match_thresh_second=1.0,
|
||||
match_thresh_init=1.0,
|
||||
track_buffer=5,
|
||||
frame_rate=25):
|
||||
|
||||
self.tracked_stracks = [] # type: list[STrack]
|
||||
self.lost_stracks = [] # type: list[STrack]
|
||||
self.removed_stracks = [] # type: list[STrack]
|
||||
|
||||
self.frame_id = 0
|
||||
self.track_thresh = det_high_thresh
|
||||
self.match_thresh_second = match_thresh_second
|
||||
self.match_thresh_init = match_thresh_init
|
||||
self.det_thresh = det_high_thresh
|
||||
self.match_thresh = match_thresh
|
||||
self.low_thresh = det_low_thresh
|
||||
|
||||
self.buffer_size = int(frame_rate / 30 * track_buffer)
|
||||
self.max_time_lost = self.buffer_size
|
||||
self.kalman_filter = KalmanFilter()
|
||||
|
||||
|
||||
def update(self, bboxes, scores, classes):
|
||||
self.frame_id += 1
|
||||
activated_starcks = []
|
||||
refind_stracks = []
|
||||
lost_stracks = []
|
||||
removed_stracks = []
|
||||
|
||||
remain_inds = scores > self.track_thresh
|
||||
inds_low = scores > self.low_thresh
|
||||
inds_high = scores < self.track_thresh
|
||||
|
||||
inds_second = np.logical_and(inds_low, inds_high)
|
||||
dets_second = bboxes[inds_second]
|
||||
dets = bboxes[remain_inds]
|
||||
scores_keep = scores[remain_inds]
|
||||
scores_second = scores[inds_second]
|
||||
|
||||
if len(dets) > 0:
|
||||
'''Detections'''
|
||||
detections = [STrack(STrack.tlbr_to_tlwh(tlbr), s) for
|
||||
(tlbr, s) in zip(dets, scores_keep)]
|
||||
else:
|
||||
detections = []
|
||||
|
||||
''' Add newly detected tracklets to tracked_stracks'''
|
||||
unconfirmed = []
|
||||
tracked_stracks = [] # type: list[STrack]
|
||||
for track in self.tracked_stracks:
|
||||
if not track.is_activated:
|
||||
unconfirmed.append(track)
|
||||
else:
|
||||
tracked_stracks.append(track)
|
||||
|
||||
''' Step 2: First association, with high score detection boxes'''
|
||||
strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
|
||||
# Predict the current location with KF
|
||||
STrack.multi_predict(strack_pool)
|
||||
dists = matching.iou_distance(strack_pool, detections)
|
||||
dists = matching.fuse_score(dists, detections)
|
||||
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.match_thresh)
|
||||
|
||||
for itracked, idet in matches:
|
||||
track = strack_pool[itracked]
|
||||
det = detections[idet]
|
||||
if track.state == TrackState.Tracked:
|
||||
track.update(detections[idet], self.frame_id)
|
||||
activated_starcks.append(track)
|
||||
else:
|
||||
track.re_activate(det, self.frame_id, new_id=False)
|
||||
refind_stracks.append(track)
|
||||
|
||||
''' Step 3: Second association, with low score detection boxes'''
|
||||
# association the untrack to the low score detections
|
||||
if len(dets_second) > 0:
|
||||
'''Detections'''
|
||||
detections_second = [STrack(STrack.tlbr_to_tlwh(tlbr), s) for
|
||||
(tlbr, s) in zip(dets_second, scores_second)]
|
||||
else:
|
||||
detections_second = []
|
||||
r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
|
||||
dists = matching.iou_distance(r_tracked_stracks, detections_second)
|
||||
matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=self.match_thresh_second)
|
||||
for itracked, idet in matches:
|
||||
track = r_tracked_stracks[itracked]
|
||||
det = detections_second[idet]
|
||||
if track.state == TrackState.Tracked:
|
||||
track.update(det, self.frame_id)
|
||||
activated_starcks.append(track)
|
||||
else:
|
||||
track.re_activate(det, self.frame_id, new_id=False)
|
||||
refind_stracks.append(track)
|
||||
|
||||
for it in u_track:
|
||||
track = r_tracked_stracks[it]
|
||||
if not track.state == TrackState.Lost:
|
||||
track.mark_lost()
|
||||
lost_stracks.append(track)
|
||||
|
||||
'''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
|
||||
detections = [detections[i] for i in u_detection]
|
||||
dists = matching.iou_distance(unconfirmed, detections)
|
||||
dists = matching.fuse_score(dists, detections)
|
||||
matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=self.match_thresh_init)
|
||||
for itracked, idet in matches:
|
||||
unconfirmed[itracked].update(detections[idet], self.frame_id)
|
||||
activated_starcks.append(unconfirmed[itracked])
|
||||
for it in u_unconfirmed:
|
||||
track = unconfirmed[it]
|
||||
track.mark_removed()
|
||||
removed_stracks.append(track)
|
||||
|
||||
""" Step 4: Init new stracks"""
|
||||
for inew in u_detection:
|
||||
track = detections[inew]
|
||||
if track.score < self.det_thresh:
|
||||
continue
|
||||
track.activate(self.kalman_filter, self.frame_id)
|
||||
activated_starcks.append(track)
|
||||
|
||||
""" Step 5: Update state"""
|
||||
for track in self.lost_stracks:
|
||||
if self.frame_id - track.end_frame > self.max_time_lost:
|
||||
track.mark_removed()
|
||||
removed_stracks.append(track)
|
||||
|
||||
# print('Ramained match {} s'.format(t4-t3))
|
||||
|
||||
self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
|
||||
self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
|
||||
self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
|
||||
self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
|
||||
self.lost_stracks.extend(lost_stracks)
|
||||
self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
|
||||
self.removed_stracks.extend(removed_stracks)
|
||||
self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
|
||||
# get scores of lost tracks
|
||||
output_stracks = [track for track in self.tracked_stracks if track.is_activated]
|
||||
|
||||
# output: [id, t, l, b, r, score]
|
||||
fres = []
|
||||
for t in output_stracks:
|
||||
tscore = t.score
|
||||
tid = t.track_id
|
||||
tlbr = t.tlwh_to_tlbr(t.tlwh)
|
||||
tlbr = [int(i) for i in tlbr]
|
||||
fres.append(np.array([tid] + tlbr + [tscore]))
|
||||
|
||||
return {'track_bboxes': np.array(fres)}
|
||||
|
||||
|
||||
def joint_stracks(tlista, tlistb):
|
||||
exists = {}
|
||||
res = []
|
||||
for t in tlista:
|
||||
exists[t.track_id] = 1
|
||||
res.append(t)
|
||||
for t in tlistb:
|
||||
tid = t.track_id
|
||||
if not exists.get(tid, 0):
|
||||
exists[tid] = 1
|
||||
res.append(t)
|
||||
return res
|
||||
|
||||
|
||||
def sub_stracks(tlista, tlistb):
|
||||
stracks = {}
|
||||
for t in tlista:
|
||||
stracks[t.track_id] = t
|
||||
for t in tlistb:
|
||||
tid = t.track_id
|
||||
if stracks.get(tid, 0):
|
||||
del stracks[tid]
|
||||
return list(stracks.values())
|
||||
|
||||
|
||||
def remove_duplicate_stracks(stracksa, stracksb):
|
||||
pdist = matching.iou_distance(stracksa, stracksb)
|
||||
pairs = np.where(pdist < 0.15)
|
||||
dupa, dupb = list(), list()
|
||||
for p, q in zip(*pairs):
|
||||
timep = stracksa[p].frame_id - stracksa[p].start_frame
|
||||
timeq = stracksb[q].frame_id - stracksb[q].start_frame
|
||||
if timep > timeq:
|
||||
dupb.append(q)
|
||||
else:
|
||||
dupa.append(p)
|
||||
resa = [t for i, t in enumerate(stracksa) if not i in dupa]
|
||||
resb = [t for i, t in enumerate(stracksb) if not i in dupb]
|
||||
return resa, resb
|
|
@ -0,0 +1,291 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# MIT License
|
||||
|
||||
# Copyright (c) 2021 Yifu Zhang
|
||||
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
import numpy as np
|
||||
import scipy.linalg
|
||||
|
||||
|
||||
"""
|
||||
Table for the 0.95 quantile of the chi-square distribution with N degrees of
|
||||
freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv
|
||||
function and used as Mahalanobis gating threshold.
|
||||
"""
|
||||
chi2inv95 = {
|
||||
1: 3.8415,
|
||||
2: 5.9915,
|
||||
3: 7.8147,
|
||||
4: 9.4877,
|
||||
5: 11.070,
|
||||
6: 12.592,
|
||||
7: 14.067,
|
||||
8: 15.507,
|
||||
9: 16.919}
|
||||
|
||||
|
||||
class KalmanFilter(object):
|
||||
"""
|
||||
A simple Kalman filter for tracking bounding boxes in image space.
|
||||
|
||||
The 8-dimensional state space
|
||||
|
||||
x, y, a, h, vx, vy, va, vh
|
||||
|
||||
contains the bounding box center position (x, y), aspect ratio a, height h,
|
||||
and their respective velocities.
|
||||
|
||||
Object motion follows a constant velocity model. The bounding box location
|
||||
(x, y, a, h) is taken as direct observation of the state space (linear
|
||||
observation model).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
ndim, dt = 4, 1.
|
||||
|
||||
# Create Kalman filter model matrices.
|
||||
self._motion_mat = np.eye(2 * ndim, 2 * ndim)
|
||||
for i in range(ndim):
|
||||
self._motion_mat[i, ndim + i] = dt
|
||||
self._update_mat = np.eye(ndim, 2 * ndim)
|
||||
|
||||
# Motion and observation uncertainty are chosen relative to the current
|
||||
# state estimate. These weights control the amount of uncertainty in
|
||||
# the model. This is a bit hacky.
|
||||
self._std_weight_position = 1. / 20
|
||||
self._std_weight_velocity = 1. / 160
|
||||
|
||||
def initiate(self, measurement):
|
||||
"""Create track from unassociated measurement.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
measurement : ndarray
|
||||
Bounding box coordinates (x, y, a, h) with center position (x, y),
|
||||
aspect ratio a, and height h.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector (8 dimensional) and covariance matrix (8x8
|
||||
dimensional) of the new track. Unobserved velocities are initialized
|
||||
to 0 mean.
|
||||
|
||||
"""
|
||||
mean_pos = measurement
|
||||
mean_vel = np.zeros_like(mean_pos)
|
||||
mean = np.r_[mean_pos, mean_vel]
|
||||
|
||||
std = [
|
||||
2 * self._std_weight_position * measurement[3],
|
||||
2 * self._std_weight_position * measurement[3],
|
||||
1e-2,
|
||||
2 * self._std_weight_position * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[3],
|
||||
1e-5,
|
||||
10 * self._std_weight_velocity * measurement[3]]
|
||||
covariance = np.diag(np.square(std))
|
||||
return mean, covariance
|
||||
|
||||
def predict(self, mean, covariance):
|
||||
"""Run Kalman filter prediction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The 8 dimensional mean vector of the object state at the previous
|
||||
time step.
|
||||
covariance : ndarray
|
||||
The 8x8 dimensional covariance matrix of the object state at the
|
||||
previous time step.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector and covariance matrix of the predicted
|
||||
state. Unobserved velocities are initialized to 0 mean.
|
||||
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[3],
|
||||
self._std_weight_position * mean[3],
|
||||
1e-2,
|
||||
self._std_weight_position * mean[3]]
|
||||
std_vel = [
|
||||
self._std_weight_velocity * mean[3],
|
||||
self._std_weight_velocity * mean[3],
|
||||
1e-5,
|
||||
self._std_weight_velocity * mean[3]]
|
||||
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
|
||||
|
||||
#mean = np.dot(self._motion_mat, mean)
|
||||
mean = np.dot(mean, self._motion_mat.T)
|
||||
covariance = np.linalg.multi_dot((
|
||||
self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
|
||||
|
||||
return mean, covariance
|
||||
|
||||
def project(self, mean, covariance):
|
||||
"""Project state distribution to measurement space.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The state's mean vector (8 dimensional array).
|
||||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the projected mean and covariance matrix of the given state
|
||||
estimate.
|
||||
|
||||
"""
|
||||
std = [
|
||||
self._std_weight_position * mean[3],
|
||||
self._std_weight_position * mean[3],
|
||||
1e-1,
|
||||
self._std_weight_position * mean[3]]
|
||||
innovation_cov = np.diag(np.square(std))
|
||||
|
||||
mean = np.dot(self._update_mat, mean)
|
||||
covariance = np.linalg.multi_dot((
|
||||
self._update_mat, covariance, self._update_mat.T))
|
||||
return mean, covariance + innovation_cov
|
||||
|
||||
def multi_predict(self, mean, covariance):
|
||||
"""Run Kalman filter prediction step (Vectorized version).
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The Nx8 dimensional mean matrix of the object states at the previous
|
||||
time step.
|
||||
covariance : ndarray
|
||||
The Nx8x8 dimensional covariance matrics of the object states at the
|
||||
previous time step.
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector and covariance matrix of the predicted
|
||||
state. Unobserved velocities are initialized to 0 mean.
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[:, 3],
|
||||
self._std_weight_position * mean[:, 3],
|
||||
1e-2 * np.ones_like(mean[:, 3]),
|
||||
self._std_weight_position * mean[:, 3]]
|
||||
std_vel = [
|
||||
self._std_weight_velocity * mean[:, 3],
|
||||
self._std_weight_velocity * mean[:, 3],
|
||||
1e-5 * np.ones_like(mean[:, 3]),
|
||||
self._std_weight_velocity * mean[:, 3]]
|
||||
sqr = np.square(np.r_[std_pos, std_vel]).T
|
||||
|
||||
motion_cov = []
|
||||
for i in range(len(mean)):
|
||||
motion_cov.append(np.diag(sqr[i]))
|
||||
motion_cov = np.asarray(motion_cov)
|
||||
|
||||
mean = np.dot(mean, self._motion_mat.T)
|
||||
left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))
|
||||
covariance = np.dot(left, self._motion_mat.T) + motion_cov
|
||||
|
||||
return mean, covariance
|
||||
|
||||
def update(self, mean, covariance, measurement):
|
||||
"""Run Kalman filter correction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The predicted state's mean vector (8 dimensional).
|
||||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
measurement : ndarray
|
||||
The 4 dimensional measurement vector (x, y, a, h), where (x, y)
|
||||
is the center position, a the aspect ratio, and h the height of the
|
||||
bounding box.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the measurement-corrected state distribution.
|
||||
|
||||
"""
|
||||
projected_mean, projected_cov = self.project(mean, covariance)
|
||||
|
||||
chol_factor, lower = scipy.linalg.cho_factor(
|
||||
projected_cov, lower=True, check_finite=False)
|
||||
kalman_gain = scipy.linalg.cho_solve(
|
||||
(chol_factor, lower), np.dot(covariance, self._update_mat.T).T,
|
||||
check_finite=False).T
|
||||
innovation = measurement - projected_mean
|
||||
|
||||
new_mean = mean + np.dot(innovation, kalman_gain.T)
|
||||
new_covariance = covariance - np.linalg.multi_dot((
|
||||
kalman_gain, projected_cov, kalman_gain.T))
|
||||
return new_mean, new_covariance
|
||||
|
||||
def gating_distance(self, mean, covariance, measurements,
|
||||
only_position=False, metric='maha'):
|
||||
"""Compute gating distance between state distribution and measurements.
|
||||
A suitable distance threshold can be obtained from `chi2inv95`. If
|
||||
`only_position` is False, the chi-square distribution has 4 degrees of
|
||||
freedom, otherwise 2.
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
Mean vector over the state distribution (8 dimensional).
|
||||
covariance : ndarray
|
||||
Covariance of the state distribution (8x8 dimensional).
|
||||
measurements : ndarray
|
||||
An Nx4 dimensional matrix of N measurements, each in
|
||||
format (x, y, a, h) where (x, y) is the bounding box center
|
||||
position, a the aspect ratio, and h the height.
|
||||
only_position : Optional[bool]
|
||||
If True, distance computation is done with respect to the bounding
|
||||
box center position only.
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns an array of length N, where the i-th element contains the
|
||||
squared Mahalanobis distance between (mean, covariance) and
|
||||
`measurements[i]`.
|
||||
"""
|
||||
mean, covariance = self.project(mean, covariance)
|
||||
if only_position:
|
||||
mean, covariance = mean[:2], covariance[:2, :2]
|
||||
measurements = measurements[:, :2]
|
||||
|
||||
d = measurements - mean
|
||||
if metric == 'gaussian':
|
||||
return np.sum(d * d, axis=1)
|
||||
elif metric == 'maha':
|
||||
cholesky_factor = np.linalg.cholesky(covariance)
|
||||
z = scipy.linalg.solve_triangular(
|
||||
cholesky_factor, d.T, lower=True, check_finite=False,
|
||||
overwrite_b=True)
|
||||
squared_maha = np.sum(z * z, axis=0)
|
||||
return squared_maha
|
||||
else:
|
||||
raise ValueError('invalid distance metric')
|
|
@ -0,0 +1,230 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# MIT License
|
||||
|
||||
# Copyright (c) 2021 Yifu Zhang
|
||||
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
|
||||
import time
|
||||
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
import cv2
|
||||
import numpy as np
|
||||
import scipy
|
||||
from cython_bbox import bbox_overlaps as bbox_ious
|
||||
from scipy.spatial.distance import cdist
|
||||
|
||||
from easycv.thirdparty.mot.bytetrack import kalman_filter
|
||||
|
||||
|
||||
def merge_matches(m1, m2, shape):
|
||||
O, P, Q = shape
|
||||
m1 = np.asarray(m1)
|
||||
m2 = np.asarray(m2)
|
||||
|
||||
M1 = scipy.sparse.coo_matrix((np.ones(len(m1)), (m1[:, 0], m1[:, 1])),
|
||||
shape=(O, P))
|
||||
M2 = scipy.sparse.coo_matrix((np.ones(len(m2)), (m2[:, 0], m2[:, 1])),
|
||||
shape=(P, Q))
|
||||
|
||||
mask = M1 * M2
|
||||
match = mask.nonzero()
|
||||
match = list(zip(match[0], match[1]))
|
||||
unmatched_O = tuple(set(range(O)) - set([i for i, j in match]))
|
||||
unmatched_Q = tuple(set(range(Q)) - set([j for i, j in match]))
|
||||
|
||||
return match, unmatched_O, unmatched_Q
|
||||
|
||||
|
||||
def _indices_to_matches(cost_matrix, indices, thresh):
|
||||
matched_cost = cost_matrix[tuple(zip(*indices))]
|
||||
matched_mask = (matched_cost <= thresh)
|
||||
|
||||
matches = indices[matched_mask]
|
||||
unmatched_a = tuple(set(range(cost_matrix.shape[0])) - set(matches[:, 0]))
|
||||
unmatched_b = tuple(set(range(cost_matrix.shape[1])) - set(matches[:, 1]))
|
||||
|
||||
return matches, unmatched_a, unmatched_b
|
||||
|
||||
|
||||
def linear_assignment(cost_matrix, thresh):
|
||||
|
||||
if cost_matrix.size == 0:
|
||||
return np.empty((0, 2),
|
||||
dtype=int), tuple(range(cost_matrix.shape[0])), tuple(
|
||||
range(cost_matrix.shape[1]))
|
||||
matches, unmatched_a, unmatched_b = [], [], []
|
||||
import lap
|
||||
cost, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
|
||||
for ix, mx in enumerate(x):
|
||||
if mx >= 0:
|
||||
matches.append([ix, mx])
|
||||
unmatched_a = np.where(x < 0)[0]
|
||||
unmatched_b = np.where(y < 0)[0]
|
||||
matches = np.asarray(matches)
|
||||
return matches, unmatched_a, unmatched_b
|
||||
|
||||
|
||||
def ious(atlbrs, btlbrs):
|
||||
"""
|
||||
Compute cost based on IoU
|
||||
:type atlbrs: list[tlbr] | np.ndarray
|
||||
:type atlbrs: list[tlbr] | np.ndarray
|
||||
|
||||
:rtype ious np.ndarray
|
||||
"""
|
||||
ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float)
|
||||
if ious.size == 0:
|
||||
return ious
|
||||
|
||||
ious = bbox_ious(
|
||||
np.ascontiguousarray(atlbrs, dtype=np.float),
|
||||
np.ascontiguousarray(btlbrs, dtype=np.float))
|
||||
|
||||
return ious
|
||||
|
||||
|
||||
def iou_distance(atracks, btracks):
|
||||
"""
|
||||
Compute cost based on IoU
|
||||
:type atracks: list[STrack]
|
||||
:type btracks: list[STrack]
|
||||
|
||||
:rtype cost_matrix np.ndarray
|
||||
"""
|
||||
|
||||
if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) or (
|
||||
len(btracks) > 0 and isinstance(btracks[0], np.ndarray)):
|
||||
atlbrs = atracks
|
||||
btlbrs = btracks
|
||||
else:
|
||||
atlbrs = [track.tlbr for track in atracks]
|
||||
btlbrs = [track.tlbr for track in btracks]
|
||||
_ious = ious(atlbrs, btlbrs)
|
||||
cost_matrix = 1 - _ious
|
||||
|
||||
return cost_matrix
|
||||
|
||||
|
||||
def v_iou_distance(atracks, btracks):
|
||||
"""
|
||||
Compute cost based on IoU
|
||||
:type atracks: list[STrack]
|
||||
:type btracks: list[STrack]
|
||||
|
||||
:rtype cost_matrix np.ndarray
|
||||
"""
|
||||
|
||||
if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) or (
|
||||
len(btracks) > 0 and isinstance(btracks[0], np.ndarray)):
|
||||
atlbrs = atracks
|
||||
btlbrs = btracks
|
||||
else:
|
||||
atlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in atracks]
|
||||
btlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in btracks]
|
||||
_ious = ious(atlbrs, btlbrs)
|
||||
cost_matrix = 1 - _ious
|
||||
|
||||
return cost_matrix
|
||||
|
||||
|
||||
def embedding_distance(tracks, detections, metric='cosine'):
|
||||
"""
|
||||
:param tracks: list[STrack]
|
||||
:param detections: list[BaseTrack]
|
||||
:param metric:
|
||||
:return: cost_matrix np.ndarray
|
||||
"""
|
||||
|
||||
cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float)
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
det_features = np.asarray([track.curr_feat for track in detections],
|
||||
dtype=np.float)
|
||||
#for i, track in enumerate(tracks):
|
||||
#cost_matrix[i, :] = np.maximum(0.0, cdist(track.smooth_feat.reshape(1,-1), det_features, metric))
|
||||
track_features = np.asarray([track.smooth_feat for track in tracks],
|
||||
dtype=np.float)
|
||||
cost_matrix = np.maximum(0.0, cdist(track_features, det_features,
|
||||
metric)) # Nomalized features
|
||||
return cost_matrix
|
||||
|
||||
|
||||
def gate_cost_matrix(kf, cost_matrix, tracks, detections, only_position=False):
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
gating_dim = 2 if only_position else 4
|
||||
gating_threshold = kalman_filter.chi2inv95[gating_dim]
|
||||
measurements = np.asarray([det.to_xyah() for det in detections])
|
||||
for row, track in enumerate(tracks):
|
||||
gating_distance = kf.gating_distance(track.mean, track.covariance,
|
||||
measurements, only_position)
|
||||
cost_matrix[row, gating_distance > gating_threshold] = np.inf
|
||||
return cost_matrix
|
||||
|
||||
|
||||
def fuse_motion(kf,
|
||||
cost_matrix,
|
||||
tracks,
|
||||
detections,
|
||||
only_position=False,
|
||||
lambda_=0.98):
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
gating_dim = 2 if only_position else 4
|
||||
gating_threshold = kalman_filter.chi2inv95[gating_dim]
|
||||
measurements = np.asarray([det.to_xyah() for det in detections])
|
||||
for row, track in enumerate(tracks):
|
||||
gating_distance = kf.gating_distance(
|
||||
track.mean,
|
||||
track.covariance,
|
||||
measurements,
|
||||
only_position,
|
||||
metric='maha')
|
||||
cost_matrix[row, gating_distance > gating_threshold] = np.inf
|
||||
cost_matrix[row] = lambda_ * cost_matrix[row] + (
|
||||
1 - lambda_) * gating_distance
|
||||
return cost_matrix
|
||||
|
||||
|
||||
def fuse_iou(cost_matrix, tracks, detections):
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
reid_sim = 1 - cost_matrix
|
||||
iou_dist = iou_distance(tracks, detections)
|
||||
iou_sim = 1 - iou_dist
|
||||
fuse_sim = reid_sim * (1 + iou_sim) / 2
|
||||
det_scores = np.array([det.score for det in detections])
|
||||
det_scores = np.expand_dims(
|
||||
det_scores, axis=0).repeat(
|
||||
cost_matrix.shape[0], axis=0)
|
||||
#fuse_sim = fuse_sim * (1 + det_scores) / 2
|
||||
fuse_cost = 1 - fuse_sim
|
||||
return fuse_cost
|
||||
|
||||
|
||||
def fuse_score(cost_matrix, detections):
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
iou_sim = 1 - cost_matrix
|
||||
det_scores = np.array([det.score for det in detections])
|
||||
det_scores = np.expand_dims(
|
||||
det_scores, axis=0).repeat(
|
||||
cost_matrix.shape[0], axis=0)
|
||||
fuse_sim = iou_sim * det_scores
|
||||
fuse_cost = 1 - fuse_sim
|
||||
return fuse_cost
|
|
@ -0,0 +1,127 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import mmcv
|
||||
|
||||
from easycv.predictors import DetectionPredictor
|
||||
from easycv.thirdparty.mot.bytetrack.byte_tracker import BYTETracker
|
||||
from easycv.thirdparty.mot.utils import detection_result_filter, show_result
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('config', help='config file')
|
||||
parser.add_argument('--input', help='input video file or folder')
|
||||
parser.add_argument(
|
||||
'--output', help='output video file (mp4 format) or folder')
|
||||
parser.add_argument('--checkpoint', help='checkpoint file')
|
||||
parser.add_argument(
|
||||
'--score-thr',
|
||||
type=float,
|
||||
default=0.0,
|
||||
help='The threshold of score to filter bboxes.')
|
||||
parser.add_argument(
|
||||
'--device', default='cuda:0', help='device used for inference')
|
||||
parser.add_argument(
|
||||
'--show',
|
||||
action='store_true',
|
||||
help='whether show the results on the fly')
|
||||
parser.add_argument('--fps', help='FPS of the output video')
|
||||
args = parser.parse_args()
|
||||
assert args.output or args.show
|
||||
# load images
|
||||
if osp.isdir(args.input):
|
||||
imgs = sorted(
|
||||
filter(lambda x: x.endswith(('.jpg', '.png', '.jpeg')),
|
||||
os.listdir(args.input)),
|
||||
key=lambda x: int(x.split('.')[0]))
|
||||
IN_VIDEO = False
|
||||
else:
|
||||
imgs = mmcv.VideoReader(args.input)
|
||||
IN_VIDEO = True
|
||||
# define output
|
||||
if args.output is not None:
|
||||
if args.output.endswith('.mp4'):
|
||||
OUT_VIDEO = True
|
||||
out_dir = tempfile.TemporaryDirectory()
|
||||
out_path = out_dir.name
|
||||
_out = args.output.rsplit(os.sep, 1)
|
||||
if len(_out) > 1:
|
||||
os.makedirs(_out[0], exist_ok=True)
|
||||
else:
|
||||
OUT_VIDEO = False
|
||||
out_path = args.output
|
||||
os.makedirs(out_path, exist_ok=True)
|
||||
|
||||
fps = args.fps
|
||||
if args.show or OUT_VIDEO:
|
||||
if fps is None and IN_VIDEO:
|
||||
fps = imgs.fps
|
||||
if not fps:
|
||||
raise ValueError('Please set the FPS for the output video.')
|
||||
fps = int(fps)
|
||||
|
||||
# build the model from a config file and a checkpoint file
|
||||
model = DetectionPredictor(args.checkpoint, args.config, score_threshold=0)
|
||||
tracker = BYTETracker(
|
||||
det_high_thresh=0.2,
|
||||
det_low_thresh=0.05,
|
||||
match_thresh=1.0,
|
||||
match_thresh_second=1.0,
|
||||
match_thresh_init=1.0,
|
||||
track_buffer=2,
|
||||
frame_rate=25)
|
||||
|
||||
prog_bar = mmcv.ProgressBar(len(imgs))
|
||||
|
||||
# test and show/save the images
|
||||
track_result = None
|
||||
for idx, img in enumerate(imgs):
|
||||
if isinstance(img, str):
|
||||
img = osp.join(args.input, img)
|
||||
result = model(img)[0]
|
||||
|
||||
detection_boxes = result['detection_boxes']
|
||||
detection_scores = result['detection_scores']
|
||||
detection_classes = result['detection_classes']
|
||||
|
||||
detection_boxes, detection_scores, detection_classes = detection_result_filter(
|
||||
detection_boxes,
|
||||
detection_scores,
|
||||
detection_classes,
|
||||
target_classes=[0],
|
||||
target_thresholds=[0])
|
||||
if len(detection_boxes) > 0:
|
||||
track_result = tracker.update(
|
||||
detection_boxes, detection_scores,
|
||||
detection_classes) # [id, t, l, b, r, score]
|
||||
|
||||
if args.output is not None:
|
||||
if IN_VIDEO or OUT_VIDEO:
|
||||
out_file = osp.join(out_path, f'{idx:06d}.jpg')
|
||||
else:
|
||||
out_file = osp.join(out_path, img.rsplit(os.sep, 1)[-1])
|
||||
else:
|
||||
out_file = None
|
||||
# if len(track_result['track_bboxes']) > 0:
|
||||
show_result(
|
||||
img,
|
||||
track_result,
|
||||
score_thr=args.score_thr,
|
||||
show=args.show,
|
||||
wait_time=int(1000. / fps) if fps else 0,
|
||||
out_file=out_file)
|
||||
prog_bar.update()
|
||||
|
||||
if args.output and OUT_VIDEO:
|
||||
print(f'making the output video at {args.output} with a FPS of {fps}')
|
||||
mmcv.frames2video(out_path, args.output, fps=fps, fourcc='mp4v')
|
||||
out_dir.cleanup()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,190 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import random
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
||||
|
||||
def detection_result_filter(bboxes,
|
||||
scores,
|
||||
classes,
|
||||
target_classes,
|
||||
target_thresholds=None):
|
||||
# post process to filter result
|
||||
bboxes_tmp = []
|
||||
scores_tmp = []
|
||||
classes_tmp = []
|
||||
assert len(target_classes) == len(
|
||||
target_thresholds
|
||||
), 'detection post process, class filter need target_classes and target_thresholds both, and should be same length!'
|
||||
|
||||
for bidx, bcls in enumerate(classes):
|
||||
if bcls in target_classes and scores[bidx] > target_thresholds[
|
||||
target_classes.index(bcls)]:
|
||||
bboxes_tmp.append(bboxes[bidx])
|
||||
scores_tmp.append(scores[bidx])
|
||||
classes_tmp.append(classes[bidx])
|
||||
bboxes = np.array(bboxes_tmp)
|
||||
scores = np.array(scores_tmp)
|
||||
classes = np.array(classes_tmp)
|
||||
return bboxes, scores, classes
|
||||
|
||||
|
||||
def results2outs(bbox_results=None, **kwargs):
|
||||
"""Restore the results (list of results of each category) into the results
|
||||
of the model forward.
|
||||
|
||||
Args:
|
||||
bbox_results (list[np.ndarray]): Each list denotes bboxes of one
|
||||
category.
|
||||
|
||||
Returns:
|
||||
tuple: tracking results of each class. It may contain keys as belows:
|
||||
|
||||
- bboxes (np.ndarray): shape (n, 5)
|
||||
- ids (np.ndarray): shape (n, )
|
||||
"""
|
||||
outputs = dict()
|
||||
|
||||
if len(bbox_results) > 0:
|
||||
|
||||
bboxes = bbox_results
|
||||
if bboxes.shape[1] == 5:
|
||||
outputs['bboxes'] = bboxes
|
||||
elif bboxes.shape[1] == 6:
|
||||
ids = bboxes[:, 0].astype(np.int64)
|
||||
bboxes = bboxes[:, 1:]
|
||||
outputs['bboxes'] = bboxes
|
||||
outputs['ids'] = ids
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'Not supported bbox shape: (N, {bboxes.shape[1]})')
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def random_color(seed):
|
||||
"""Random a color according to the input seed."""
|
||||
random.seed(seed)
|
||||
import seaborn as sns
|
||||
colors = sns.color_palette()
|
||||
color = random.choice(colors)
|
||||
return color
|
||||
|
||||
|
||||
def imshow_tracks(img,
|
||||
bboxes,
|
||||
ids,
|
||||
classes=None,
|
||||
score_thr=0.0,
|
||||
thickness=2,
|
||||
font_scale=0.4,
|
||||
show=False,
|
||||
wait_time=0,
|
||||
out_file=None):
|
||||
"""Show the tracks with opencv."""
|
||||
if isinstance(img, str):
|
||||
img = mmcv.imread(img)
|
||||
if bboxes is not None and ids is not None:
|
||||
assert bboxes.ndim == 2
|
||||
assert ids.ndim == 1
|
||||
assert bboxes.shape[1] == 5
|
||||
|
||||
img_shape = img.shape
|
||||
bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1])
|
||||
bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0])
|
||||
|
||||
inds = np.where(bboxes[:, -1] > score_thr)[0]
|
||||
bboxes = bboxes[inds]
|
||||
ids = ids[inds]
|
||||
|
||||
text_width, text_height = 9, 13
|
||||
for i, (bbox, id) in enumerate(zip(bboxes, ids)):
|
||||
x1, y1, x2, y2 = bbox[:4].astype(np.int32)
|
||||
score = float(bbox[-1])
|
||||
|
||||
# bbox
|
||||
bbox_color = random_color(id)
|
||||
bbox_color = [int(255 * _c) for _c in bbox_color][::-1]
|
||||
cv2.rectangle(
|
||||
img, (x1, y1), (x2, y2), bbox_color, thickness=thickness)
|
||||
|
||||
# score
|
||||
text = '{:.02f}'.format(score)
|
||||
width = len(text) * text_width
|
||||
img[y1:y1 + text_height, x1:x1 + width, :] = bbox_color
|
||||
cv2.putText(
|
||||
img,
|
||||
text, (x1, y1 + text_height - 2),
|
||||
cv2.FONT_HERSHEY_COMPLEX,
|
||||
font_scale,
|
||||
color=(0, 0, 0))
|
||||
|
||||
# id
|
||||
text = str(id)
|
||||
width = len(text) * text_width
|
||||
img[y1 + text_height:y1 + 2 * text_height,
|
||||
x1:x1 + width, :] = bbox_color
|
||||
cv2.putText(
|
||||
img,
|
||||
str(id), (x1, y1 + 2 * text_height - 2),
|
||||
cv2.FONT_HERSHEY_COMPLEX,
|
||||
font_scale,
|
||||
color=(0, 0, 0))
|
||||
|
||||
if show:
|
||||
mmcv.imshow(img, wait_time=wait_time)
|
||||
if out_file is not None:
|
||||
mmcv.imwrite(img, out_file)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def show_result(img,
|
||||
result,
|
||||
score_thr=0.0,
|
||||
thickness=1,
|
||||
font_scale=0.5,
|
||||
show=False,
|
||||
out_file=None,
|
||||
wait_time=0,
|
||||
**kwargs):
|
||||
"""Visualize tracking results.
|
||||
|
||||
Args:
|
||||
img (str | ndarray): Filename of loaded image.
|
||||
result (dict): Tracking result.
|
||||
- The value of key 'track_bboxes' is list with length
|
||||
num_classes, and each element in list is ndarray with
|
||||
shape(n, 6) in [id, tl_x, tl_y, br_x, br_y, score] format.
|
||||
- The value of key 'det_bboxes' is list with length
|
||||
num_classes, and each element in list is ndarray with
|
||||
shape(n, 5) in [tl_x, tl_y, br_x, br_y, score] format.
|
||||
thickness (int, optional): Thickness of lines. Defaults to 1.
|
||||
font_scale (float, optional): Font scales of texts. Defaults
|
||||
to 0.5.
|
||||
show (bool, optional): Whether show the visualizations on the
|
||||
fly. Defaults to False.
|
||||
out_file (str | None, optional): Output filename. Defaults to None.
|
||||
|
||||
Returns:
|
||||
ndarray: Visualized image.
|
||||
"""
|
||||
assert isinstance(result, dict)
|
||||
track_bboxes = result.get('track_bboxes', None)
|
||||
if isinstance(img, str):
|
||||
img = mmcv.imread(img)
|
||||
outs_track = results2outs(bbox_results=track_bboxes)
|
||||
img = imshow_tracks(
|
||||
img,
|
||||
outs_track.get('bboxes', None),
|
||||
outs_track.get('ids', None),
|
||||
score_thr=score_thr,
|
||||
thickness=thickness,
|
||||
font_scale=font_scale,
|
||||
show=show,
|
||||
out_file=out_file,
|
||||
wait_time=wait_time)
|
||||
return img
|
|
@ -3,8 +3,9 @@ http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/pkgs/whl/panopt
|
|||
http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/third_party/blade_compression-0.0.2-py3-none-any.whl
|
||||
https://developer.download.nvidia.com/compute/redist/nvidia-dali-cuda100/nvidia_dali_cuda100-0.25.0-1535750-py3-none-manylinux2014_x86_64.whl
|
||||
|
||||
# detection3d
|
||||
lap
|
||||
nuscenes-devkit
|
||||
open3d
|
||||
pyquaternion
|
||||
seaborn
|
||||
trimesh
|
||||
|
|
|
@ -55,6 +55,58 @@ class FCOSTest(unittest.TestCase):
|
|||
[189.96198, 108.948654, 297.10025, 154.80592]]),
|
||||
decimal=1)
|
||||
|
||||
@unittest.skip('skip bytetrack unittest')
|
||||
def test_bytetrack(self):
|
||||
from easycv.thirdparty.mot.bytetrack.byte_tracker import BYTETracker
|
||||
model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/fcos/fcos_epoch_12.pth'
|
||||
config_path = 'configs/detection/fcos/fcos_r50_torch_1x_coco.py'
|
||||
img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg'
|
||||
model = DetectionPredictor(model_path, config_path)
|
||||
output = model(img)[0]
|
||||
tracker = BYTETracker(
|
||||
det_high_thresh=0.2,
|
||||
det_low_thresh=0.05,
|
||||
match_thresh=1.0,
|
||||
match_thresh_second=1.0,
|
||||
match_thresh_init=1.0,
|
||||
track_buffer=2,
|
||||
frame_rate=25)
|
||||
track_result = tracker.update(output['detection_boxes'],
|
||||
output['detection_scores'],
|
||||
output['detection_classes'])
|
||||
|
||||
assert_array_almost_equal(
|
||||
track_result['track_bboxes'],
|
||||
np.array([[
|
||||
1.00000000e+00, 2.94000000e+02, 1.16000000e+02, 3.78000000e+02,
|
||||
1.49000000e+02, 7.14209914e-01
|
||||
],
|
||||
[
|
||||
2.00000000e+00, 4.80000000e+02, 1.10000000e+02,
|
||||
5.23000000e+02, 1.30000000e+02, 6.16470039e-01
|
||||
],
|
||||
[
|
||||
3.00000000e+00, 3.98000000e+02, 1.10000000e+02,
|
||||
4.33000000e+02, 1.33000000e+02, 5.85758626e-01
|
||||
],
|
||||
[
|
||||
4.00000000e+00, 6.08000000e+02, 1.11000000e+02,
|
||||
6.36000000e+02, 1.37000000e+02, 5.83925486e-01
|
||||
],
|
||||
[
|
||||
5.00000000e+00, 5.91000000e+02, 1.09000000e+02,
|
||||
6.19000000e+02, 1.26000000e+02, 5.37827313e-01
|
||||
],
|
||||
[
|
||||
6.00000000e+00, 4.31000000e+02, 1.04000000e+02,
|
||||
4.82000000e+02, 1.31000000e+02, 5.12700200e-01
|
||||
],
|
||||
[
|
||||
7.00000000e+00, 1.89000000e+02, 1.08000000e+02,
|
||||
2.97000000e+02, 1.54000000e+02, 5.07710576e-01
|
||||
]]),
|
||||
decimal=1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue