From 1feda07ce6ede45d3d560ad268da7333bbf6f2e2 Mon Sep 17 00:00:00 2001 From: "zuchen.wang" Date: Sat, 16 Oct 2021 21:24:16 +0800 Subject: [PATCH] evaluation success --- fastreid/engine/hooks.py | 6 ++ fastreid/evaluation/pair_evaluator.py | 109 ++++++++++++++++++++++++++ fastreid/evaluation/testing.py | 21 ++++- projects/FastShoe/fastshoe/trainer.py | 4 +- 4 files changed, 135 insertions(+), 5 deletions(-) create mode 100644 fastreid/evaluation/pair_evaluator.py diff --git a/fastreid/engine/hooks.py b/fastreid/engine/hooks.py index 0de4a4a..fecfe46 100644 --- a/fastreid/engine/hooks.py +++ b/fastreid/engine/hooks.py @@ -9,6 +9,7 @@ import tempfile import time from collections import Counter +import numpy as np import torch from torch import nn from torch.nn.parallel import DistributedDataParallel @@ -355,6 +356,11 @@ class EvalHook(HookBase): results, dict ), "Eval function must return a dict. Got {} instead.".format(results) + # drop np.array values in results + for k, v in list(results.items()): + if isinstance(v, (list, np.ndarray)): + results.pop(k) + flattened_results = flatten_results_dict(results) for k, v in flattened_results.items(): try: diff --git a/fastreid/evaluation/pair_evaluator.py b/fastreid/evaluation/pair_evaluator.py new file mode 100644 index 0000000..f90b62f --- /dev/null +++ b/fastreid/evaluation/pair_evaluator.py @@ -0,0 +1,109 @@ +# coding: utf-8 + +import copy +import itertools +import logging +from collections import OrderedDict + +import numpy as np +import torch +from fastreid.utils import comm +from sklearn import metrics as skmetrics + +from .clas_evaluator import ClasEvaluator + +logger = logging.getLogger(__name__) + + +class PairEvaluator(ClasEvaluator): + def __init__(self, cfg, output_dir=None): + super(PairEvaluator, self).__init__(cfg=cfg, output_dir=output_dir) + self._threshold_list = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98] + + def process(self, inputs, outputs): + pred_logits = outputs.to(self._cpu_device, torch.float32) + labels = inputs["targets"].to(self._cpu_device) + + with torch.no_grad(): + probs = torch.softmax(pred_logits, dim=-1) + probs, _ = torch.max(probs, dim=-1) + + labels = labels.numpy() + probs = probs.numpy() + batch_size = probs.shape[0] + + # 计算这3个总体值,还有给定阈值下的precision, recall, f1 + acc = skmetrics.accuracy_score(labels, probs > 0.5) * batch_size + ap = skmetrics.average_precision_score(labels, probs) * batch_size + auc = skmetrics.roc_auc_score(labels, probs) * batch_size # auc under roc + + precisions = [] + recalls = [] + f1s = [] + for thresh in self._threshold_list: + precision = skmetrics.precision_score(labels, probs >= thresh, zero_division=0) * batch_size + recall = skmetrics.recall_score(labels, probs >= thresh, zero_division=0) * batch_size + if precision + recall == 0: + f1 = 0 + else: + f1 = 2 * precision * recall / (precision + recall) * batch_size + + precisions.append(precision) + recalls.append(recall) + f1s.append(f1) + + self._predictions.append({ + 'acc': acc, + 'ap': ap, + 'auc': auc, + 'precisions': np.asarray(precisions), + 'recalls': np.asarray(recalls), + 'f1s': np.asarray(recalls), + 'num_samples': batch_size + }) + + def evaluate(self): + if comm.get_world_size() > 1: + comm.synchronize() + predictions = comm.gather(self._predictions, dst=0) + predictions = list(itertools.chain(*predictions)) + + if not comm.is_main_process(): + return {} + else: + predictions = self._predictions + + total_acc = 0 + total_ap = 0 + total_auc = 0 + total_precisions = np.zeros((len(self._threshold_list,))) + total_recalls = np.zeros((len(self._threshold_list,))) + total_f1s = np.zeros((len(self._threshold_list,))) + total_samples = 0 + for prediction in predictions: + total_acc += prediction['acc'] + total_ap += prediction['ap'] + total_auc += prediction['auc'] + total_precisions += prediction['precisions'] + total_recalls += prediction['recalls'] + total_f1s += prediction['f1s'] + total_samples += prediction['num_samples'] + + acc = total_acc / total_samples + ap = total_ap / total_samples + auc = total_auc / total_samples + precisions = total_precisions / total_samples + recalls = total_recalls / total_samples + f1s = total_f1s / total_samples + + self._results = OrderedDict() + self._results['Acc'] = acc + self._results['Ap'] = ap + self._results['Auc'] = auc + self._results['Thresholds'] = self._threshold_list + self._results['Precisions'] = precisions + self._results['Recalls'] = recalls + self._results['F1_Scores'] = f1s + + return copy.deepcopy(self._results) + diff --git a/fastreid/evaluation/testing.py b/fastreid/evaluation/testing.py index cf4abc3..3cca0d6 100644 --- a/fastreid/evaluation/testing.py +++ b/fastreid/evaluation/testing.py @@ -21,18 +21,32 @@ def print_csv_format(results): logger = logging.getLogger(__name__) dataset_name = results.pop('dataset') - metrics = ["Dataset"] + [k for k in results] - csv_results = [(dataset_name, *list(results.values()))] + metrics = ["Dataset"] + [k for k, v in results.items() if not isinstance(v, (list, np.ndarray))] + csv_results = [[dataset_name] + [v for v in results.values() if not isinstance(v, (list, np.ndarray))]] # tabulate it table = tabulate( csv_results, tablefmt="pipe", - floatfmt=".2f", + floatfmt=".4f", headers=metrics, numalign="left", ) + logger.info("Evaluation results in csv format: \n" + colored(table, "cyan")) + # show precision, recall and f1 under given threshold + metrics = [k for k, v in results.items() if isinstance(v, (list, np.ndarray))] + csv_results = [v for v in results.values() if isinstance(v, (list, np.ndarray))] + csv_results = [v.tolist() if isinstance(v, np.ndarray) else v for v in csv_results] + csv_results = np.array(csv_results).T.tolist() + + table = tabulate( + csv_results, + tablefmt="pipe", + floatfmt=".4f", + headers=metrics, + numalign="left", + ) logger.info("Evaluation results in csv format: \n" + colored(table, "cyan")) @@ -85,3 +99,4 @@ def flatten_results_dict(results): else: r[k] = v return r + diff --git a/projects/FastShoe/fastshoe/trainer.py b/projects/FastShoe/fastshoe/trainer.py index 1abb2b8..3546d18 100644 --- a/projects/FastShoe/fastshoe/trainer.py +++ b/projects/FastShoe/fastshoe/trainer.py @@ -11,7 +11,7 @@ from fastreid.data.datasets import DATASET_REGISTRY from fastreid.utils import comm from fastreid.data.transforms import build_transforms from fastreid.data.build import build_reid_train_loader, build_reid_test_loader -from fastreid.evaluation.clas_evaluator import ClasEvaluator +from fastreid.evaluation.pair_evaluator import PairEvaluator from projects.FastShoe.fastshoe.data import PairDataset @@ -57,4 +57,4 @@ class PairTrainer(DefaultTrainer): @classmethod def build_evaluator(cls, cfg, dataset_name, output_dir=None): data_loader = cls.build_test_loader(cfg, dataset_name) - return data_loader, ClasEvaluator(cfg, output_dir) + return data_loader, PairEvaluator(cfg, output_dir)