mirror of https://github.com/JDAI-CV/fast-reid.git
102 lines
3.2 KiB
Python
102 lines
3.2 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: liaoxingyu
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
import copy
|
|
import logging
|
|
from collections import OrderedDict
|
|
|
|
import torch
|
|
|
|
from fastreid.evaluation.evaluator import DatasetEvaluator
|
|
from fastreid.utils import comm
|
|
|
|
logger = logging.getLogger("fastreid.attr_evaluation")
|
|
|
|
|
|
class AttrEvaluator(DatasetEvaluator):
|
|
def __init__(self, cfg, attr_dict, thres=0.5, output_dir=None):
|
|
self.cfg = cfg
|
|
self.attr_dict = attr_dict
|
|
self.thres = thres
|
|
self._output_dir = output_dir
|
|
|
|
self._cpu_device = torch.device("cpu")
|
|
|
|
self.pred_logits = []
|
|
self.gt_labels = []
|
|
|
|
def reset(self):
|
|
self.pred_logits = []
|
|
self.gt_labels = []
|
|
|
|
def process(self, inputs, outputs):
|
|
self.gt_labels.extend(inputs["targets"].to(self._cpu_device))
|
|
self.pred_logits.extend(outputs.to(self._cpu_device, torch.float32))
|
|
|
|
@staticmethod
|
|
def get_attr_metrics(gt_labels, pred_logits, thres):
|
|
|
|
eps = 1e-20
|
|
|
|
pred_labels = copy.deepcopy(pred_logits)
|
|
pred_labels[pred_logits < thres] = 0
|
|
pred_labels[pred_logits >= thres] = 1
|
|
|
|
# Compute label-based metric
|
|
overlaps = pred_labels * gt_labels
|
|
correct_pos = overlaps.sum(axis=0)
|
|
real_pos = gt_labels.sum(axis=0)
|
|
inv_overlaps = (1 - pred_labels) * (1 - gt_labels)
|
|
correct_neg = inv_overlaps.sum(axis=0)
|
|
real_neg = (1 - gt_labels).sum(axis=0)
|
|
|
|
# Compute instance-based accuracy
|
|
pred_labels = pred_labels.astype(bool)
|
|
gt_labels = gt_labels.astype(bool)
|
|
intersect = (pred_labels & gt_labels).astype(float)
|
|
union = (pred_labels | gt_labels).astype(float)
|
|
ins_acc = (intersect.sum(axis=1) / (union.sum(axis=1) + eps)).mean()
|
|
ins_prec = (intersect.sum(axis=1) / (pred_labels.astype(float).sum(axis=1) + eps)).mean()
|
|
ins_rec = (intersect.sum(axis=1) / (gt_labels.astype(float).sum(axis=1) + eps)).mean()
|
|
ins_f1 = (2 * ins_prec * ins_rec) / (ins_prec + ins_rec + eps)
|
|
|
|
term1 = correct_pos / (real_pos + eps)
|
|
term2 = correct_neg / (real_neg + eps)
|
|
label_mA_verbose = (term1 + term2) * 0.5
|
|
label_mA = label_mA_verbose.mean()
|
|
|
|
results = OrderedDict()
|
|
results["Accu"] = ins_acc * 100
|
|
results["Prec"] = ins_prec * 100
|
|
results["Recall"] = ins_rec * 100
|
|
results["F1"] = ins_f1 * 100
|
|
results["mA"] = label_mA * 100
|
|
results["metric"] = label_mA * 100
|
|
return results
|
|
|
|
def evaluate(self):
|
|
if comm.get_world_size() > 1:
|
|
comm.synchronize()
|
|
pred_logits = comm.gather(self.pred_logits)
|
|
pred_logits = sum(pred_logits, [])
|
|
|
|
gt_labels = comm.gather(self.gt_labels)
|
|
gt_labels = sum(gt_labels, [])
|
|
|
|
if not comm.is_main_process():
|
|
return {}
|
|
else:
|
|
pred_logits = self.pred_logits
|
|
gt_labels = self.gt_labels
|
|
|
|
pred_logits = torch.stack(pred_logits, dim=0).numpy()
|
|
gt_labels = torch.stack(gt_labels, dim=0).numpy()
|
|
|
|
# Pedestrian attribute metrics
|
|
thres = self.cfg.TEST.THRES
|
|
self._results = self.get_attr_metrics(gt_labels, pred_logits, thres)
|
|
|
|
return copy.deepcopy(self._results)
|