From 68c190b53cc2228e8396a6e06d86e7205db260a6 Mon Sep 17 00:00:00 2001 From: liaoxingyu Date: Tue, 9 Mar 2021 20:07:13 +0800 Subject: [PATCH] replace list in evaluator process with dict --- fastreid/evaluation/reid_evaluation.py | 57 ++++++++++++++------------ 1 file changed, 31 insertions(+), 26 deletions(-) diff --git a/fastreid/evaluation/reid_evaluation.py b/fastreid/evaluation/reid_evaluation.py index fa2b4d9..0ca6ec0 100644 --- a/fastreid/evaluation/reid_evaluation.py +++ b/fastreid/evaluation/reid_evaluation.py @@ -5,6 +5,7 @@ """ import copy import logging +import itertools from collections import OrderedDict import numpy as np @@ -28,50 +29,54 @@ class ReidEvaluator(DatasetEvaluator): self._num_query = num_query self._output_dir = output_dir - self.features = [] - self.pids = [] - self.camids = [] + self._cpu_device = torch.device('cpu') + + self._predictions = [] def reset(self): - self.features = [] - self.pids = [] - self.camids = [] + self._predictions = [] def process(self, inputs, outputs): - self.pids.extend(inputs["targets"]) - self.camids.extend(inputs["camids"]) - self.features.append(outputs.cpu()) + prediction = { + 'feats': outputs.to(self._cpu_device, torch.float32), + 'pids': inputs['targets'].to(self._cpu_device), + 'camids': inputs['camids'].to(self._cpu_device) + + } + self._predictions.append(prediction) def evaluate(self): if comm.get_world_size() > 1: comm.synchronize() - features = comm.gather(self.features) - features = sum(features, []) + predictions = comm.gather(self._predictions, dst=0) + predictions = list(itertools.chain(*predictions)) - pids = comm.gather(self.pids) - pids = sum(pids, []) + if not comm.is_main_process(): + return {} - camids = comm.gather(self.camids) - camids = sum(camids, []) - - # fmt: off - if not comm.is_main_process(): return {} - # fmt: on else: - features = self.features - pids = self.pids - camids = self.camids + predictions = self._predictions + + features = [] + pids = [] + camids = [] + for prediction in predictions: + features.append(prediction['feats']) + pids.append(prediction['pids']) + camids.append(prediction['camids']) features = torch.cat(features, dim=0) + pids = torch.cat(pids, dim=0).numpy() + camids = torch.cat(camids, dim=0).numpy() # query feature, person ids and camera ids query_features = features[:self._num_query] - query_pids = np.asarray(pids[:self._num_query]) - query_camids = np.asarray(camids[:self._num_query]) + query_pids = pids[:self._num_query] + query_camids = camids[:self._num_query] # gallery features, person ids and camera ids gallery_features = features[self._num_query:] - gallery_pids = np.asarray(pids[self._num_query:]) - gallery_camids = np.asarray(camids[self._num_query:]) + gallery_pids = pids[self._num_query:] + gallery_camids = camids[self._num_query:] self._results = OrderedDict()