fast-reid/fastreid/evaluation/testing.py

88 lines
2.4 KiB
Python
Raw Normal View History

2020-02-10 07:38:56 +08:00
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
import pprint
import sys
from collections import Mapping, OrderedDict
import numpy as np
2020-07-15 14:56:18 +08:00
from tabulate import tabulate
2020-07-17 19:13:45 +08:00
from termcolor import colored
2020-02-10 07:38:56 +08:00
def print_csv_format(results):
"""
2021-03-09 20:07:28 +08:00
Print main metrics in a format similar to Detectron2,
2020-02-10 07:38:56 +08:00
so that they are easy to copypaste into a spreadsheet.
Args:
2021-03-09 20:07:28 +08:00
results (OrderedDict): {metric -> score}
2020-02-10 07:38:56 +08:00
"""
2021-03-09 20:07:28 +08:00
# unordered results cannot be properly printed
assert isinstance(results, OrderedDict) or not len(results), results
logger = logging.getLogger(__name__)
2020-07-06 16:57:43 +08:00
2021-03-09 20:07:28 +08:00
dataset_name = results.pop('dataset')
metrics = ["Dataset"] + [k for k in results]
csv_results = [(dataset_name, *list(results.values()))]
2020-07-15 14:56:18 +08:00
# tabulate it
table = tabulate(
csv_results,
tablefmt="pipe",
2021-01-18 11:36:38 +08:00
floatfmt=".2f",
2020-07-15 14:56:18 +08:00
headers=metrics,
numalign="left",
)
2020-07-17 19:13:45 +08:00
logger.info("Evaluation results in csv format: \n" + colored(table, "cyan"))
2020-02-10 07:38:56 +08:00
def verify_results(cfg, results):
"""
Args:
results (OrderedDict[dict]): task_name -> {metric -> score}
Returns:
bool: whether the verification succeeds or not
"""
expected_results = cfg.TEST.EXPECTED_RESULTS
if not len(expected_results):
return True
ok = True
for task, metric, expected, tolerance in expected_results:
actual = results[task][metric]
if not np.isfinite(actual):
ok = False
diff = abs(actual - expected)
if diff > tolerance:
ok = False
logger = logging.getLogger(__name__)
if not ok:
logger.error("Result verification failed!")
logger.error("Expected Results: " + str(expected_results))
logger.error("Actual Results: " + pprint.pformat(results))
sys.exit(1)
else:
logger.info("Results verification passed.")
return ok
def flatten_results_dict(results):
"""
Expand a hierarchical dict of scalars into a flat dict of scalars.
If results[k1][k2][k3] = v, the returned dict will have the entry
{"k1/k2/k3": v}.
Args:
results (dict):
"""
r = {}
for k, v in results.items():
if isinstance(v, Mapping):
v = flatten_results_dict(v)
for kk, vv in v.items():
r[k + "/" + kk] = vv
else:
r[k] = v
return r