replace list in evaluator process with dict

pull/425/head
liaoxingyu 2021-03-09 20:07:13 +08:00
parent 44ad4b83b1
commit 68c190b53c
1 changed files with 31 additions and 26 deletions

View File

@ -5,6 +5,7 @@
"""
import copy
import logging
import itertools
from collections import OrderedDict
import numpy as np
@ -28,50 +29,54 @@ class ReidEvaluator(DatasetEvaluator):
self._num_query = num_query
self._output_dir = output_dir
self.features = []
self.pids = []
self.camids = []
self._cpu_device = torch.device('cpu')
self._predictions = []
def reset(self):
self.features = []
self.pids = []
self.camids = []
self._predictions = []
def process(self, inputs, outputs):
self.pids.extend(inputs["targets"])
self.camids.extend(inputs["camids"])
self.features.append(outputs.cpu())
prediction = {
'feats': outputs.to(self._cpu_device, torch.float32),
'pids': inputs['targets'].to(self._cpu_device),
'camids': inputs['camids'].to(self._cpu_device)
}
self._predictions.append(prediction)
def evaluate(self):
if comm.get_world_size() > 1:
comm.synchronize()
features = comm.gather(self.features)
features = sum(features, [])
predictions = comm.gather(self._predictions, dst=0)
predictions = list(itertools.chain(*predictions))
pids = comm.gather(self.pids)
pids = sum(pids, [])
if not comm.is_main_process():
return {}
camids = comm.gather(self.camids)
camids = sum(camids, [])
# fmt: off
if not comm.is_main_process(): return {}
# fmt: on
else:
features = self.features
pids = self.pids
camids = self.camids
predictions = self._predictions
features = []
pids = []
camids = []
for prediction in predictions:
features.append(prediction['feats'])
pids.append(prediction['pids'])
camids.append(prediction['camids'])
features = torch.cat(features, dim=0)
pids = torch.cat(pids, dim=0).numpy()
camids = torch.cat(camids, dim=0).numpy()
# query feature, person ids and camera ids
query_features = features[:self._num_query]
query_pids = np.asarray(pids[:self._num_query])
query_camids = np.asarray(camids[:self._num_query])
query_pids = pids[:self._num_query]
query_camids = camids[:self._num_query]
# gallery features, person ids and camera ids
gallery_features = features[self._num_query:]
gallery_pids = np.asarray(pids[self._num_query:])
gallery_camids = np.asarray(camids[self._num_query:])
gallery_pids = pids[self._num_query:]
gallery_camids = camids[self._num_query:]
self._results = OrderedDict()