[Feature] Add --eval-options in test.py (#158)
* add --eval-options in test.py * fix typo * revise according to commnetspull/164/head
parent
7f49632d7c
commit
ddc2a14177
|
@ -115,7 +115,7 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
|
|||
def evaluate(self,
|
||||
results,
|
||||
metric='accuracy',
|
||||
metric_options={'topk': (1, 5)},
|
||||
metric_options=None,
|
||||
logger=None):
|
||||
"""Evaluate the dataset.
|
||||
|
||||
|
@ -123,13 +123,16 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
|
|||
results (list): Testing results of the dataset.
|
||||
metric (str | list[str]): Metrics to be evaluated.
|
||||
Default value is `accuracy`.
|
||||
metric_options (dict): Options for calculating metrics. Allowed
|
||||
keys are 'topk', 'thrs' and 'average_mode'.
|
||||
logger (logging.Logger | None | str): Logger used for printing
|
||||
related information during evaluation. Default: None.
|
||||
metric_options (dict, optional): Options for calculating metrics.
|
||||
Allowed keys are 'topk', 'thrs' and 'average_mode'.
|
||||
Defaults to None.
|
||||
logger (logging.Logger | str, optional): Logger used for printing
|
||||
related information during evaluation. Defaults to None.
|
||||
Returns:
|
||||
dict: evaluation results
|
||||
"""
|
||||
if metric_options is None:
|
||||
metric_options = {'topk': (1, 5)}
|
||||
if isinstance(metric, str):
|
||||
metrics = [metric]
|
||||
else:
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mmcls.core import average_performance, mAP
|
||||
|
@ -21,7 +23,12 @@ class MultiLabelDataset(BaseDataset):
|
|||
cat_ids = np.where(gt_labels == 1)[0]
|
||||
return cat_ids
|
||||
|
||||
def evaluate(self, results, metric='mAP', logger=None, **eval_kwargs):
|
||||
def evaluate(self,
|
||||
results,
|
||||
metric='mAP',
|
||||
metric_options=None,
|
||||
logger=None,
|
||||
**deprecated_kwargs):
|
||||
"""Evaluate the dataset.
|
||||
|
||||
Args:
|
||||
|
@ -29,11 +36,23 @@ class MultiLabelDataset(BaseDataset):
|
|||
metric (str | list[str]): Metrics to be evaluated.
|
||||
Default value is 'mAP'. Options are 'mAP', 'CP', 'CR', 'CF1',
|
||||
'OP', 'OR' and 'OF1'.
|
||||
logger (logging.Logger | None | str): Logger used for printing
|
||||
related information during evaluation. Default: None.
|
||||
metric_options (dict, optional): Options for calculating metrics.
|
||||
Allowed keys are 'k' and 'thr'. Defaults to None
|
||||
logger (logging.Logger | str, optional): Logger used for printing
|
||||
related information during evaluation. Defaults to None.
|
||||
deprecated_kwargs (dict): Used for containing deprecated arguments.
|
||||
|
||||
Returns:
|
||||
dict: evaluation results
|
||||
"""
|
||||
if metric_options is None:
|
||||
metric_options = {'thr': 0.5}
|
||||
|
||||
if deprecated_kwargs != {}:
|
||||
warnings.warn('Option arguments for metrics has been changed to '
|
||||
'`metric_options`.')
|
||||
metric_options = {**deprecated_kwargs}
|
||||
|
||||
if isinstance(metric, str):
|
||||
metrics = [metric]
|
||||
else:
|
||||
|
@ -48,7 +67,7 @@ class MultiLabelDataset(BaseDataset):
|
|||
|
||||
invalid_metrics = set(metrics) - set(allowed_metrics)
|
||||
if len(invalid_metrics) != 0:
|
||||
raise ValueError(f'metirc {invalid_metrics} is not supported.')
|
||||
raise ValueError(f'metric {invalid_metrics} is not supported.')
|
||||
|
||||
if 'mAP' in metrics:
|
||||
mAP_value = mAP(results, gt_labels)
|
||||
|
@ -56,7 +75,7 @@ class MultiLabelDataset(BaseDataset):
|
|||
if len(set(metrics) - {'mAP'}) != 0:
|
||||
performance_keys = ['CP', 'CR', 'CF1', 'OP', 'OR', 'OF1']
|
||||
performance_values = average_performance(results, gt_labels,
|
||||
**eval_kwargs)
|
||||
**metric_options)
|
||||
for k, v in zip(performance_keys, performance_values):
|
||||
if k in metrics:
|
||||
eval_results[k] = v
|
||||
|
|
|
@ -39,6 +39,14 @@ def parse_args():
|
|||
action=DictAction,
|
||||
help='override some settings in the used config, the key-value pair '
|
||||
'in xxx=yyy format will be merged into config file.')
|
||||
parser.add_argument(
|
||||
'--metric-options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
default={},
|
||||
help='custom options for evaluation, the key-value pair in xxx=yyy '
|
||||
'format will be parsed as a dict metric_options for dataset.evaluate()'
|
||||
' function.')
|
||||
parser.add_argument(
|
||||
'--launcher',
|
||||
choices=['none', 'pytorch', 'slurm', 'mpi'],
|
||||
|
@ -101,7 +109,8 @@ def main():
|
|||
rank, _ = get_dist_info()
|
||||
if rank == 0:
|
||||
if args.metrics:
|
||||
results = dataset.evaluate(outputs, args.metrics)
|
||||
results = dataset.evaluate(outputs, args.metrics,
|
||||
args.metric_options)
|
||||
for k, v in results.items():
|
||||
print(f'\n{k} : {v:.2f}')
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue