diff --git a/projects/FastAttr/fastattr/attr_evaluation.py b/projects/FastAttr/fastattr/attr_evaluation.py index 3ab7247..eb06dde 100644 --- a/projects/FastAttr/fastattr/attr_evaluation.py +++ b/projects/FastAttr/fastattr/attr_evaluation.py @@ -36,6 +36,8 @@ class AttrEvaluator(DatasetEvaluator): @staticmethod def get_attr_metrics(gt_labels, pred_logits, thres): + eps = 1e-20 + pred_labels = copy.deepcopy(pred_logits) pred_labels[pred_logits < thres] = 0 pred_labels[pred_logits >= thres] = 1 @@ -53,13 +55,13 @@ class AttrEvaluator(DatasetEvaluator): gt_labels = gt_labels.astype(bool) intersect = (pred_labels & gt_labels).astype(float) union = (pred_labels | gt_labels).astype(float) - ins_acc = (intersect.sum(axis=1) / union.sum(axis=1)).mean() - ins_prec = (intersect.sum(axis=1) / pred_labels.astype(float).sum(axis=1)).mean() - ins_rec = (intersect.sum(axis=1) / gt_labels.astype(float).sum(axis=1)).mean() - ins_f1 = (2 * ins_prec * ins_rec) / (ins_prec + ins_rec) + ins_acc = (intersect.sum(axis=1) / (union.sum(axis=1) + eps)).mean() + ins_prec = (intersect.sum(axis=1) / (pred_labels.astype(float).sum(axis=1) + eps)).mean() + ins_rec = (intersect.sum(axis=1) / (gt_labels.astype(float).sum(axis=1) + eps)).mean() + ins_f1 = (2 * ins_prec * ins_rec) / (ins_prec + ins_rec + eps) - term1 = correct_pos / real_pos - term2 = correct_neg / real_neg + term1 = correct_pos / (real_pos + eps) + term2 = correct_neg / (real_neg + eps) label_mA_verbose = (term1 + term2) * 0.5 label_mA = label_mA_verbose.mean()