diff --git a/configs/_base_/schedules/schedule_160k.py b/configs/_base_/schedules/schedule_160k.py index 52603890b..39630f215 100644 --- a/configs/_base_/schedules/schedule_160k.py +++ b/configs/_base_/schedules/schedule_160k.py @@ -6,4 +6,4 @@ lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) # runtime settings runner = dict(type='IterBasedRunner', max_iters=160000) checkpoint_config = dict(by_epoch=False, interval=16000) -evaluation = dict(interval=16000, metric='mIoU') +evaluation = dict(interval=16000, metric='mIoU', pre_eval=True) diff --git a/configs/_base_/schedules/schedule_20k.py b/configs/_base_/schedules/schedule_20k.py index bf780a1b6..73c702197 100644 --- a/configs/_base_/schedules/schedule_20k.py +++ b/configs/_base_/schedules/schedule_20k.py @@ -6,4 +6,4 @@ lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) # runtime settings runner = dict(type='IterBasedRunner', max_iters=20000) checkpoint_config = dict(by_epoch=False, interval=2000) -evaluation = dict(interval=2000, metric='mIoU') +evaluation = dict(interval=2000, metric='mIoU', pre_eval=True) diff --git a/configs/_base_/schedules/schedule_40k.py b/configs/_base_/schedules/schedule_40k.py index cdbf841ab..d2c502325 100644 --- a/configs/_base_/schedules/schedule_40k.py +++ b/configs/_base_/schedules/schedule_40k.py @@ -6,4 +6,4 @@ lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) # runtime settings runner = dict(type='IterBasedRunner', max_iters=40000) checkpoint_config = dict(by_epoch=False, interval=4000) -evaluation = dict(interval=4000, metric='mIoU') +evaluation = dict(interval=4000, metric='mIoU', pre_eval=True) diff --git a/configs/_base_/schedules/schedule_80k.py b/configs/_base_/schedules/schedule_80k.py index c190cee6b..8365a878e 100644 --- a/configs/_base_/schedules/schedule_80k.py +++ b/configs/_base_/schedules/schedule_80k.py @@ -6,4 +6,4 @@ lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) # runtime settings runner = dict(type='IterBasedRunner', max_iters=80000) checkpoint_config = dict(by_epoch=False, interval=8000) -evaluation = dict(interval=8000, metric='mIoU') +evaluation = dict(interval=8000, metric='mIoU', pre_eval=True) diff --git a/docs/inference.md b/docs/inference.md index d7bc21b65..65f1e4602 100644 --- a/docs/inference.md +++ b/docs/inference.md @@ -21,11 +21,11 @@ python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}] [- Optional arguments: -- `RESULT_FILE`: Filename of the output results in pickle format. If not specified, the results will not be saved to a file. +- `RESULT_FILE`: Filename of the output results in pickle format. If not specified, the results will not be saved to a file. (After mmseg v0.17, the output results become pre-evaluation results or format result paths) - `EVAL_METRICS`: Items to be evaluated on the results. Allowed values depend on the dataset, e.g., `mIoU` is available for all dataset. Cityscapes could be evaluated by `cityscapes` as well as standard `mIoU` metrics. - `--show`: If specified, segmentation results will be plotted on the images and shown in a new window. It is only applicable to single GPU testing and used for debugging and visualization. Please make sure that GUI is available in your environment, otherwise you may encounter the error like `cannot connect to X server`. - `--show-dir`: If specified, segmentation results will be plotted on the images and saved to the specified directory. It is only applicable to single GPU testing and used for debugging and visualization. You do NOT need a GUI available in your environment for using this option. -- `--eval-options`: Optional parameters during evaluation. When `efficient_test=True`, it will save intermediate results to local files to save CPU memory. Make sure that you have enough local storage space (more than 20GB). +- `--eval-options`: Optional parameters for `dataset.format_results` and `dataset.evaluate` during evaluation. When `efficient_test=True`, it will save intermediate results to local files to save CPU memory. Make sure that you have enough local storage space (more than 20GB). (`efficient_test` argument does not have effect after mmseg v0.17, we use a progressive mode to evaluation and format results which can largely save memory cost and evaluation time.) Examples: @@ -98,4 +98,4 @@ Assume that you have already downloaded the checkpoints to the directory `checkp --eval mIoU ``` - Using ```pmap``` to view CPU memory footprint, it used 2.25GB CPU memory with ```efficient_test=True``` and 11.06GB CPU memory with ```efficient_test=False``` . This optional parameter can save a lot of memory. + Using ```pmap``` to view CPU memory footprint, it used 2.25GB CPU memory with ```efficient_test=True``` and 11.06GB CPU memory with ```efficient_test=False``` . This optional parameter can save a lot of memory. (After mmseg v0.17, efficient_test has not effect and we use a progressive mode to evaluation and format results efficiently by default.) diff --git a/docs_zh-CN/inference.md b/docs_zh-CN/inference.md index 85d9ff085..7d14bb980 100644 --- a/docs_zh-CN/inference.md +++ b/docs_zh-CN/inference.md @@ -20,11 +20,11 @@ python tools/test.py ${配置文件} ${检查点文件} [--out ${结果文件}] 可选参数: -- `RESULT_FILE`: pickle 格式的输出结果的文件名,如果不专门指定,结果将不会被专门保存成文件 -- `EVAL_METRICS`: 在结果里将被评估的指标,这主要取决于数据集, `mIoU` 对于所有数据集都可获得,像 Cityscapes 数据集可以通过 `cityscapes` 命令来专门评估,就像标准的 `mIoU`一样 -- `--show`: 如果被指定,分割结果将会在一张图像里画出来并且在另一个窗口展示,它仅仅是用来调试与可视化,并且仅针对单卡 GPU 测试,请确认 GUI 在您的环境里可用,否则您也许会遇到报错 `cannot connect to X server` -- `--show-dir`: 如果被指定,分割结果将会在一张图像里画出来并且保存在指定文件夹里,它仅仅是用来调试与可视化,并且仅针对单卡GPU测试,使用该参数时,您的环境不需要 GUI -- `--eval-options`: 评估时的可选参数,当设置 `efficient_test=True` 时,它将会保存中间结果至本地文件里以节约 CPU 内存,请确认您本地硬盘有足够的存储空间(大于20GB) +- `RESULT_FILE`: pickle 格式的输出结果的文件名,如果不专门指定,结果将不会被专门保存成文件。(MMseg v0.17 之后,args.out 将只会保存评估时的中间结果或者是分割图的保存路径。) +- `EVAL_METRICS`: 在结果里将被评估的指标。这主要取决于数据集, `mIoU` 对于所有数据集都可获得,像 Cityscapes 数据集可以通过 `cityscapes` 命令来专门评估,就像标准的 `mIoU`一样。 +- `--show`: 如果被指定,分割结果将会在一张图像里画出来并且在另一个窗口展示。它仅仅是用来调试与可视化,并且仅针对单卡 GPU 测试。请确认 GUI 在您的环境里可用,否则您也许会遇到报错 `cannot connect to X server` +- `--show-dir`: 如果被指定,分割结果将会在一张图像里画出来并且保存在指定文件夹里。它仅仅是用来调试与可视化,并且仅针对单卡GPU测试。使用该参数时,您的环境不需要 GUI。 +- `--eval-options`: 评估时的可选参数,当设置 `efficient_test=True` 时,它将会保存中间结果至本地文件里以节约 CPU 内存。请确认您本地硬盘有足够的存储空间(大于20GB)。(MMseg v0.17 之后,`efficient_test` 不再生效,我们重构了 test api,通过使用一种渐近式的方式来提升评估和保存结果的效率。) 例子: @@ -96,4 +96,4 @@ python tools/test.py ${配置文件} ${检查点文件} [--out ${结果文件}] --eval mIoU ``` - 使用 ```pmap``` 可查看 CPU 内存情况, ```efficient_test=True``` 会使用约 2.25GB 的 CPU 内存, ```efficient_test=False``` 会使用约 11.06GB 的 CPU 内存。 这个可选参数可以节约很多 CPU 内存。 + 使用 ```pmap``` 可查看 CPU 内存情况, ```efficient_test=True``` 会使用约 2.25GB 的 CPU 内存, ```efficient_test=False``` 会使用约 11.06GB 的 CPU 内存。 这个可选参数可以节约很多 CPU 内存。(MMseg v0.17 之后, `efficient_test` 参数将不再生效, 我们使用了一种渐近的方式来更加有效快速地评估和保存结果。) diff --git a/mmseg/apis/test.py b/mmseg/apis/test.py index fb0bb9361..2b11adfdc 100644 --- a/mmseg/apis/test.py +++ b/mmseg/apis/test.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp import tempfile +import warnings import mmcv import numpy as np @@ -19,7 +20,6 @@ def np2tmp(array, temp_file_name=None, tmpdir=None): function will generate a file name with tempfile.NamedTemporaryFile to save ndarray. Default: None. tmpdir (str): Temporary directory to save Ndarray files. Default: None. - Returns: str: The numpy file name. """ @@ -36,8 +36,11 @@ def single_gpu_test(model, show=False, out_dir=None, efficient_test=False, - opacity=0.5): - """Test with single GPU. + opacity=0.5, + pre_eval=False, + format_only=False, + format_args={}): + """Test with single GPU by progressive mode. Args: model (nn.Module): Model to be tested. @@ -46,24 +49,60 @@ def single_gpu_test(model, out_dir (str, optional): If specified, the results will be dumped into the directory to save output results. efficient_test (bool): Whether save the results as local numpy files to - save CPU memory during evaluation. Default: False. + save CPU memory during evaluation. Mutually exclusive with + pre_eval and format_results. Default: False. opacity(float): Opacity of painted segmentation map. Default 0.5. Must be in (0, 1] range. + pre_eval (bool): Use dataset.pre_eval() function to generate + pre_results for metric evaluation. Mutually exclusive with + efficient_test and format_results. Default: False. + format_only (bool): Only format result for results commit. + Mutually exclusive with pre_eval and efficient_test. + Default: False. + format_args (dict): The args for format_results. Default: {}. Returns: - list: The prediction results. + list: list of evaluation pre-results or list of save file names. """ + if efficient_test: + warnings.warn( + 'DeprecationWarning: ``efficient_test`` will be deprecated, the ' + 'evaluation is CPU memory friendly with pre_eval=True') + mmcv.mkdir_or_exist('.efficient_test') + # when none of them is set true, return segmentation results as + # a list of np.array. + assert [efficient_test, pre_eval, format_only].count(True) <= 1, \ + '``efficient_test``, ``pre_eval`` and ``format_only`` are mutually ' \ + 'exclusive, only one of them could be true .' model.eval() results = [] dataset = data_loader.dataset prog_bar = mmcv.ProgressBar(len(dataset)) - if efficient_test: - mmcv.mkdir_or_exist('.efficient_test') - for i, data in enumerate(data_loader): + # The pipeline about how the data_loader retrieval samples from dataset: + # sampler -> batch_sampler -> indices + # The indices are passed to dataset_fetcher to get data from dataset. + # data_fetcher -> collate_fn(dataset[index]) -> data_sample + # we use batch_sampler to get correct data idx + loader_indices = data_loader.batch_sampler + + for batch_indices, data in zip(loader_indices, data_loader): with torch.no_grad(): result = model(return_loss=False, **data) + if efficient_test: + result = [np2tmp(_, tmpdir='.efficient_test') for _ in result] + + if format_only: + result = dataset.format_results( + result, indices=batch_indices, **format_args) + if pre_eval: + # TODO: adapt samples_per_gpu > 1. + # only samples_per_gpu=1 valid now + result = dataset.pre_eval(result, indices=batch_indices) + + results.extend(result) + if show or out_dir: img_tensor = data['img'][0] img_metas = data['img_metas'][0].data[0] @@ -90,18 +129,10 @@ def single_gpu_test(model, out_file=out_file, opacity=opacity) - if isinstance(result, list): - if efficient_test: - result = [np2tmp(_, tmpdir='.efficient_test') for _ in result] - results.extend(result) - else: - if efficient_test: - result = np2tmp(result, tmpdir='.efficient_test') - results.append(result) - batch_size = len(result) for _ in range(batch_size): prog_bar.update() + return results @@ -109,8 +140,11 @@ def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False, - efficient_test=False): - """Test model with multiple gpus. + efficient_test=False, + pre_eval=False, + format_only=False, + format_args={}): + """Test model with multiple gpus by progressive mode. This method tests model with multiple gpus and collects the results under two different modes: gpu and cpu modes. By setting 'gpu_collect=True' @@ -123,39 +157,71 @@ def multi_gpu_test(model, data_loader (utils.data.Dataloader): Pytorch data loader. tmpdir (str): Path of directory to save the temporary results from different gpus under cpu mode. The same path is used for efficient - test. + test. Default: None. gpu_collect (bool): Option to use either gpu or cpu to collect results. + Default: False. efficient_test (bool): Whether save the results as local numpy files to - save CPU memory during evaluation. Default: False. + save CPU memory during evaluation. Mutually exclusive with + pre_eval and format_results. Default: False. + pre_eval (bool): Use dataset.pre_eval() function to generate + pre_results for metric evaluation. Mutually exclusive with + efficient_test and format_results. Default: False. + format_only (bool): Only format result for results commit. + Mutually exclusive with pre_eval and efficient_test. + Default: False. + format_args (dict): The args for format_results. Default: {}. Returns: - list: The prediction results. + list: list of evaluation pre-results or list of save file names. """ + if efficient_test: + warnings.warn( + 'DeprecationWarning: ``efficient_test`` will be deprecated, the ' + 'evaluation is CPU memory friendly with pre_eval=True') + mmcv.mkdir_or_exist('.efficient_test') + # when none of them is set true, return segmentation results as + # a list of np.array. + assert [efficient_test, pre_eval, format_only].count(True) <= 1, \ + '``efficient_test``, ``pre_eval`` and ``format_only`` are mutually ' \ + 'exclusive, only one of them could be true .' model.eval() results = [] dataset = data_loader.dataset + # The pipeline about how the data_loader retrieval samples from dataset: + # sampler -> batch_sampler -> indices + # The indices are passed to dataset_fetcher to get data from dataset. + # data_fetcher -> collate_fn(dataset[index]) -> data_sample + # we use batch_sampler to get correct data idx + + # batch_sampler based on DistributedSampler, the indices only point to data + # samples of related machine. + loader_indices = data_loader.batch_sampler + rank, world_size = get_dist_info() if rank == 0: prog_bar = mmcv.ProgressBar(len(dataset)) - if efficient_test: - mmcv.mkdir_or_exist('.efficient_test') - for i, data in enumerate(data_loader): + + for batch_indices, data in zip(loader_indices, data_loader): with torch.no_grad(): result = model(return_loss=False, rescale=True, **data) - if isinstance(result, list): - if efficient_test: - result = [np2tmp(_, tmpdir='.efficient_test') for _ in result] - results.extend(result) - else: - if efficient_test: - result = np2tmp(result, tmpdir='.efficient_test') - results.append(result) + if efficient_test: + result = [np2tmp(_, tmpdir='.efficient_test') for _ in result] + + if format_only: + result = dataset.format_results( + result, indices=batch_indices, **format_args) + if pre_eval: + # TODO: adapt samples_per_gpu > 1. + # only samples_per_gpu=1 valid now + result = dataset.pre_eval(result, indices=batch_indices) + + results.extend(result) if rank == 0: - batch_size = len(result) - for _ in range(batch_size * world_size): + batch_size = len(result) * world_size + for _ in range(batch_size): prog_bar.update() # collect results from all ranks diff --git a/mmseg/core/evaluation/__init__.py b/mmseg/core/evaluation/__init__.py index 237cf2476..3d16d17e5 100644 --- a/mmseg/core/evaluation/__init__.py +++ b/mmseg/core/evaluation/__init__.py @@ -1,9 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from .class_names import get_classes, get_palette from .eval_hooks import DistEvalHook, EvalHook -from .metrics import eval_metrics, mean_dice, mean_fscore, mean_iou +from .metrics import (eval_metrics, intersect_and_union, mean_dice, + mean_fscore, mean_iou, pre_eval_to_metrics) __all__ = [ 'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'mean_fscore', - 'eval_metrics', 'get_classes', 'get_palette' + 'eval_metrics', 'get_classes', 'get_palette', 'pre_eval_to_metrics', + 'intersect_and_union' ] diff --git a/mmseg/core/evaluation/eval_hooks.py b/mmseg/core/evaluation/eval_hooks.py index a2f08d775..952db3b0b 100644 --- a/mmseg/core/evaluation/eval_hooks.py +++ b/mmseg/core/evaluation/eval_hooks.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp +import warnings import torch.distributed as dist from mmcv.runner import DistEvalHook as _DistEvalHook @@ -16,15 +17,28 @@ class EvalHook(_EvalHook): Default: False. efficient_test (bool): Whether save the results as local numpy files to save CPU memory during evaluation. Default: False. + pre_eval (bool): Whether to use progressive mode to evaluate model. + Default: False. Returns: list: The prediction results. """ greater_keys = ['mIoU', 'mAcc', 'aAcc'] - def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs): + def __init__(self, + *args, + by_epoch=False, + efficient_test=False, + pre_eval=False, + **kwargs): super().__init__(*args, by_epoch=by_epoch, **kwargs) - self.efficient_test = efficient_test + self.pre_eval = pre_eval + if efficient_test: + warnings.warn( + 'DeprecationWarning: ``efficient_test`` for evaluation hook ' + 'is deprecated, the evaluation hook is CPU memory friendly ' + 'with ``pre_eval=True`` as argument for ``single_gpu_test()`` ' + 'function') def _do_evaluate(self, runner): """perform evaluation and save ckpt.""" @@ -33,10 +47,8 @@ class EvalHook(_EvalHook): from mmseg.apis import single_gpu_test results = single_gpu_test( - runner.model, - self.dataloader, - show=False, - efficient_test=self.efficient_test) + runner.model, self.dataloader, show=False, pre_eval=self.pre_eval) + runner.log_buffer.clear() runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) key_score = self.evaluate(runner, results) if self.save_best: @@ -52,15 +64,28 @@ class DistEvalHook(_DistEvalHook): Default: False. efficient_test (bool): Whether save the results as local numpy files to save CPU memory during evaluation. Default: False. + pre_eval (bool): Whether to use progressive mode to evaluate model. + Default: False. Returns: list: The prediction results. """ greater_keys = ['mIoU', 'mAcc', 'aAcc'] - def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs): + def __init__(self, + *args, + by_epoch=False, + efficient_test=False, + pre_eval=False, + **kwargs): super().__init__(*args, by_epoch=by_epoch, **kwargs) - self.efficient_test = efficient_test + self.pre_eval = pre_eval + if efficient_test: + warnings.warn( + 'DeprecationWarning: ``efficient_test`` for evaluation hook ' + 'is deprecated, the evaluation hook is CPU memory friendly ' + 'with ``pre_eval=True`` as argument for ``multi_gpu_test()`` ' + 'function') def _do_evaluate(self, runner): """perform evaluation and save ckpt.""" @@ -90,7 +115,10 @@ class DistEvalHook(_DistEvalHook): self.dataloader, tmpdir=tmpdir, gpu_collect=self.gpu_collect, - efficient_test=self.efficient_test) + pre_eval=self.pre_eval) + + runner.log_buffer.clear() + if runner.rank == 0: print('\n') runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) diff --git a/mmseg/core/evaluation/metrics.py b/mmseg/core/evaluation/metrics.py index 3c5f63fb4..f64967c6c 100644 --- a/mmseg/core/evaluation/metrics.py +++ b/mmseg/core/evaluation/metrics.py @@ -97,8 +97,8 @@ def total_intersect_and_union(results, Args: results (list[ndarray] | list[str]): List of prediction segmentation maps or list of prediction result filenames. - gt_seg_maps (list[ndarray] | list[str]): list of ground truth - segmentation maps or list of label filenames. + gt_seg_maps (list[ndarray] | list[str] | Iterables): list of ground + truth segmentation maps or list of label filenames. num_classes (int): Number of categories. ignore_index (int): Index that will be ignored in evaluation. label_map (dict): Mapping old labels to new labels. Default: dict(). @@ -113,15 +113,15 @@ def total_intersect_and_union(results, ndarray: The ground truth histogram on all classes. """ num_imgs = len(results) - assert len(gt_seg_maps) == num_imgs + assert len(list(gt_seg_maps)) == num_imgs total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64) total_area_union = torch.zeros((num_classes, ), dtype=torch.float64) total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64) total_area_label = torch.zeros((num_classes, ), dtype=torch.float64) - for i in range(num_imgs): + for result, gt_seg_map in zip(results, gt_seg_maps): area_intersect, area_union, area_pred_label, area_label = \ intersect_and_union( - results[i], gt_seg_maps[i], num_classes, ignore_index, + result, gt_seg_map, num_classes, ignore_index, label_map, reduce_zero_label) total_area_intersect += area_intersect total_area_union += area_union @@ -268,8 +268,8 @@ def eval_metrics(results, Args: results (list[ndarray] | list[str]): List of prediction segmentation maps or list of prediction result filenames. - gt_seg_maps (list[ndarray] | list[str]): list of ground truth - segmentation maps or list of label filenames. + gt_seg_maps (list[ndarray] | list[str] | Iterables): list of ground + truth segmentation maps or list of label filenames. num_classes (int): Number of categories. ignore_index (int): Index that will be ignored in evaluation. metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'. @@ -282,16 +282,86 @@ def eval_metrics(results, ndarray: Per category accuracy, shape (num_classes, ). ndarray: Per category evaluation metrics, shape (num_classes, ). """ + + total_area_intersect, total_area_union, total_area_pred_label, \ + total_area_label = total_intersect_and_union( + results, gt_seg_maps, num_classes, ignore_index, label_map, + reduce_zero_label) + ret_metrics = total_area_to_metrics(total_area_intersect, total_area_union, + total_area_pred_label, + total_area_label, metrics, nan_to_num, + beta) + + return ret_metrics + + +def pre_eval_to_metrics(pre_eval_results, + metrics=['mIoU'], + nan_to_num=None, + beta=1): + """Convert pre-eval results to metrics. + + Args: + pre_eval_results (list[tuple[torch.Tensor]]): per image eval results + for computing evaluation metric + metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'. + nan_to_num (int, optional): If specified, NaN values will be replaced + by the numbers defined by the user. Default: None. + Returns: + float: Overall accuracy on all images. + ndarray: Per category accuracy, shape (num_classes, ). + ndarray: Per category evaluation metrics, shape (num_classes, ). + """ + + # convert list of tuples to tuple of lists, e.g. + # [(A_1, B_1, C_1, D_1), ..., (A_n, B_n, C_n, D_n)] to + # ([A_1, ..., A_n], ..., [D_1, ..., D_n]) + pre_eval_results = tuple(zip(*pre_eval_results)) + assert len(pre_eval_results) == 4 + + total_area_intersect = sum(pre_eval_results[0]) + total_area_union = sum(pre_eval_results[1]) + total_area_pred_label = sum(pre_eval_results[2]) + total_area_label = sum(pre_eval_results[3]) + + ret_metrics = total_area_to_metrics(total_area_intersect, total_area_union, + total_area_pred_label, + total_area_label, metrics, nan_to_num, + beta) + + return ret_metrics + + +def total_area_to_metrics(total_area_intersect, + total_area_union, + total_area_pred_label, + total_area_label, + metrics=['mIoU'], + nan_to_num=None, + beta=1): + """Calculate evaluation metrics + Args: + total_area_intersect (ndarray): The intersection of prediction and + ground truth histogram on all classes. + total_area_union (ndarray): The union of prediction and ground truth + histogram on all classes. + total_area_pred_label (ndarray): The prediction histogram on all + classes. + total_area_label (ndarray): The ground truth histogram on all classes. + metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'. + nan_to_num (int, optional): If specified, NaN values will be replaced + by the numbers defined by the user. Default: None. + Returns: + float: Overall accuracy on all images. + ndarray: Per category accuracy, shape (num_classes, ). + ndarray: Per category evaluation metrics, shape (num_classes, ). + """ if isinstance(metrics, str): metrics = [metrics] allowed_metrics = ['mIoU', 'mDice', 'mFscore'] if not set(metrics).issubset(set(allowed_metrics)): raise KeyError('metrics {} is not supported'.format(metrics)) - total_area_intersect, total_area_union, total_area_pred_label, \ - total_area_label = total_intersect_and_union( - results, gt_seg_maps, num_classes, ignore_index, label_map, - reduce_zero_label) all_acc = total_area_intersect.sum() / total_area_label.sum() ret_metrics = OrderedDict({'aAcc': all_acc}) for metric in metrics: diff --git a/mmseg/datasets/ade.py b/mmseg/datasets/ade.py index 9af437126..d807a001a 100644 --- a/mmseg/datasets/ade.py +++ b/mmseg/datasets/ade.py @@ -1,6 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp -import tempfile import mmcv import numpy as np @@ -91,7 +90,7 @@ class ADE20KDataset(CustomDataset): reduce_zero_label=True, **kwargs) - def results2img(self, results, imgfile_prefix, to_label_id): + def results2img(self, results, imgfile_prefix, to_label_id, indices=None): """Write the segmentation results to images. Args: @@ -101,17 +100,21 @@ class ADE20KDataset(CustomDataset): If the prefix is "somepath/xxx", the png files will be named "somepath/xxx.png". to_label_id (bool): whether convert output to label_id for - submission + submission. + indices (list[int], optional): Indices of input results, if not + set, all the indices of the dataset will be used. + Default: None. Returns: list[str: str]: result txt files which contains corresponding semantic segmentation images. """ + if indices is None: + indices = list(range(len(self))) + mmcv.mkdir_or_exist(imgfile_prefix) result_files = [] - prog_bar = mmcv.ProgressBar(len(self)) - for idx in range(len(self)): - result = results[idx] + for result, idx in zip(results, indices): filename = self.img_infos[idx]['filename'] basename = osp.splitext(osp.basename(filename))[0] @@ -127,21 +130,25 @@ class ADE20KDataset(CustomDataset): output.save(png_filename) result_files.append(png_filename) - prog_bar.update() - return result_files - def format_results(self, results, imgfile_prefix=None, to_label_id=True): + def format_results(self, + results, + imgfile_prefix, + to_label_id=True, + indices=None): """Format the results into dir (standard format for ade20k evaluation). Args: results (list): Testing results of the dataset. imgfile_prefix (str | None): The prefix of images files. It includes the file path and the prefix of filename, e.g., - "a/b/prefix". If not specified, a temp file will be created. - Default: None. + "a/b/prefix". to_label_id (bool): whether convert output to label_id for submission. Default: False + indices (list[int], optional): Indices of input results, if not + set, all the indices of the dataset will be used. + Default: None. Returns: tuple: (result_files, tmp_dir), result_files is a list containing @@ -149,16 +156,12 @@ class ADE20KDataset(CustomDataset): for saving json/png files when img_prefix is not specified. """ - assert isinstance(results, list), 'results must be a list' - assert len(results) == len(self), ( - 'The length of results is not equal to the dataset len: ' - f'{len(results)} != {len(self)}') + if indices is None: + indices = list(range(len(self))) - if imgfile_prefix is None: - tmp_dir = tempfile.TemporaryDirectory() - imgfile_prefix = tmp_dir.name - else: - tmp_dir = None + assert isinstance(results, list), 'results must be a list.' + assert isinstance(indices, list), 'indices must be a list.' - result_files = self.results2img(results, imgfile_prefix, to_label_id) - return result_files, tmp_dir + result_files = self.results2img(results, imgfile_prefix, to_label_id, + indices) + return result_files diff --git a/mmseg/datasets/cityscapes.py b/mmseg/datasets/cityscapes.py index fd814f92c..5802622e7 100644 --- a/mmseg/datasets/cityscapes.py +++ b/mmseg/datasets/cityscapes.py @@ -1,6 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp -import tempfile import mmcv import numpy as np @@ -48,7 +47,7 @@ class CityscapesDataset(CustomDataset): return result_copy - def results2img(self, results, imgfile_prefix, to_label_id): + def results2img(self, results, imgfile_prefix, to_label_id, indices=None): """Write the segmentation results to images. Args: @@ -58,17 +57,21 @@ class CityscapesDataset(CustomDataset): If the prefix is "somepath/xxx", the png files will be named "somepath/xxx.png". to_label_id (bool): whether convert output to label_id for - submission + submission. + indices (list[int], optional): Indices of input results, + if not set, all the indices of the dataset will be used. + Default: None. Returns: list[str: str]: result txt files which contains corresponding semantic segmentation images. """ + if indices is None: + indices = list(range(len(self))) + mmcv.mkdir_or_exist(imgfile_prefix) result_files = [] - prog_bar = mmcv.ProgressBar(len(self)) - for idx in range(len(self)): - result = results[idx] + for result, idx in zip(results, indices): if to_label_id: result = self._convert_to_label_id(result) filename = self.img_infos[idx]['filename'] @@ -85,49 +88,49 @@ class CityscapesDataset(CustomDataset): output.putpalette(palette) output.save(png_filename) result_files.append(png_filename) - prog_bar.update() return result_files - def format_results(self, results, imgfile_prefix=None, to_label_id=True): + def format_results(self, + results, + imgfile_prefix, + to_label_id=True, + indices=None): """Format the results into dir (standard format for Cityscapes evaluation). Args: results (list): Testing results of the dataset. - imgfile_prefix (str | None): The prefix of images files. It + imgfile_prefix (str): The prefix of images files. It includes the file path and the prefix of filename, e.g., - "a/b/prefix". If not specified, a temp file will be created. - Default: None. + "a/b/prefix". to_label_id (bool): whether convert output to label_id for submission. Default: False + indices (list[int], optional): Indices of input results, + if not set, all the indices of the dataset will be used. + Default: None. Returns: tuple: (result_files, tmp_dir), result_files is a list containing the image paths, tmp_dir is the temporal directory created for saving json/png files when img_prefix is not specified. """ + if indices is None: + indices = list(range(len(self))) - assert isinstance(results, list), 'results must be a list' - assert len(results) == len(self), ( - 'The length of results is not equal to the dataset len: ' - f'{len(results)} != {len(self)}') + assert isinstance(results, list), 'results must be a list.' + assert isinstance(indices, list), 'indices must be a list.' - if imgfile_prefix is None: - tmp_dir = tempfile.TemporaryDirectory() - imgfile_prefix = tmp_dir.name - else: - tmp_dir = None - result_files = self.results2img(results, imgfile_prefix, to_label_id) + result_files = self.results2img(results, imgfile_prefix, to_label_id, + indices) - return result_files, tmp_dir + return result_files def evaluate(self, results, metric='mIoU', logger=None, - imgfile_prefix=None, - efficient_test=False): + imgfile_prefix=None): """Evaluation in Cityscapes/default protocol. Args: @@ -158,7 +161,7 @@ class CityscapesDataset(CustomDataset): if len(metrics) > 0: eval_results.update( super(CityscapesDataset, - self).evaluate(results, metrics, logger, efficient_test)) + self).evaluate(results, metrics, logger)) return eval_results @@ -184,12 +187,7 @@ class CityscapesDataset(CustomDataset): msg = '\n' + msg print_log(msg, logger=logger) - result_files, tmp_dir = self.format_results(results, imgfile_prefix) - - if tmp_dir is None: - result_dir = imgfile_prefix - else: - result_dir = tmp_dir.name + result_dir = imgfile_prefix eval_results = dict() print_log(f'Evaluating results under {result_dir} ...', logger=logger) @@ -212,7 +210,4 @@ class CityscapesDataset(CustomDataset): eval_results.update( CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args)) - if tmp_dir is not None: - tmp_dir.cleanup() - return eval_results diff --git a/mmseg/datasets/custom.py b/mmseg/datasets/custom.py index a86fabb97..e366c0da2 100644 --- a/mmseg/datasets/custom.py +++ b/mmseg/datasets/custom.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -import os import os.path as osp +import warnings from collections import OrderedDict from functools import reduce @@ -10,7 +10,7 @@ from mmcv.utils import print_log from prettytable import PrettyTable from torch.utils.data import Dataset -from mmseg.core import eval_metrics +from mmseg.core import eval_metrics, intersect_and_union, pre_eval_to_metrics from mmseg.utils import get_root_logger from .builder import DATASETS from .pipelines import Compose @@ -226,21 +226,55 @@ class CustomDataset(Dataset): self.pre_pipeline(results) return self.pipeline(results) - def format_results(self, results, **kwargs): + def format_results(self, results, imgfile_prefix, indices=None, **kwargs): """Place holder to format result to dataset specific output.""" + raise NotImplementedError - def get_gt_seg_maps(self, efficient_test=False): + def get_gt_seg_maps(self, efficient_test=None): """Get ground truth segmentation maps for evaluation.""" - gt_seg_maps = [] + if efficient_test is not None: + warnings.warn( + 'DeprecationWarning: ``efficient_test`` has been deprecated ' + 'since MMSeg v0.16, the ``get_gt_seg_maps()`` is CPU memory ' + 'friendly by default. ') + for img_info in self.img_infos: seg_map = osp.join(self.ann_dir, img_info['ann']['seg_map']) - if efficient_test: - gt_seg_map = seg_map - else: - gt_seg_map = mmcv.imread( - seg_map, flag='unchanged', backend='pillow') - gt_seg_maps.append(gt_seg_map) - return gt_seg_maps + gt_seg_map = mmcv.imread( + seg_map, flag='unchanged', backend='pillow') + yield gt_seg_map + + def pre_eval(self, preds, indices): + """Collect eval result from each iteration. + + Args: + preds (list[torch.Tensor] | torch.Tensor): the segmentation logit + after argmax, shape (N, H, W). + indices (list[int] | int): the prediction related ground truth + indices. + + Returns: + list[torch.Tensor]: (area_intersect, area_union, area_prediction, + area_ground_truth). + """ + # In order to compat with batch inference + if not isinstance(indices, list): + indices = [indices] + if not isinstance(preds, list): + preds = [preds] + + pre_eval_results = [] + + for pred, index in zip(preds, indices): + seg_map = osp.join(self.ann_dir, + self.img_infos[index]['ann']['seg_map']) + seg_map = mmcv.imread(seg_map, flag='unchanged', backend='pillow') + pre_eval_results.append( + intersect_and_union(pred, seg_map, len(self.CLASSES), + self.ignore_index, self.label_map, + self.reduce_zero_label)) + + return pre_eval_results def get_classes_and_palette(self, classes=None, palette=None): """Get class names of current dataset. @@ -305,16 +339,13 @@ class CustomDataset(Dataset): return palette - def evaluate(self, - results, - metric='mIoU', - logger=None, - efficient_test=False, - **kwargs): + def evaluate(self, results, metric='mIoU', logger=None, **kwargs): """Evaluate the dataset. Args: - results (list): Testing results of the dataset. + results (list[tuple[torch.Tensor]] | list[str]): per image pre_eval + results or predict segmentation map for computing evaluation + metric. metric (str | list[str]): Metrics to be evaluated. 'mIoU', 'mDice' and 'mFscore' are supported. logger (logging.Logger | None | str): Logger used for printing @@ -323,28 +354,37 @@ class CustomDataset(Dataset): Returns: dict[str, float]: Default metrics. """ - if isinstance(metric, str): metric = [metric] allowed_metrics = ['mIoU', 'mDice', 'mFscore'] if not set(metric).issubset(set(allowed_metrics)): raise KeyError('metric {} is not supported'.format(metric)) - eval_results = {} - gt_seg_maps = self.get_gt_seg_maps(efficient_test) - if self.CLASSES is None: - num_classes = len( - reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps])) - else: - num_classes = len(self.CLASSES) - ret_metrics = eval_metrics( - results, - gt_seg_maps, - num_classes, - self.ignore_index, - metric, - label_map=self.label_map, - reduce_zero_label=self.reduce_zero_label) + eval_results = {} + # test a list of files + if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of( + results, str): + gt_seg_maps = self.get_gt_seg_maps() + if self.CLASSES is None: + num_classes = len( + reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps])) + else: + num_classes = len(self.CLASSES) + # reset generator + gt_seg_maps = self.get_gt_seg_maps() + ret_metrics = eval_metrics( + results, + gt_seg_maps, + num_classes, + self.ignore_index, + metric, + label_map=self.label_map, + reduce_zero_label=self.reduce_zero_label) + # test a list of pre_eval_results + else: + ret_metrics = pre_eval_to_metrics(results, metric) + + # Because dataset.CLASSES is required for per-eval. if self.CLASSES is None: class_names = tuple(range(num_classes)) else: @@ -396,7 +436,4 @@ class CustomDataset(Dataset): for idx, name in enumerate(class_names) }) - if mmcv.is_list_of(results, str): - for file_name in results: - os.remove(file_name) return eval_results diff --git a/tests/data/pseudo_cityscapes_dataset/gtFine/frankfurt_000000_000294_gtFine_instanceIds.png b/tests/data/pseudo_cityscapes_dataset/gtFine/frankfurt_000000_000294_gtFine_instanceIds.png new file mode 100644 index 000000000..dfe7aea9b Binary files /dev/null and b/tests/data/pseudo_cityscapes_dataset/gtFine/frankfurt_000000_000294_gtFine_instanceIds.png differ diff --git a/tests/data/pseudo_cityscapes_dataset/gtFine/frankfurt_000000_000294_gtFine_labelIds.png b/tests/data/pseudo_cityscapes_dataset/gtFine/frankfurt_000000_000294_gtFine_labelIds.png new file mode 100644 index 000000000..faab6f554 Binary files /dev/null and b/tests/data/pseudo_cityscapes_dataset/gtFine/frankfurt_000000_000294_gtFine_labelIds.png differ diff --git a/tests/data/pseudo_cityscapes_dataset/gtFine/frankfurt_000000_000294_gtFine_labelTrainIds.png b/tests/data/pseudo_cityscapes_dataset/gtFine/frankfurt_000000_000294_gtFine_labelTrainIds.png new file mode 100644 index 000000000..659229b92 Binary files /dev/null and b/tests/data/pseudo_cityscapes_dataset/gtFine/frankfurt_000000_000294_gtFine_labelTrainIds.png differ diff --git a/tests/data/pseudo_cityscapes_dataset/leftImg8bit/frankfurt_000000_000294_leftImg8bit.png b/tests/data/pseudo_cityscapes_dataset/leftImg8bit/frankfurt_000000_000294_leftImg8bit.png new file mode 100644 index 000000000..2c83ee4f5 Binary files /dev/null and b/tests/data/pseudo_cityscapes_dataset/leftImg8bit/frankfurt_000000_000294_leftImg8bit.png differ diff --git a/tests/test_apis/test_single_gpu.py b/tests/test_apis/test_single_gpu.py new file mode 100644 index 000000000..b741896e5 --- /dev/null +++ b/tests/test_apis/test_single_gpu.py @@ -0,0 +1,72 @@ +import shutil +from unittest.mock import MagicMock + +import numpy as np +import pytest +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, Dataset, dataloader + +from mmseg.apis import single_gpu_test + + +class ExampleDataset(Dataset): + + def __getitem__(self, idx): + results = dict(img=torch.tensor([1]), img_metas=dict()) + return results + + def __len__(self): + return 1 + + +class ExampleModel(nn.Module): + + def __init__(self): + super(ExampleModel, self).__init__() + self.test_cfg = None + self.conv = nn.Conv2d(3, 3, 3) + + def forward(self, img, img_metas, return_loss=False, **kwargs): + return img + + +def test_single_gpu(): + test_dataset = ExampleDataset() + data_loader = DataLoader( + test_dataset, + batch_size=1, + sampler=None, + num_workers=0, + shuffle=False, + ) + model = ExampleModel() + + # Test efficient test compatibility (will be deprecated) + results = single_gpu_test(model, data_loader, efficient_test=True) + assert len(results) == 1 + pred = np.load(results[0]) + assert isinstance(pred, np.ndarray) + assert pred.shape == (1, ) + assert pred[0] == 1 + + shutil.rmtree('.efficient_test') + + # Test pre_eval + test_dataset.pre_eval = MagicMock(return_value=['success']) + results = single_gpu_test(model, data_loader, pre_eval=True) + assert results == ['success'] + + # Test format_only + test_dataset.format_results = MagicMock(return_value=['success']) + results = single_gpu_test(model, data_loader, format_only=True) + assert results == ['success'] + + # efficient_test, pre_eval and format_only are mutually exclusive + with pytest.raises(AssertionError): + single_gpu_test( + model, + dataloader, + efficient_test=True, + format_only=True, + pre_eval=True) diff --git a/tests/test_data/test_dataset.py b/tests/test_data/test_dataset.py index 7ef59f27d..ebc173669 100644 --- a/tests/test_data/test_dataset.py +++ b/tests/test_data/test_dataset.py @@ -1,9 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp +import shutil +from typing import Generator from unittest.mock import MagicMock, patch import numpy as np import pytest +from PIL import Image from mmseg.core.evaluation import get_classes, get_palette from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset, @@ -152,10 +155,16 @@ def test_custom_dataset(): assert isinstance(test_data, dict) # get gt seg map - gt_seg_maps = train_dataset.get_gt_seg_maps() + gt_seg_maps = train_dataset.get_gt_seg_maps(efficient_test=True) + assert isinstance(gt_seg_maps, Generator) + gt_seg_maps = list(gt_seg_maps) assert len(gt_seg_maps) == 5 - # evaluation + # format_results not implemented + with pytest.raises(NotImplementedError): + test_dataset.format_results([], '') + + # test past evaluation pseudo_results = [] for gt_seg_map in gt_seg_maps: h, w = gt_seg_map.shape @@ -180,7 +189,7 @@ def test_custom_dataset(): assert 'mAcc' in eval_results assert 'aAcc' in eval_results - # evaluation with CLASSES + # test past evaluation with CLASSES train_dataset.CLASSES = tuple(['a'] * 7) eval_results = train_dataset.evaluate(pseudo_results, metric='mIoU') assert isinstance(eval_results, dict) @@ -212,6 +221,95 @@ def test_custom_dataset(): assert 'mPrecision' in eval_results assert 'mRecall' in eval_results + # test evaluation with pre-eval and the dataset.CLASSES is necessary + train_dataset.CLASSES = tuple(['a'] * 7) + pseudo_results = [] + for idx in range(len(train_dataset)): + h, w = gt_seg_maps[idx].shape + pseudo_result = np.random.randint(low=0, high=7, size=(h, w)) + pseudo_results.extend(train_dataset.pre_eval(pseudo_result, idx)) + eval_results = train_dataset.evaluate(pseudo_results, metric=['mIoU']) + assert isinstance(eval_results, dict) + assert 'mIoU' in eval_results + assert 'mAcc' in eval_results + assert 'aAcc' in eval_results + + eval_results = train_dataset.evaluate(pseudo_results, metric='mDice') + assert isinstance(eval_results, dict) + assert 'mDice' in eval_results + assert 'mAcc' in eval_results + assert 'aAcc' in eval_results + + eval_results = train_dataset.evaluate(pseudo_results, metric='mFscore') + assert isinstance(eval_results, dict) + assert 'mRecall' in eval_results + assert 'mPrecision' in eval_results + assert 'mFscore' in eval_results + assert 'aAcc' in eval_results + + eval_results = train_dataset.evaluate( + pseudo_results, metric=['mIoU', 'mDice', 'mFscore']) + assert isinstance(eval_results, dict) + assert 'mIoU' in eval_results + assert 'mDice' in eval_results + assert 'mAcc' in eval_results + assert 'aAcc' in eval_results + assert 'mFscore' in eval_results + assert 'mPrecision' in eval_results + assert 'mRecall' in eval_results + + +def test_ade(): + test_dataset = ADE20KDataset( + pipeline=[], + img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs')) + assert len(test_dataset) == 5 + + # Test format_results + pseudo_results = [] + for _ in range(len(test_dataset)): + h, w = (2, 2) + pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w))) + + file_paths = test_dataset.format_results(pseudo_results, '.format_ade') + assert len(file_paths) == len(test_dataset) + temp = np.array(Image.open(file_paths[0])) + assert np.allclose(temp, pseudo_results[0] + 1) + + shutil.rmtree('.format_ade') + + +def test_cityscapes(): + test_dataset = CityscapesDataset( + pipeline=[], + img_dir=osp.join( + osp.dirname(__file__), + '../data/pseudo_cityscapes_dataset/leftImg8bit'), + ann_dir=osp.join( + osp.dirname(__file__), '../data/pseudo_cityscapes_dataset/gtFine')) + assert len(test_dataset) == 1 + + gt_seg_maps = list(test_dataset.get_gt_seg_maps()) + + # Test format_results + pseudo_results = [] + for idx in range(len(test_dataset)): + h, w = gt_seg_maps[idx].shape + pseudo_results.append(np.random.randint(low=0, high=19, size=(h, w))) + + file_paths = test_dataset.format_results(pseudo_results, '.format_city') + assert len(file_paths) == len(test_dataset) + temp = np.array(Image.open(file_paths[0])) + assert np.allclose(temp, + test_dataset._convert_to_label_id(pseudo_results[0])) + + # Test cityscapes evaluate + + test_dataset.evaluate( + pseudo_results, metric='cityscapes', imgfile_prefix='.format_city') + + shutil.rmtree('.format_city') + @patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock) @patch('mmseg.datasets.CustomDataset.__getitem__', diff --git a/tests/test_eval_hook.py b/tests/test_eval_hook.py index 54d2a4353..5267438c3 100644 --- a/tests/test_eval_hook.py +++ b/tests/test_eval_hook.py @@ -53,6 +53,7 @@ def test_iter_eval_hook(): EvalHook(data_loader) test_dataset = ExampleDataset() + test_dataset.pre_eval = MagicMock(return_value=[torch.tensor([1])]) test_dataset.evaluate = MagicMock(return_value=dict(test='success')) loader = DataLoader(test_dataset, batch_size=1) model = ExampleModel() @@ -64,7 +65,7 @@ def test_iter_eval_hook(): # test EvalHook with tempfile.TemporaryDirectory() as tmpdir: - eval_hook = EvalHook(data_loader, by_epoch=False) + eval_hook = EvalHook(data_loader, by_epoch=False, efficient_test=True) runner = mmcv.runner.IterBasedRunner( model=model, optimizer=optimizer, @@ -90,6 +91,7 @@ def test_epoch_eval_hook(): EvalHook(data_loader, by_epoch=True) test_dataset = ExampleDataset() + test_dataset.pre_eval = MagicMock(return_value=[torch.tensor([1])]) test_dataset.evaluate = MagicMock(return_value=dict(test='success')) loader = DataLoader(test_dataset, batch_size=1) model = ExampleModel() @@ -117,8 +119,9 @@ def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False, - efficient_test=False): - results = single_gpu_test(model, data_loader) + pre_eval=False): + # Pre eval is set by default when training. + results = single_gpu_test(model, data_loader, pre_eval=True) return results @@ -137,6 +140,7 @@ def test_dist_eval_hook(): DistEvalHook(data_loader) test_dataset = ExampleDataset() + test_dataset.pre_eval = MagicMock(return_value=[torch.tensor([1])]) test_dataset.evaluate = MagicMock(return_value=dict(test='success')) loader = DataLoader(test_dataset, batch_size=1) model = ExampleModel() @@ -148,7 +152,8 @@ def test_dist_eval_hook(): # test DistEvalHook with tempfile.TemporaryDirectory() as tmpdir: - eval_hook = DistEvalHook(data_loader, by_epoch=False) + eval_hook = DistEvalHook( + data_loader, by_epoch=False, efficient_test=True) runner = mmcv.runner.IterBasedRunner( model=model, optimizer=optimizer, @@ -175,6 +180,7 @@ def test_dist_eval_hook_epoch(): DistEvalHook(data_loader) test_dataset = ExampleDataset() + test_dataset.pre_eval = MagicMock(return_value=[torch.tensor([1])]) test_dataset.evaluate = MagicMock(return_value=dict(test='success')) loader = DataLoader(test_dataset, batch_size=1) model = ExampleModel() diff --git a/tools/deploy_test.py b/tools/deploy_test.py index 6e709b8c9..593532c0b 100644 --- a/tools/deploy_test.py +++ b/tools/deploy_test.py @@ -2,6 +2,7 @@ import argparse import os import os.path as osp +import shutil import warnings from typing import Any, Iterable @@ -234,24 +235,61 @@ def main(): model.CLASSES = dataset.CLASSES model.PALETTE = dataset.PALETTE - efficient_test = False - if args.eval_options is not None: - efficient_test = args.eval_options.get('efficient_test', False) + # clean gpu memory when starting a new evaluation. + torch.cuda.empty_cache() + eval_kwargs = {} if args.eval_options is None else args.eval_options + + # Deprecated + efficient_test = eval_kwargs.get('efficient_test', False) + if efficient_test: + warnings.warn( + '``efficient_test=True`` does not have effect in tools/test.py, ' + 'the evaluation and format results are CPU memory efficient by ' + 'default') + + eval_on_format_results = ( + args.eval is not None and 'cityscapes' in args.eval) + if eval_on_format_results: + assert len(args.eval) == 1, 'eval on format results is not ' \ + 'applicable for metrics other than ' \ + 'cityscapes' + if args.format_only or eval_on_format_results: + if 'imgfile_prefix' in eval_kwargs: + tmpdir = eval_kwargs['imgfile_prefix'] + else: + tmpdir = '.format_cityscapes' + eval_kwargs.setdefault('imgfile_prefix', tmpdir) + mmcv.mkdir_or_exist(tmpdir) + else: + tmpdir = None model = MMDataParallel(model, device_ids=[0]) - outputs = single_gpu_test(model, data_loader, args.show, args.show_dir, - efficient_test, args.opacity) + results = single_gpu_test( + model, + data_loader, + args.show, + args.show_dir, + False, + args.opacity, + pre_eval=args.eval is not None and not eval_on_format_results, + format_only=args.format_only or eval_on_format_results, + format_args=eval_kwargs) rank, _ = get_dist_info() if rank == 0: if args.out: + warnings.warn( + 'The behavior of ``args.out`` has been changed since MMSeg ' + 'v0.16, the pickled outputs could be seg map as type of ' + 'np.array, pre-eval results or file paths for ' + '``dataset.format_results()``.') print(f'\nwriting results to {args.out}') - mmcv.dump(outputs, args.out) - kwargs = {} if args.eval_options is None else args.eval_options - if args.format_only: - dataset.format_results(outputs, **kwargs) + mmcv.dump(results, args.out) if args.eval: - dataset.evaluate(outputs, args.eval, **kwargs) + dataset.evaluate(results, args.eval, **eval_kwargs) + if tmpdir is not None and eval_on_format_results: + # remove tmp dir when cityscapes evaluation + shutil.rmtree(tmpdir) if __name__ == '__main__': diff --git a/tools/test.py b/tools/test.py index 87bd3659d..7420a44ad 100644 --- a/tools/test.py +++ b/tools/test.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse import os +import shutil +import warnings import mmcv import torch @@ -134,32 +136,76 @@ def main(): print('"PALETTE" not found in meta, use dataset.PALETTE instead') model.PALETTE = dataset.PALETTE - efficient_test = False - if args.eval_options is not None: - efficient_test = args.eval_options.get('efficient_test', False) + # clean gpu memory when starting a new evaluation. + torch.cuda.empty_cache() + eval_kwargs = {} if args.eval_options is None else args.eval_options + + # Deprecated + efficient_test = eval_kwargs.get('efficient_test', False) + if efficient_test: + warnings.warn( + '``efficient_test=True`` does not have effect in tools/test.py, ' + 'the evaluation and format results are CPU memory efficient by ' + 'default') + + eval_on_format_results = ( + args.eval is not None and 'cityscapes' in args.eval) + if eval_on_format_results: + assert len(args.eval) == 1, 'eval on format results is not ' \ + 'applicable for metrics other than ' \ + 'cityscapes' + if args.format_only or eval_on_format_results: + if 'imgfile_prefix' in eval_kwargs: + tmpdir = eval_kwargs['imgfile_prefix'] + else: + tmpdir = '.format_cityscapes' + eval_kwargs.setdefault('imgfile_prefix', tmpdir) + mmcv.mkdir_or_exist(tmpdir) + else: + tmpdir = None if not distributed: model = MMDataParallel(model, device_ids=[0]) - outputs = single_gpu_test(model, data_loader, args.show, args.show_dir, - efficient_test, args.opacity) + results = single_gpu_test( + model, + data_loader, + args.show, + args.show_dir, + False, + args.opacity, + pre_eval=args.eval is not None and not eval_on_format_results, + format_only=args.format_only or eval_on_format_results, + format_args=eval_kwargs) else: model = MMDistributedDataParallel( model.cuda(), device_ids=[torch.cuda.current_device()], broadcast_buffers=False) - outputs = multi_gpu_test(model, data_loader, args.tmpdir, - args.gpu_collect, efficient_test) + results = multi_gpu_test( + model, + data_loader, + args.tmpdir, + args.gpu_collect, + False, + pre_eval=args.eval is not None and not eval_on_format_results, + format_only=args.format_only or eval_on_format_results, + format_args=eval_kwargs) rank, _ = get_dist_info() if rank == 0: if args.out: + warnings.warn( + 'The behavior of ``args.out`` has been changed since MMSeg ' + 'v0.16, the pickled outputs could be seg map as type of ' + 'np.array, pre-eval results or file paths for ' + '``dataset.format_results()``.') print(f'\nwriting results to {args.out}') - mmcv.dump(outputs, args.out) - kwargs = {} if args.eval_options is None else args.eval_options - if args.format_only: - dataset.format_results(outputs, **kwargs) + mmcv.dump(results, args.out) if args.eval: - dataset.evaluate(outputs, args.eval, **kwargs) + dataset.evaluate(results, args.eval, **eval_kwargs) + if tmpdir is not None and eval_on_format_results: + # remove tmp dir when cityscapes evaluation + shutil.rmtree(tmpdir) if __name__ == '__main__':