fast-reid/projects/FastAttr/fastattr/attr_evaluation.py

102 lines
3.2 KiB
Python

# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import copy
import logging
from collections import OrderedDict
import torch
from fastreid.evaluation.evaluator import DatasetEvaluator
from fastreid.utils import comm
logger = logging.getLogger("fastreid.attr_evaluation")
class AttrEvaluator(DatasetEvaluator):
def __init__(self, cfg, attr_dict, thres=0.5, output_dir=None):
self.cfg = cfg
self.attr_dict = attr_dict
self.thres = thres
self._output_dir = output_dir
self._cpu_device = torch.device("cpu")
self.pred_logits = []
self.gt_labels = []
def reset(self):
self.pred_logits = []
self.gt_labels = []
def process(self, inputs, outputs):
self.gt_labels.extend(inputs["targets"].to(self._cpu_device))
self.pred_logits.extend(outputs.to(self._cpu_device, torch.float32))
@staticmethod
def get_attr_metrics(gt_labels, pred_logits, thres):
eps = 1e-20
pred_labels = copy.deepcopy(pred_logits)
pred_labels[pred_logits < thres] = 0
pred_labels[pred_logits >= thres] = 1
# Compute label-based metric
overlaps = pred_labels * gt_labels
correct_pos = overlaps.sum(axis=0)
real_pos = gt_labels.sum(axis=0)
inv_overlaps = (1 - pred_labels) * (1 - gt_labels)
correct_neg = inv_overlaps.sum(axis=0)
real_neg = (1 - gt_labels).sum(axis=0)
# Compute instance-based accuracy
pred_labels = pred_labels.astype(bool)
gt_labels = gt_labels.astype(bool)
intersect = (pred_labels & gt_labels).astype(float)
union = (pred_labels | gt_labels).astype(float)
ins_acc = (intersect.sum(axis=1) / (union.sum(axis=1) + eps)).mean()
ins_prec = (intersect.sum(axis=1) / (pred_labels.astype(float).sum(axis=1) + eps)).mean()
ins_rec = (intersect.sum(axis=1) / (gt_labels.astype(float).sum(axis=1) + eps)).mean()
ins_f1 = (2 * ins_prec * ins_rec) / (ins_prec + ins_rec + eps)
term1 = correct_pos / (real_pos + eps)
term2 = correct_neg / (real_neg + eps)
label_mA_verbose = (term1 + term2) * 0.5
label_mA = label_mA_verbose.mean()
results = OrderedDict()
results["Accu"] = ins_acc * 100
results["Prec"] = ins_prec * 100
results["Recall"] = ins_rec * 100
results["F1"] = ins_f1 * 100
results["mA"] = label_mA * 100
results["metric"] = label_mA * 100
return results
def evaluate(self):
if comm.get_world_size() > 1:
comm.synchronize()
pred_logits = comm.gather(self.pred_logits)
pred_logits = sum(pred_logits, [])
gt_labels = comm.gather(self.gt_labels)
gt_labels = sum(gt_labels, [])
if not comm.is_main_process():
return {}
else:
pred_logits = self.pred_logits
gt_labels = self.gt_labels
pred_logits = torch.stack(pred_logits, dim=0).numpy()
gt_labels = torch.stack(gt_labels, dim=0).numpy()
# Pedestrian attribute metrics
thres = self.cfg.TEST.THRES
self._results = self.get_attr_metrics(gt_labels, pred_logits, thres)
return copy.deepcopy(self._results)