fast-reid/projects/FastFace/fastface/face_evaluator.py

84 lines
2.2 KiB
Python

# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import copy
import io
import logging
import os
from collections import OrderedDict
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from PIL import Image
from fastreid.evaluation import DatasetEvaluator
from fastreid.utils import comm
from fastreid.utils.file_io import PathManager
from .verification import evaluate
logger = logging.getLogger("fastreid.face_evaluator")
def gen_plot(fpr, tpr):
"""Create a pyplot plot and save to buffer."""
plt.figure()
plt.xlabel("FPR", fontsize=14)
plt.ylabel("TPR", fontsize=14)
plt.title("ROC Curve", fontsize=14)
plt.plot(fpr, tpr, linewidth=2)
buf = io.BytesIO()
plt.savefig(buf, format='jpeg')
buf.seek(0)
plt.close()
return buf
class FaceEvaluator(DatasetEvaluator):
def __init__(self, cfg, labels, dataset_name, output_dir=None):
self.cfg = cfg
self.labels = labels
self.dataset_name = dataset_name
self._output_dir = output_dir
self.features = []
def reset(self):
self.features = []
def process(self, inputs, outputs):
self.features.append(outputs.cpu())
def evaluate(self):
if comm.get_world_size() > 1:
comm.synchronize()
features = comm.gather(self.features)
features = sum(features, [])
# fmt: off
if not comm.is_main_process(): return {}
# fmt: on
else:
features = self.features
features = torch.cat(features, dim=0)
features = F.normalize(features, p=2, dim=1).numpy()
self._results = OrderedDict()
tpr, fpr, accuracy, best_thresholds = evaluate(features, self.labels)
self._results["Accuracy"] = accuracy.mean() * 100
self._results["Threshold"] = best_thresholds.mean()
self._results["metric"] = accuracy.mean() * 100
buf = gen_plot(fpr, tpr)
roc_curve = Image.open(buf)
PathManager.mkdirs(self._output_dir)
roc_curve.save(os.path.join(self._output_dir, self.dataset_name + "_roc.png"))
return copy.deepcopy(self._results)