mirror of https://github.com/JDAI-CV/fast-reid.git
evaluation success
parent
786c442391
commit
1feda07ce6
|
@ -9,6 +9,7 @@ import tempfile
|
|||
import time
|
||||
from collections import Counter
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
@ -355,6 +356,11 @@ class EvalHook(HookBase):
|
|||
results, dict
|
||||
), "Eval function must return a dict. Got {} instead.".format(results)
|
||||
|
||||
# drop np.array values in results
|
||||
for k, v in list(results.items()):
|
||||
if isinstance(v, (list, np.ndarray)):
|
||||
results.pop(k)
|
||||
|
||||
flattened_results = flatten_results_dict(results)
|
||||
for k, v in flattened_results.items():
|
||||
try:
|
||||
|
|
|
@ -0,0 +1,109 @@
|
|||
# coding: utf-8
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fastreid.utils import comm
|
||||
from sklearn import metrics as skmetrics
|
||||
|
||||
from .clas_evaluator import ClasEvaluator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PairEvaluator(ClasEvaluator):
|
||||
def __init__(self, cfg, output_dir=None):
|
||||
super(PairEvaluator, self).__init__(cfg=cfg, output_dir=output_dir)
|
||||
self._threshold_list = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98]
|
||||
|
||||
def process(self, inputs, outputs):
|
||||
pred_logits = outputs.to(self._cpu_device, torch.float32)
|
||||
labels = inputs["targets"].to(self._cpu_device)
|
||||
|
||||
with torch.no_grad():
|
||||
probs = torch.softmax(pred_logits, dim=-1)
|
||||
probs, _ = torch.max(probs, dim=-1)
|
||||
|
||||
labels = labels.numpy()
|
||||
probs = probs.numpy()
|
||||
batch_size = probs.shape[0]
|
||||
|
||||
# 计算这3个总体值,还有给定阈值下的precision, recall, f1
|
||||
acc = skmetrics.accuracy_score(labels, probs > 0.5) * batch_size
|
||||
ap = skmetrics.average_precision_score(labels, probs) * batch_size
|
||||
auc = skmetrics.roc_auc_score(labels, probs) * batch_size # auc under roc
|
||||
|
||||
precisions = []
|
||||
recalls = []
|
||||
f1s = []
|
||||
for thresh in self._threshold_list:
|
||||
precision = skmetrics.precision_score(labels, probs >= thresh, zero_division=0) * batch_size
|
||||
recall = skmetrics.recall_score(labels, probs >= thresh, zero_division=0) * batch_size
|
||||
if precision + recall == 0:
|
||||
f1 = 0
|
||||
else:
|
||||
f1 = 2 * precision * recall / (precision + recall) * batch_size
|
||||
|
||||
precisions.append(precision)
|
||||
recalls.append(recall)
|
||||
f1s.append(f1)
|
||||
|
||||
self._predictions.append({
|
||||
'acc': acc,
|
||||
'ap': ap,
|
||||
'auc': auc,
|
||||
'precisions': np.asarray(precisions),
|
||||
'recalls': np.asarray(recalls),
|
||||
'f1s': np.asarray(recalls),
|
||||
'num_samples': batch_size
|
||||
})
|
||||
|
||||
def evaluate(self):
|
||||
if comm.get_world_size() > 1:
|
||||
comm.synchronize()
|
||||
predictions = comm.gather(self._predictions, dst=0)
|
||||
predictions = list(itertools.chain(*predictions))
|
||||
|
||||
if not comm.is_main_process():
|
||||
return {}
|
||||
else:
|
||||
predictions = self._predictions
|
||||
|
||||
total_acc = 0
|
||||
total_ap = 0
|
||||
total_auc = 0
|
||||
total_precisions = np.zeros((len(self._threshold_list,)))
|
||||
total_recalls = np.zeros((len(self._threshold_list,)))
|
||||
total_f1s = np.zeros((len(self._threshold_list,)))
|
||||
total_samples = 0
|
||||
for prediction in predictions:
|
||||
total_acc += prediction['acc']
|
||||
total_ap += prediction['ap']
|
||||
total_auc += prediction['auc']
|
||||
total_precisions += prediction['precisions']
|
||||
total_recalls += prediction['recalls']
|
||||
total_f1s += prediction['f1s']
|
||||
total_samples += prediction['num_samples']
|
||||
|
||||
acc = total_acc / total_samples
|
||||
ap = total_ap / total_samples
|
||||
auc = total_auc / total_samples
|
||||
precisions = total_precisions / total_samples
|
||||
recalls = total_recalls / total_samples
|
||||
f1s = total_f1s / total_samples
|
||||
|
||||
self._results = OrderedDict()
|
||||
self._results['Acc'] = acc
|
||||
self._results['Ap'] = ap
|
||||
self._results['Auc'] = auc
|
||||
self._results['Thresholds'] = self._threshold_list
|
||||
self._results['Precisions'] = precisions
|
||||
self._results['Recalls'] = recalls
|
||||
self._results['F1_Scores'] = f1s
|
||||
|
||||
return copy.deepcopy(self._results)
|
||||
|
|
@ -21,18 +21,32 @@ def print_csv_format(results):
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
dataset_name = results.pop('dataset')
|
||||
metrics = ["Dataset"] + [k for k in results]
|
||||
csv_results = [(dataset_name, *list(results.values()))]
|
||||
metrics = ["Dataset"] + [k for k, v in results.items() if not isinstance(v, (list, np.ndarray))]
|
||||
csv_results = [[dataset_name] + [v for v in results.values() if not isinstance(v, (list, np.ndarray))]]
|
||||
|
||||
# tabulate it
|
||||
table = tabulate(
|
||||
csv_results,
|
||||
tablefmt="pipe",
|
||||
floatfmt=".2f",
|
||||
floatfmt=".4f",
|
||||
headers=metrics,
|
||||
numalign="left",
|
||||
)
|
||||
logger.info("Evaluation results in csv format: \n" + colored(table, "cyan"))
|
||||
|
||||
# show precision, recall and f1 under given threshold
|
||||
metrics = [k for k, v in results.items() if isinstance(v, (list, np.ndarray))]
|
||||
csv_results = [v for v in results.values() if isinstance(v, (list, np.ndarray))]
|
||||
csv_results = [v.tolist() if isinstance(v, np.ndarray) else v for v in csv_results]
|
||||
csv_results = np.array(csv_results).T.tolist()
|
||||
|
||||
table = tabulate(
|
||||
csv_results,
|
||||
tablefmt="pipe",
|
||||
floatfmt=".4f",
|
||||
headers=metrics,
|
||||
numalign="left",
|
||||
)
|
||||
logger.info("Evaluation results in csv format: \n" + colored(table, "cyan"))
|
||||
|
||||
|
||||
|
@ -85,3 +99,4 @@ def flatten_results_dict(results):
|
|||
else:
|
||||
r[k] = v
|
||||
return r
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ from fastreid.data.datasets import DATASET_REGISTRY
|
|||
from fastreid.utils import comm
|
||||
from fastreid.data.transforms import build_transforms
|
||||
from fastreid.data.build import build_reid_train_loader, build_reid_test_loader
|
||||
from fastreid.evaluation.clas_evaluator import ClasEvaluator
|
||||
from fastreid.evaluation.pair_evaluator import PairEvaluator
|
||||
|
||||
from projects.FastShoe.fastshoe.data import PairDataset
|
||||
|
||||
|
@ -57,4 +57,4 @@ class PairTrainer(DefaultTrainer):
|
|||
@classmethod
|
||||
def build_evaluator(cls, cfg, dataset_name, output_dir=None):
|
||||
data_loader = cls.build_test_loader(cfg, dataset_name)
|
||||
return data_loader, ClasEvaluator(cfg, output_dir)
|
||||
return data_loader, PairEvaluator(cfg, output_dir)
|
||||
|
|
Loading…
Reference in New Issue