mirror of https://github.com/JDAI-CV/fast-reid.git
replace list in evaluator process with dict
parent
44ad4b83b1
commit
68c190b53c
|
@ -5,6 +5,7 @@
|
|||
"""
|
||||
import copy
|
||||
import logging
|
||||
import itertools
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
|
@ -28,50 +29,54 @@ class ReidEvaluator(DatasetEvaluator):
|
|||
self._num_query = num_query
|
||||
self._output_dir = output_dir
|
||||
|
||||
self.features = []
|
||||
self.pids = []
|
||||
self.camids = []
|
||||
self._cpu_device = torch.device('cpu')
|
||||
|
||||
self._predictions = []
|
||||
|
||||
def reset(self):
|
||||
self.features = []
|
||||
self.pids = []
|
||||
self.camids = []
|
||||
self._predictions = []
|
||||
|
||||
def process(self, inputs, outputs):
|
||||
self.pids.extend(inputs["targets"])
|
||||
self.camids.extend(inputs["camids"])
|
||||
self.features.append(outputs.cpu())
|
||||
prediction = {
|
||||
'feats': outputs.to(self._cpu_device, torch.float32),
|
||||
'pids': inputs['targets'].to(self._cpu_device),
|
||||
'camids': inputs['camids'].to(self._cpu_device)
|
||||
|
||||
}
|
||||
self._predictions.append(prediction)
|
||||
|
||||
def evaluate(self):
|
||||
if comm.get_world_size() > 1:
|
||||
comm.synchronize()
|
||||
features = comm.gather(self.features)
|
||||
features = sum(features, [])
|
||||
predictions = comm.gather(self._predictions, dst=0)
|
||||
predictions = list(itertools.chain(*predictions))
|
||||
|
||||
pids = comm.gather(self.pids)
|
||||
pids = sum(pids, [])
|
||||
if not comm.is_main_process():
|
||||
return {}
|
||||
|
||||
camids = comm.gather(self.camids)
|
||||
camids = sum(camids, [])
|
||||
|
||||
# fmt: off
|
||||
if not comm.is_main_process(): return {}
|
||||
# fmt: on
|
||||
else:
|
||||
features = self.features
|
||||
pids = self.pids
|
||||
camids = self.camids
|
||||
predictions = self._predictions
|
||||
|
||||
features = []
|
||||
pids = []
|
||||
camids = []
|
||||
for prediction in predictions:
|
||||
features.append(prediction['feats'])
|
||||
pids.append(prediction['pids'])
|
||||
camids.append(prediction['camids'])
|
||||
|
||||
features = torch.cat(features, dim=0)
|
||||
pids = torch.cat(pids, dim=0).numpy()
|
||||
camids = torch.cat(camids, dim=0).numpy()
|
||||
# query feature, person ids and camera ids
|
||||
query_features = features[:self._num_query]
|
||||
query_pids = np.asarray(pids[:self._num_query])
|
||||
query_camids = np.asarray(camids[:self._num_query])
|
||||
query_pids = pids[:self._num_query]
|
||||
query_camids = camids[:self._num_query]
|
||||
|
||||
# gallery features, person ids and camera ids
|
||||
gallery_features = features[self._num_query:]
|
||||
gallery_pids = np.asarray(pids[self._num_query:])
|
||||
gallery_camids = np.asarray(camids[self._num_query:])
|
||||
gallery_pids = pids[self._num_query:]
|
||||
gallery_camids = camids[self._num_query:]
|
||||
|
||||
self._results = OrderedDict()
|
||||
|
||||
|
|
Loading…
Reference in New Issue