mirror of https://github.com/JDAI-CV/fast-reid.git
103 lines
3.0 KiB
Python
103 lines
3.0 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
import logging
|
|
import pprint
|
|
import sys
|
|
from collections import Mapping, OrderedDict
|
|
|
|
import numpy as np
|
|
from tabulate import tabulate
|
|
from termcolor import colored
|
|
|
|
|
|
def print_csv_format(results):
|
|
"""
|
|
Print main metrics in a format similar to Detectron2,
|
|
so that they are easy to copypaste into a spreadsheet.
|
|
Args:
|
|
results (OrderedDict): {metric -> score}
|
|
"""
|
|
# unordered results cannot be properly printed
|
|
assert isinstance(results, OrderedDict) or not len(results), results
|
|
logger = logging.getLogger(__name__)
|
|
|
|
dataset_name = results.pop('dataset')
|
|
metrics = ["Dataset"] + [k for k, v in results.items() if not isinstance(v, (list, np.ndarray))]
|
|
csv_results = [[dataset_name] + [v for v in results.values() if not isinstance(v, (list, np.ndarray))]]
|
|
|
|
# tabulate it
|
|
table = tabulate(
|
|
csv_results,
|
|
tablefmt="pipe",
|
|
floatfmt=".4f",
|
|
headers=metrics,
|
|
numalign="left",
|
|
)
|
|
logger.info("Evaluation results in csv format: \n" + colored(table, "cyan"))
|
|
|
|
# show precision, recall and f1 under given threshold
|
|
metrics = [k for k, v in results.items() if isinstance(v, (list, np.ndarray))]
|
|
csv_results = [v for v in results.values() if isinstance(v, (list, np.ndarray))]
|
|
csv_results = [v.tolist() if isinstance(v, np.ndarray) else v for v in csv_results]
|
|
csv_results = np.array(csv_results).T.tolist()
|
|
|
|
table = tabulate(
|
|
csv_results,
|
|
tablefmt="pipe",
|
|
floatfmt=".4f",
|
|
headers=metrics,
|
|
numalign="left",
|
|
)
|
|
logger.info("Evaluation results in csv format: \n" + colored(table, "cyan"))
|
|
|
|
|
|
def verify_results(cfg, results):
|
|
"""
|
|
Args:
|
|
results (OrderedDict[dict]): task_name -> {metric -> score}
|
|
Returns:
|
|
bool: whether the verification succeeds or not
|
|
"""
|
|
expected_results = cfg.TEST.EXPECTED_RESULTS
|
|
if not len(expected_results):
|
|
return True
|
|
|
|
ok = True
|
|
for task, metric, expected, tolerance in expected_results:
|
|
actual = results[task][metric]
|
|
if not np.isfinite(actual):
|
|
ok = False
|
|
diff = abs(actual - expected)
|
|
if diff > tolerance:
|
|
ok = False
|
|
|
|
logger = logging.getLogger(__name__)
|
|
if not ok:
|
|
logger.error("Result verification failed!")
|
|
logger.error("Expected Results: " + str(expected_results))
|
|
logger.error("Actual Results: " + pprint.pformat(results))
|
|
|
|
sys.exit(1)
|
|
else:
|
|
logger.info("Results verification passed.")
|
|
return ok
|
|
|
|
|
|
def flatten_results_dict(results):
|
|
"""
|
|
Expand a hierarchical dict of scalars into a flat dict of scalars.
|
|
If results[k1][k2][k3] = v, the returned dict will have the entry
|
|
{"k1/k2/k3": v}.
|
|
Args:
|
|
results (dict):
|
|
"""
|
|
r = {}
|
|
for k, v in results.items():
|
|
if isinstance(v, Mapping):
|
|
v = flatten_results_dict(v)
|
|
for kk, vv in v.items():
|
|
r[k + "/" + kk] = vv
|
|
else:
|
|
r[k] = v
|
|
return r
|
|
|