diff --git a/fastreid/evaluation/clas_evaluator.py b/fastreid/evaluation/clas_evaluator.py index f40cc8e..763616a 100644 --- a/fastreid/evaluation/clas_evaluator.py +++ b/fastreid/evaluation/clas_evaluator.py @@ -46,11 +46,14 @@ class ClasEvaluator(DatasetEvaluator): self._predictions = [] def process(self, inputs, outputs): - predictions = { - "logits": outputs.to(self._cpu_device, torch.float32), - "labels": inputs["targets"].to(self._cpu_device), - } - self._predictions.append(predictions) + pred_logits = outputs.to(self._cpu_device, torch.float32) + labels = inputs["targets"].to(self._cpu_device) + + # measure accuracy + acc1, = accuracy(pred_logits, labels, topk=(1,)) + num_correct_acc1 = acc1 * labels.size(0) / 100 + + self._predictions.append({"num_correct": num_correct_acc1, "num_samples": labels.size(0)}) def evaluate(self): if comm.get_world_size() > 1: @@ -63,21 +66,16 @@ class ClasEvaluator(DatasetEvaluator): else: predictions = self._predictions - pred_logits = [] - labels = [] + total_correct_num = 0 + total_samples = 0 for prediction in predictions: - pred_logits.append(prediction['logits']) - labels.append(prediction['labels']) + total_correct_num += prediction["num_correct"] + total_samples += prediction["num_samples"] - pred_logits = torch.cat(pred_logits, dim=0) - labels = torch.cat(labels, dim=0) - - # measure accuracy and record loss - acc1, = accuracy(pred_logits, labels, topk=(1,)) + acc1 = total_correct_num / total_samples * 100 self._results = OrderedDict() self._results["Acc@1"] = acc1 - self._results["metric"] = acc1 return copy.deepcopy(self._results)