[Feature] Add --eval-options in test.py (#158)

* add --eval-options in test.py

* fix typo

* revise according to commnets
pull/164/head
LXXXXR 2021-02-05 17:46:43 +08:00 committed by GitHub
parent 7f49632d7c
commit ddc2a14177
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 42 additions and 11 deletions

View File

@ -115,7 +115,7 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
def evaluate(self,
results,
metric='accuracy',
metric_options={'topk': (1, 5)},
metric_options=None,
logger=None):
"""Evaluate the dataset.
@ -123,13 +123,16 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
results (list): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated.
Default value is `accuracy`.
metric_options (dict): Options for calculating metrics. Allowed
keys are 'topk', 'thrs' and 'average_mode'.
logger (logging.Logger | None | str): Logger used for printing
related information during evaluation. Default: None.
metric_options (dict, optional): Options for calculating metrics.
Allowed keys are 'topk', 'thrs' and 'average_mode'.
Defaults to None.
logger (logging.Logger | str, optional): Logger used for printing
related information during evaluation. Defaults to None.
Returns:
dict: evaluation results
"""
if metric_options is None:
metric_options = {'topk': (1, 5)}
if isinstance(metric, str):
metrics = [metric]
else:

View File

@ -1,3 +1,5 @@
import warnings
import numpy as np
from mmcls.core import average_performance, mAP
@ -21,7 +23,12 @@ class MultiLabelDataset(BaseDataset):
cat_ids = np.where(gt_labels == 1)[0]
return cat_ids
def evaluate(self, results, metric='mAP', logger=None, **eval_kwargs):
def evaluate(self,
results,
metric='mAP',
metric_options=None,
logger=None,
**deprecated_kwargs):
"""Evaluate the dataset.
Args:
@ -29,11 +36,23 @@ class MultiLabelDataset(BaseDataset):
metric (str | list[str]): Metrics to be evaluated.
Default value is 'mAP'. Options are 'mAP', 'CP', 'CR', 'CF1',
'OP', 'OR' and 'OF1'.
logger (logging.Logger | None | str): Logger used for printing
related information during evaluation. Default: None.
metric_options (dict, optional): Options for calculating metrics.
Allowed keys are 'k' and 'thr'. Defaults to None
logger (logging.Logger | str, optional): Logger used for printing
related information during evaluation. Defaults to None.
deprecated_kwargs (dict): Used for containing deprecated arguments.
Returns:
dict: evaluation results
"""
if metric_options is None:
metric_options = {'thr': 0.5}
if deprecated_kwargs != {}:
warnings.warn('Option arguments for metrics has been changed to '
'`metric_options`.')
metric_options = {**deprecated_kwargs}
if isinstance(metric, str):
metrics = [metric]
else:
@ -48,7 +67,7 @@ class MultiLabelDataset(BaseDataset):
invalid_metrics = set(metrics) - set(allowed_metrics)
if len(invalid_metrics) != 0:
raise ValueError(f'metirc {invalid_metrics} is not supported.')
raise ValueError(f'metric {invalid_metrics} is not supported.')
if 'mAP' in metrics:
mAP_value = mAP(results, gt_labels)
@ -56,7 +75,7 @@ class MultiLabelDataset(BaseDataset):
if len(set(metrics) - {'mAP'}) != 0:
performance_keys = ['CP', 'CR', 'CF1', 'OP', 'OR', 'OF1']
performance_values = average_performance(results, gt_labels,
**eval_kwargs)
**metric_options)
for k, v in zip(performance_keys, performance_values):
if k in metrics:
eval_results[k] = v

View File

@ -39,6 +39,14 @@ def parse_args():
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file.')
parser.add_argument(
'--metric-options',
nargs='+',
action=DictAction,
default={},
help='custom options for evaluation, the key-value pair in xxx=yyy '
'format will be parsed as a dict metric_options for dataset.evaluate()'
' function.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
@ -101,7 +109,8 @@ def main():
rank, _ = get_dist_info()
if rank == 0:
if args.metrics:
results = dataset.evaluate(outputs, args.metrics)
results = dataset.evaluate(outputs, args.metrics,
args.metric_options)
for k, v in results.items():
print(f'\n{k} : {v:.2f}')
else: