add eps in attr_evaluation.py

pull/406/head
Jinkai Zheng 2021-02-05 21:52:39 +08:00
parent 254a489eb1
commit c9537c97d1
1 changed files with 8 additions and 6 deletions

View File

@ -36,6 +36,8 @@ class AttrEvaluator(DatasetEvaluator):
@staticmethod
def get_attr_metrics(gt_labels, pred_logits, thres):
eps = 1e-20
pred_labels = copy.deepcopy(pred_logits)
pred_labels[pred_logits < thres] = 0
pred_labels[pred_logits >= thres] = 1
@ -53,13 +55,13 @@ class AttrEvaluator(DatasetEvaluator):
gt_labels = gt_labels.astype(bool)
intersect = (pred_labels & gt_labels).astype(float)
union = (pred_labels | gt_labels).astype(float)
ins_acc = (intersect.sum(axis=1) / union.sum(axis=1)).mean()
ins_prec = (intersect.sum(axis=1) / pred_labels.astype(float).sum(axis=1)).mean()
ins_rec = (intersect.sum(axis=1) / gt_labels.astype(float).sum(axis=1)).mean()
ins_f1 = (2 * ins_prec * ins_rec) / (ins_prec + ins_rec)
ins_acc = (intersect.sum(axis=1) / (union.sum(axis=1) + eps)).mean()
ins_prec = (intersect.sum(axis=1) / (pred_labels.astype(float).sum(axis=1) + eps)).mean()
ins_rec = (intersect.sum(axis=1) / (gt_labels.astype(float).sum(axis=1) + eps)).mean()
ins_f1 = (2 * ins_prec * ins_rec) / (ins_prec + ins_rec + eps)
term1 = correct_pos / real_pos
term2 = correct_neg / real_neg
term1 = correct_pos / (real_pos + eps)
term2 = correct_neg / (real_neg + eps)
label_mA_verbose = (term1 + term2) * 0.5
label_mA = label_mA_verbose.mean()