From cb09ed54e5adc7dab5a84a05126c4411f107c329 Mon Sep 17 00:00:00 2001 From: Ma Zerun Date: Wed, 15 Sep 2021 11:06:16 +0800 Subject: [PATCH] [Enhance] Support setting `--out-items` in `tools/test.py`. (#437) * Support setting out_details in `tools/test.py`. * Add assertion in `eval_metric` and `analyze_results`. --- tools/analysis_tools/analyze_results.py | 9 ++++-- tools/analysis_tools/eval_metric.py | 5 +++- tools/test.py | 37 ++++++++++++++++++------- 3 files changed, 38 insertions(+), 13 deletions(-) diff --git a/tools/analysis_tools/analyze_results.py b/tools/analysis_tools/analyze_results.py index 0029a028..9f63439e 100644 --- a/tools/analysis_tools/analyze_results.py +++ b/tools/analysis_tools/analyze_results.py @@ -67,6 +67,13 @@ def save_imgs(result_dir, folder_name, results, model): def main(): args = parse_args() + # load test results + outputs = mmcv.load(args.result) + assert ('pred_score' in outputs and 'pred_class' in outputs + and 'pred_label' in outputs), \ + 'No "pred_label", "pred_score" or "pred_class" in result file, ' \ + 'please set "--out-items" in test.py' + cfg = mmcv.Config.fromfile(args.config) if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) @@ -86,8 +93,6 @@ def main(): gt_labels = list(dataset.get_gt_labels()) gt_classes = [dataset.CLASSES[x] for x in gt_labels] - # load test results - outputs = mmcv.load(args.result) outputs['filename'] = filenames outputs['gt_label'] = gt_labels outputs['gt_class'] = gt_classes diff --git a/tools/analysis_tools/eval_metric.py b/tools/analysis_tools/eval_metric.py index fbc75edb..c5a5c7a6 100644 --- a/tools/analysis_tools/eval_metric.py +++ b/tools/analysis_tools/eval_metric.py @@ -41,6 +41,10 @@ def parse_args(): def main(): args = parse_args() + outputs = mmcv.load(args.pkl_results) + assert 'class_scores' in outputs, \ + 'No "class_scores" in result file, please set "--out-items" in test.py' + cfg = Config.fromfile(args.config) assert args.metrics, ( 'Please specify at least one metric the argument "--metrics".') @@ -54,7 +58,6 @@ def main(): cfg.data.test.test_mode = True dataset = build_dataset(cfg.data.test) - outputs = mmcv.load(args.pkl_results) pred_score = outputs['class_scores'] kwargs = {} if args.eval_options is None else args.eval_options diff --git a/tools/test.py b/tools/test.py index b553ac44..f62c4ccc 100644 --- a/tools/test.py +++ b/tools/test.py @@ -28,6 +28,17 @@ def parse_args(): parser.add_argument('config', help='test config file path') parser.add_argument('checkpoint', help='checkpoint file') parser.add_argument('--out', help='output result file') + out_options = ['class_scores', 'pred_score', 'pred_label', 'pred_class'] + parser.add_argument( + '--out-items', + nargs='+', + default=['all'], + choices=out_options + ['none', 'all'], + help='Besides metrics, what items will be included in the output ' + f'result file. You can choose some of ({", ".join(out_options)}), ' + 'or use "all" to include all above, or use "none" to disable all of ' + 'above. Defaults to output all.', + metavar='') parser.add_argument( '--metrics', type=str, @@ -177,16 +188,22 @@ def main(): for k, v in eval_results.items(): print(f'\n{k} : {v:.2f}') if args.out: - scores = np.vstack(outputs) - pred_score = np.max(scores, axis=1) - pred_label = np.argmax(scores, axis=1) - pred_class = [CLASSES[lb] for lb in pred_label] - results.update({ - 'class_scores': scores, - 'pred_score': pred_score, - 'pred_label': pred_label, - 'pred_class': pred_class - }) + if 'none' not in args.out_items: + scores = np.vstack(outputs) + pred_score = np.max(scores, axis=1) + pred_label = np.argmax(scores, axis=1) + pred_class = [CLASSES[lb] for lb in pred_label] + res_items = { + 'class_scores': scores, + 'pred_score': pred_score, + 'pred_label': pred_label, + 'pred_class': pred_class + } + if 'all' in args.out_items: + results.update(res_items) + else: + for key in args.out_items: + results[key] = res_items[key] print(f'\ndumping results to {args.out}') mmcv.dump(results, args.out)