[Refactor] Support progressive test with fewer memory cost (#709)

* Support progressive test with fewer memory cost.

* Temp code

* Using processor to refactor evaluation workflow.

* refactor eval hook.

* Fix process bar.

* Fix middle save argument.

* Modify some variable name of dataset evaluate api.

* Modify some viriable name of eval hook.

* Fix some priority bugs of eval hook.

* Depreciated efficient_test.

* Fix training progress blocked by eval hook.

* Depreciated old test api.

* Fix test api error.

* Modify outer api.

* Build a sampler test api.

* TODO: Refactor format_results.

* Modify variable names.

* Fix num_classes bug.

* Fix sampler index bug.

* Fix grammaly bug.

* Support batch sampler.

* More readable test api.

* Remove some command arg and fix eval hook bug.

* Support format-only arg.

* Modify format_results of datasets.

* Modify tool which use test apis.

* support cityscapes eval

* fixed cityscapes

* 1. Add comments for batch_sampler;

2. Keep eval hook api same and add deprecated warning;

3. Add doc string for dataset.pre_eval;

* Add efficient_test doc string.

* Modify test tool to compat old version.

* Modify eval hook to compat with old version.

* Modify test api to compat old version api.

* Sampler explanation.

* update warning

* Modify deploy_test.py

* compatible with old output, add efficient test back

* clear logic of exclusive

* Warning about efficient_test.

* Modify format_results save folder.

* Fix bugs of format_results.

* Modify deploy_test.py.

* Update doc

* Fix deploy test bugs.

* Fix custom dataset unit tests.

* Fix dataset unit tests.

* Fix eval hook unit tests.

* Fix some imcompatible.

* Add pre_eval argument for eval hooks.

* Update eval hook doc string.

* Make pre_eval false in default.

* Add unit tests for dataset format_results.

* Fix some comments and bc-breaking bug.

* Fix pre_eval set cfg field.

* Remove redundant codes.

Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>
pull/750/head^2
sennnnn 2021-08-20 11:44:58 +08:00 committed by GitHub
parent 99d8376d2e
commit e235c1a824
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 654 additions and 193 deletions

View File

@ -6,4 +6,4 @@ lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
# runtime settings
runner = dict(type='IterBasedRunner', max_iters=160000)
checkpoint_config = dict(by_epoch=False, interval=16000)
evaluation = dict(interval=16000, metric='mIoU')
evaluation = dict(interval=16000, metric='mIoU', pre_eval=True)

View File

@ -6,4 +6,4 @@ lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
# runtime settings
runner = dict(type='IterBasedRunner', max_iters=20000)
checkpoint_config = dict(by_epoch=False, interval=2000)
evaluation = dict(interval=2000, metric='mIoU')
evaluation = dict(interval=2000, metric='mIoU', pre_eval=True)

View File

@ -6,4 +6,4 @@ lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
# runtime settings
runner = dict(type='IterBasedRunner', max_iters=40000)
checkpoint_config = dict(by_epoch=False, interval=4000)
evaluation = dict(interval=4000, metric='mIoU')
evaluation = dict(interval=4000, metric='mIoU', pre_eval=True)

View File

@ -6,4 +6,4 @@ lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
# runtime settings
runner = dict(type='IterBasedRunner', max_iters=80000)
checkpoint_config = dict(by_epoch=False, interval=8000)
evaluation = dict(interval=8000, metric='mIoU')
evaluation = dict(interval=8000, metric='mIoU', pre_eval=True)

View File

@ -21,11 +21,11 @@ python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}] [-
Optional arguments:
- `RESULT_FILE`: Filename of the output results in pickle format. If not specified, the results will not be saved to a file.
- `RESULT_FILE`: Filename of the output results in pickle format. If not specified, the results will not be saved to a file. (After mmseg v0.17, the output results become pre-evaluation results or format result paths)
- `EVAL_METRICS`: Items to be evaluated on the results. Allowed values depend on the dataset, e.g., `mIoU` is available for all dataset. Cityscapes could be evaluated by `cityscapes` as well as standard `mIoU` metrics.
- `--show`: If specified, segmentation results will be plotted on the images and shown in a new window. It is only applicable to single GPU testing and used for debugging and visualization. Please make sure that GUI is available in your environment, otherwise you may encounter the error like `cannot connect to X server`.
- `--show-dir`: If specified, segmentation results will be plotted on the images and saved to the specified directory. It is only applicable to single GPU testing and used for debugging and visualization. You do NOT need a GUI available in your environment for using this option.
- `--eval-options`: Optional parameters during evaluation. When `efficient_test=True`, it will save intermediate results to local files to save CPU memory. Make sure that you have enough local storage space (more than 20GB).
- `--eval-options`: Optional parameters for `dataset.format_results` and `dataset.evaluate` during evaluation. When `efficient_test=True`, it will save intermediate results to local files to save CPU memory. Make sure that you have enough local storage space (more than 20GB). (`efficient_test` argument does not have effect after mmseg v0.17, we use a progressive mode to evaluation and format results which can largely save memory cost and evaluation time.)
Examples:
@ -98,4 +98,4 @@ Assume that you have already downloaded the checkpoints to the directory `checkp
--eval mIoU
```
Using ```pmap``` to view CPU memory footprint, it used 2.25GB CPU memory with ```efficient_test=True``` and 11.06GB CPU memory with ```efficient_test=False``` . This optional parameter can save a lot of memory.
Using ```pmap``` to view CPU memory footprint, it used 2.25GB CPU memory with ```efficient_test=True``` and 11.06GB CPU memory with ```efficient_test=False``` . This optional parameter can save a lot of memory. (After mmseg v0.17, efficient_test has not effect and we use a progressive mode to evaluation and format results efficiently by default.)

View File

@ -20,11 +20,11 @@ python tools/test.py ${配置文件} ${检查点文件} [--out ${结果文件}]
可选参数:
- `RESULT_FILE`: pickle 格式的输出结果的文件名,如果不专门指定,结果将不会被专门保存成文件
- `EVAL_METRICS`: 在结果里将被评估的指标这主要取决于数据集, `mIoU` 对于所有数据集都可获得,像 Cityscapes 数据集可以通过 `cityscapes` 命令来专门评估,就像标准的 `mIoU`一样
- `--show`: 如果被指定,分割结果将会在一张图像里画出来并且在另一个窗口展示,它仅仅是用来调试与可视化,并且仅针对单卡 GPU 测试,请确认 GUI 在您的环境里可用,否则您也许会遇到报错 `cannot connect to X server`
- `--show-dir`: 如果被指定,分割结果将会在一张图像里画出来并且保存在指定文件夹里它仅仅是用来调试与可视化并且仅针对单卡GPU测试使用该参数时您的环境不需要 GUI
- `--eval-options`: 评估时的可选参数,当设置 `efficient_test=True` 时,它将会保存中间结果至本地文件里以节约 CPU 内存请确认您本地硬盘有足够的存储空间大于20GB
- `RESULT_FILE`: pickle 格式的输出结果的文件名,如果不专门指定,结果将不会被专门保存成文件MMseg v0.17 之后args.out 将只会保存评估时的中间结果或者是分割图的保存路径。)
- `EVAL_METRICS`: 在结果里将被评估的指标这主要取决于数据集, `mIoU` 对于所有数据集都可获得,像 Cityscapes 数据集可以通过 `cityscapes` 命令来专门评估,就像标准的 `mIoU`一样
- `--show`: 如果被指定,分割结果将会在一张图像里画出来并且在另一个窗口展示。它仅仅是用来调试与可视化,并且仅针对单卡 GPU 测试。请确认 GUI 在您的环境里可用,否则您也许会遇到报错 `cannot connect to X server`
- `--show-dir`: 如果被指定,分割结果将会在一张图像里画出来并且保存在指定文件夹里。它仅仅是用来调试与可视化并且仅针对单卡GPU测试。使用该参数时您的环境不需要 GUI。
- `--eval-options`: 评估时的可选参数,当设置 `efficient_test=True` 时,它将会保存中间结果至本地文件里以节约 CPU 内存请确认您本地硬盘有足够的存储空间大于20GBMMseg v0.17 之后,`efficient_test` 不再生效,我们重构了 test api通过使用一种渐近式的方式来提升评估和保存结果的效率。
例子:
@ -96,4 +96,4 @@ python tools/test.py ${配置文件} ${检查点文件} [--out ${结果文件}]
--eval mIoU
```
使用 ```pmap``` 可查看 CPU 内存情况, ```efficient_test=True``` 会使用约 2.25GB 的 CPU 内存, ```efficient_test=False``` 会使用约 11.06GB 的 CPU 内存。 这个可选参数可以节约很多 CPU 内存。
使用 ```pmap``` 可查看 CPU 内存情况, ```efficient_test=True``` 会使用约 2.25GB 的 CPU 内存, ```efficient_test=False``` 会使用约 11.06GB 的 CPU 内存。 这个可选参数可以节约很多 CPU 内存。MMseg v0.17 之后, `efficient_test` 参数将不再生效, 我们使用了一种渐近的方式来更加有效快速地评估和保存结果。)

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
import warnings
import mmcv
import numpy as np
@ -19,7 +20,6 @@ def np2tmp(array, temp_file_name=None, tmpdir=None):
function will generate a file name with tempfile.NamedTemporaryFile
to save ndarray. Default: None.
tmpdir (str): Temporary directory to save Ndarray files. Default: None.
Returns:
str: The numpy file name.
"""
@ -36,8 +36,11 @@ def single_gpu_test(model,
show=False,
out_dir=None,
efficient_test=False,
opacity=0.5):
"""Test with single GPU.
opacity=0.5,
pre_eval=False,
format_only=False,
format_args={}):
"""Test with single GPU by progressive mode.
Args:
model (nn.Module): Model to be tested.
@ -46,24 +49,60 @@ def single_gpu_test(model,
out_dir (str, optional): If specified, the results will be dumped into
the directory to save output results.
efficient_test (bool): Whether save the results as local numpy files to
save CPU memory during evaluation. Default: False.
save CPU memory during evaluation. Mutually exclusive with
pre_eval and format_results. Default: False.
opacity(float): Opacity of painted segmentation map.
Default 0.5.
Must be in (0, 1] range.
pre_eval (bool): Use dataset.pre_eval() function to generate
pre_results for metric evaluation. Mutually exclusive with
efficient_test and format_results. Default: False.
format_only (bool): Only format result for results commit.
Mutually exclusive with pre_eval and efficient_test.
Default: False.
format_args (dict): The args for format_results. Default: {}.
Returns:
list: The prediction results.
list: list of evaluation pre-results or list of save file names.
"""
if efficient_test:
warnings.warn(
'DeprecationWarning: ``efficient_test`` will be deprecated, the '
'evaluation is CPU memory friendly with pre_eval=True')
mmcv.mkdir_or_exist('.efficient_test')
# when none of them is set true, return segmentation results as
# a list of np.array.
assert [efficient_test, pre_eval, format_only].count(True) <= 1, \
'``efficient_test``, ``pre_eval`` and ``format_only`` are mutually ' \
'exclusive, only one of them could be true .'
model.eval()
results = []
dataset = data_loader.dataset
prog_bar = mmcv.ProgressBar(len(dataset))
if efficient_test:
mmcv.mkdir_or_exist('.efficient_test')
for i, data in enumerate(data_loader):
# The pipeline about how the data_loader retrieval samples from dataset:
# sampler -> batch_sampler -> indices
# The indices are passed to dataset_fetcher to get data from dataset.
# data_fetcher -> collate_fn(dataset[index]) -> data_sample
# we use batch_sampler to get correct data idx
loader_indices = data_loader.batch_sampler
for batch_indices, data in zip(loader_indices, data_loader):
with torch.no_grad():
result = model(return_loss=False, **data)
if efficient_test:
result = [np2tmp(_, tmpdir='.efficient_test') for _ in result]
if format_only:
result = dataset.format_results(
result, indices=batch_indices, **format_args)
if pre_eval:
# TODO: adapt samples_per_gpu > 1.
# only samples_per_gpu=1 valid now
result = dataset.pre_eval(result, indices=batch_indices)
results.extend(result)
if show or out_dir:
img_tensor = data['img'][0]
img_metas = data['img_metas'][0].data[0]
@ -90,18 +129,10 @@ def single_gpu_test(model,
out_file=out_file,
opacity=opacity)
if isinstance(result, list):
if efficient_test:
result = [np2tmp(_, tmpdir='.efficient_test') for _ in result]
results.extend(result)
else:
if efficient_test:
result = np2tmp(result, tmpdir='.efficient_test')
results.append(result)
batch_size = len(result)
for _ in range(batch_size):
prog_bar.update()
return results
@ -109,8 +140,11 @@ def multi_gpu_test(model,
data_loader,
tmpdir=None,
gpu_collect=False,
efficient_test=False):
"""Test model with multiple gpus.
efficient_test=False,
pre_eval=False,
format_only=False,
format_args={}):
"""Test model with multiple gpus by progressive mode.
This method tests model with multiple gpus and collects the results
under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
@ -123,39 +157,71 @@ def multi_gpu_test(model,
data_loader (utils.data.Dataloader): Pytorch data loader.
tmpdir (str): Path of directory to save the temporary results from
different gpus under cpu mode. The same path is used for efficient
test.
test. Default: None.
gpu_collect (bool): Option to use either gpu or cpu to collect results.
Default: False.
efficient_test (bool): Whether save the results as local numpy files to
save CPU memory during evaluation. Default: False.
save CPU memory during evaluation. Mutually exclusive with
pre_eval and format_results. Default: False.
pre_eval (bool): Use dataset.pre_eval() function to generate
pre_results for metric evaluation. Mutually exclusive with
efficient_test and format_results. Default: False.
format_only (bool): Only format result for results commit.
Mutually exclusive with pre_eval and efficient_test.
Default: False.
format_args (dict): The args for format_results. Default: {}.
Returns:
list: The prediction results.
list: list of evaluation pre-results or list of save file names.
"""
if efficient_test:
warnings.warn(
'DeprecationWarning: ``efficient_test`` will be deprecated, the '
'evaluation is CPU memory friendly with pre_eval=True')
mmcv.mkdir_or_exist('.efficient_test')
# when none of them is set true, return segmentation results as
# a list of np.array.
assert [efficient_test, pre_eval, format_only].count(True) <= 1, \
'``efficient_test``, ``pre_eval`` and ``format_only`` are mutually ' \
'exclusive, only one of them could be true .'
model.eval()
results = []
dataset = data_loader.dataset
# The pipeline about how the data_loader retrieval samples from dataset:
# sampler -> batch_sampler -> indices
# The indices are passed to dataset_fetcher to get data from dataset.
# data_fetcher -> collate_fn(dataset[index]) -> data_sample
# we use batch_sampler to get correct data idx
# batch_sampler based on DistributedSampler, the indices only point to data
# samples of related machine.
loader_indices = data_loader.batch_sampler
rank, world_size = get_dist_info()
if rank == 0:
prog_bar = mmcv.ProgressBar(len(dataset))
if efficient_test:
mmcv.mkdir_or_exist('.efficient_test')
for i, data in enumerate(data_loader):
for batch_indices, data in zip(loader_indices, data_loader):
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
if isinstance(result, list):
if efficient_test:
result = [np2tmp(_, tmpdir='.efficient_test') for _ in result]
results.extend(result)
else:
if efficient_test:
result = np2tmp(result, tmpdir='.efficient_test')
results.append(result)
if efficient_test:
result = [np2tmp(_, tmpdir='.efficient_test') for _ in result]
if format_only:
result = dataset.format_results(
result, indices=batch_indices, **format_args)
if pre_eval:
# TODO: adapt samples_per_gpu > 1.
# only samples_per_gpu=1 valid now
result = dataset.pre_eval(result, indices=batch_indices)
results.extend(result)
if rank == 0:
batch_size = len(result)
for _ in range(batch_size * world_size):
batch_size = len(result) * world_size
for _ in range(batch_size):
prog_bar.update()
# collect results from all ranks

View File

@ -1,9 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .class_names import get_classes, get_palette
from .eval_hooks import DistEvalHook, EvalHook
from .metrics import eval_metrics, mean_dice, mean_fscore, mean_iou
from .metrics import (eval_metrics, intersect_and_union, mean_dice,
mean_fscore, mean_iou, pre_eval_to_metrics)
__all__ = [
'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'mean_fscore',
'eval_metrics', 'get_classes', 'get_palette'
'eval_metrics', 'get_classes', 'get_palette', 'pre_eval_to_metrics',
'intersect_and_union'
]

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import warnings
import torch.distributed as dist
from mmcv.runner import DistEvalHook as _DistEvalHook
@ -16,15 +17,28 @@ class EvalHook(_EvalHook):
Default: False.
efficient_test (bool): Whether save the results as local numpy files to
save CPU memory during evaluation. Default: False.
pre_eval (bool): Whether to use progressive mode to evaluate model.
Default: False.
Returns:
list: The prediction results.
"""
greater_keys = ['mIoU', 'mAcc', 'aAcc']
def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs):
def __init__(self,
*args,
by_epoch=False,
efficient_test=False,
pre_eval=False,
**kwargs):
super().__init__(*args, by_epoch=by_epoch, **kwargs)
self.efficient_test = efficient_test
self.pre_eval = pre_eval
if efficient_test:
warnings.warn(
'DeprecationWarning: ``efficient_test`` for evaluation hook '
'is deprecated, the evaluation hook is CPU memory friendly '
'with ``pre_eval=True`` as argument for ``single_gpu_test()`` '
'function')
def _do_evaluate(self, runner):
"""perform evaluation and save ckpt."""
@ -33,10 +47,8 @@ class EvalHook(_EvalHook):
from mmseg.apis import single_gpu_test
results = single_gpu_test(
runner.model,
self.dataloader,
show=False,
efficient_test=self.efficient_test)
runner.model, self.dataloader, show=False, pre_eval=self.pre_eval)
runner.log_buffer.clear()
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
key_score = self.evaluate(runner, results)
if self.save_best:
@ -52,15 +64,28 @@ class DistEvalHook(_DistEvalHook):
Default: False.
efficient_test (bool): Whether save the results as local numpy files to
save CPU memory during evaluation. Default: False.
pre_eval (bool): Whether to use progressive mode to evaluate model.
Default: False.
Returns:
list: The prediction results.
"""
greater_keys = ['mIoU', 'mAcc', 'aAcc']
def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs):
def __init__(self,
*args,
by_epoch=False,
efficient_test=False,
pre_eval=False,
**kwargs):
super().__init__(*args, by_epoch=by_epoch, **kwargs)
self.efficient_test = efficient_test
self.pre_eval = pre_eval
if efficient_test:
warnings.warn(
'DeprecationWarning: ``efficient_test`` for evaluation hook '
'is deprecated, the evaluation hook is CPU memory friendly '
'with ``pre_eval=True`` as argument for ``multi_gpu_test()`` '
'function')
def _do_evaluate(self, runner):
"""perform evaluation and save ckpt."""
@ -90,7 +115,10 @@ class DistEvalHook(_DistEvalHook):
self.dataloader,
tmpdir=tmpdir,
gpu_collect=self.gpu_collect,
efficient_test=self.efficient_test)
pre_eval=self.pre_eval)
runner.log_buffer.clear()
if runner.rank == 0:
print('\n')
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)

View File

@ -97,8 +97,8 @@ def total_intersect_and_union(results,
Args:
results (list[ndarray] | list[str]): List of prediction segmentation
maps or list of prediction result filenames.
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
segmentation maps or list of label filenames.
gt_seg_maps (list[ndarray] | list[str] | Iterables): list of ground
truth segmentation maps or list of label filenames.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
label_map (dict): Mapping old labels to new labels. Default: dict().
@ -113,15 +113,15 @@ def total_intersect_and_union(results,
ndarray: The ground truth histogram on all classes.
"""
num_imgs = len(results)
assert len(gt_seg_maps) == num_imgs
assert len(list(gt_seg_maps)) == num_imgs
total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_union = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_label = torch.zeros((num_classes, ), dtype=torch.float64)
for i in range(num_imgs):
for result, gt_seg_map in zip(results, gt_seg_maps):
area_intersect, area_union, area_pred_label, area_label = \
intersect_and_union(
results[i], gt_seg_maps[i], num_classes, ignore_index,
result, gt_seg_map, num_classes, ignore_index,
label_map, reduce_zero_label)
total_area_intersect += area_intersect
total_area_union += area_union
@ -268,8 +268,8 @@ def eval_metrics(results,
Args:
results (list[ndarray] | list[str]): List of prediction segmentation
maps or list of prediction result filenames.
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
segmentation maps or list of label filenames.
gt_seg_maps (list[ndarray] | list[str] | Iterables): list of ground
truth segmentation maps or list of label filenames.
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
@ -282,16 +282,86 @@ def eval_metrics(results,
ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category evaluation metrics, shape (num_classes, ).
"""
total_area_intersect, total_area_union, total_area_pred_label, \
total_area_label = total_intersect_and_union(
results, gt_seg_maps, num_classes, ignore_index, label_map,
reduce_zero_label)
ret_metrics = total_area_to_metrics(total_area_intersect, total_area_union,
total_area_pred_label,
total_area_label, metrics, nan_to_num,
beta)
return ret_metrics
def pre_eval_to_metrics(pre_eval_results,
metrics=['mIoU'],
nan_to_num=None,
beta=1):
"""Convert pre-eval results to metrics.
Args:
pre_eval_results (list[tuple[torch.Tensor]]): per image eval results
for computing evaluation metric
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.
Returns:
float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category evaluation metrics, shape (num_classes, ).
"""
# convert list of tuples to tuple of lists, e.g.
# [(A_1, B_1, C_1, D_1), ..., (A_n, B_n, C_n, D_n)] to
# ([A_1, ..., A_n], ..., [D_1, ..., D_n])
pre_eval_results = tuple(zip(*pre_eval_results))
assert len(pre_eval_results) == 4
total_area_intersect = sum(pre_eval_results[0])
total_area_union = sum(pre_eval_results[1])
total_area_pred_label = sum(pre_eval_results[2])
total_area_label = sum(pre_eval_results[3])
ret_metrics = total_area_to_metrics(total_area_intersect, total_area_union,
total_area_pred_label,
total_area_label, metrics, nan_to_num,
beta)
return ret_metrics
def total_area_to_metrics(total_area_intersect,
total_area_union,
total_area_pred_label,
total_area_label,
metrics=['mIoU'],
nan_to_num=None,
beta=1):
"""Calculate evaluation metrics
Args:
total_area_intersect (ndarray): The intersection of prediction and
ground truth histogram on all classes.
total_area_union (ndarray): The union of prediction and ground truth
histogram on all classes.
total_area_pred_label (ndarray): The prediction histogram on all
classes.
total_area_label (ndarray): The ground truth histogram on all classes.
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.
Returns:
float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category evaluation metrics, shape (num_classes, ).
"""
if isinstance(metrics, str):
metrics = [metrics]
allowed_metrics = ['mIoU', 'mDice', 'mFscore']
if not set(metrics).issubset(set(allowed_metrics)):
raise KeyError('metrics {} is not supported'.format(metrics))
total_area_intersect, total_area_union, total_area_pred_label, \
total_area_label = total_intersect_and_union(
results, gt_seg_maps, num_classes, ignore_index, label_map,
reduce_zero_label)
all_acc = total_area_intersect.sum() / total_area_label.sum()
ret_metrics = OrderedDict({'aAcc': all_acc})
for metric in metrics:

View File

@ -1,6 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
import mmcv
import numpy as np
@ -91,7 +90,7 @@ class ADE20KDataset(CustomDataset):
reduce_zero_label=True,
**kwargs)
def results2img(self, results, imgfile_prefix, to_label_id):
def results2img(self, results, imgfile_prefix, to_label_id, indices=None):
"""Write the segmentation results to images.
Args:
@ -101,17 +100,21 @@ class ADE20KDataset(CustomDataset):
If the prefix is "somepath/xxx",
the png files will be named "somepath/xxx.png".
to_label_id (bool): whether convert output to label_id for
submission
submission.
indices (list[int], optional): Indices of input results, if not
set, all the indices of the dataset will be used.
Default: None.
Returns:
list[str: str]: result txt files which contains corresponding
semantic segmentation images.
"""
if indices is None:
indices = list(range(len(self)))
mmcv.mkdir_or_exist(imgfile_prefix)
result_files = []
prog_bar = mmcv.ProgressBar(len(self))
for idx in range(len(self)):
result = results[idx]
for result, idx in zip(results, indices):
filename = self.img_infos[idx]['filename']
basename = osp.splitext(osp.basename(filename))[0]
@ -127,21 +130,25 @@ class ADE20KDataset(CustomDataset):
output.save(png_filename)
result_files.append(png_filename)
prog_bar.update()
return result_files
def format_results(self, results, imgfile_prefix=None, to_label_id=True):
def format_results(self,
results,
imgfile_prefix,
to_label_id=True,
indices=None):
"""Format the results into dir (standard format for ade20k evaluation).
Args:
results (list): Testing results of the dataset.
imgfile_prefix (str | None): The prefix of images files. It
includes the file path and the prefix of filename, e.g.,
"a/b/prefix". If not specified, a temp file will be created.
Default: None.
"a/b/prefix".
to_label_id (bool): whether convert output to label_id for
submission. Default: False
indices (list[int], optional): Indices of input results, if not
set, all the indices of the dataset will be used.
Default: None.
Returns:
tuple: (result_files, tmp_dir), result_files is a list containing
@ -149,16 +156,12 @@ class ADE20KDataset(CustomDataset):
for saving json/png files when img_prefix is not specified.
"""
assert isinstance(results, list), 'results must be a list'
assert len(results) == len(self), (
'The length of results is not equal to the dataset len: '
f'{len(results)} != {len(self)}')
if indices is None:
indices = list(range(len(self)))
if imgfile_prefix is None:
tmp_dir = tempfile.TemporaryDirectory()
imgfile_prefix = tmp_dir.name
else:
tmp_dir = None
assert isinstance(results, list), 'results must be a list.'
assert isinstance(indices, list), 'indices must be a list.'
result_files = self.results2img(results, imgfile_prefix, to_label_id)
return result_files, tmp_dir
result_files = self.results2img(results, imgfile_prefix, to_label_id,
indices)
return result_files

View File

@ -1,6 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
import mmcv
import numpy as np
@ -48,7 +47,7 @@ class CityscapesDataset(CustomDataset):
return result_copy
def results2img(self, results, imgfile_prefix, to_label_id):
def results2img(self, results, imgfile_prefix, to_label_id, indices=None):
"""Write the segmentation results to images.
Args:
@ -58,17 +57,21 @@ class CityscapesDataset(CustomDataset):
If the prefix is "somepath/xxx",
the png files will be named "somepath/xxx.png".
to_label_id (bool): whether convert output to label_id for
submission
submission.
indices (list[int], optional): Indices of input results,
if not set, all the indices of the dataset will be used.
Default: None.
Returns:
list[str: str]: result txt files which contains corresponding
semantic segmentation images.
"""
if indices is None:
indices = list(range(len(self)))
mmcv.mkdir_or_exist(imgfile_prefix)
result_files = []
prog_bar = mmcv.ProgressBar(len(self))
for idx in range(len(self)):
result = results[idx]
for result, idx in zip(results, indices):
if to_label_id:
result = self._convert_to_label_id(result)
filename = self.img_infos[idx]['filename']
@ -85,49 +88,49 @@ class CityscapesDataset(CustomDataset):
output.putpalette(palette)
output.save(png_filename)
result_files.append(png_filename)
prog_bar.update()
return result_files
def format_results(self, results, imgfile_prefix=None, to_label_id=True):
def format_results(self,
results,
imgfile_prefix,
to_label_id=True,
indices=None):
"""Format the results into dir (standard format for Cityscapes
evaluation).
Args:
results (list): Testing results of the dataset.
imgfile_prefix (str | None): The prefix of images files. It
imgfile_prefix (str): The prefix of images files. It
includes the file path and the prefix of filename, e.g.,
"a/b/prefix". If not specified, a temp file will be created.
Default: None.
"a/b/prefix".
to_label_id (bool): whether convert output to label_id for
submission. Default: False
indices (list[int], optional): Indices of input results,
if not set, all the indices of the dataset will be used.
Default: None.
Returns:
tuple: (result_files, tmp_dir), result_files is a list containing
the image paths, tmp_dir is the temporal directory created
for saving json/png files when img_prefix is not specified.
"""
if indices is None:
indices = list(range(len(self)))
assert isinstance(results, list), 'results must be a list'
assert len(results) == len(self), (
'The length of results is not equal to the dataset len: '
f'{len(results)} != {len(self)}')
assert isinstance(results, list), 'results must be a list.'
assert isinstance(indices, list), 'indices must be a list.'
if imgfile_prefix is None:
tmp_dir = tempfile.TemporaryDirectory()
imgfile_prefix = tmp_dir.name
else:
tmp_dir = None
result_files = self.results2img(results, imgfile_prefix, to_label_id)
result_files = self.results2img(results, imgfile_prefix, to_label_id,
indices)
return result_files, tmp_dir
return result_files
def evaluate(self,
results,
metric='mIoU',
logger=None,
imgfile_prefix=None,
efficient_test=False):
imgfile_prefix=None):
"""Evaluation in Cityscapes/default protocol.
Args:
@ -158,7 +161,7 @@ class CityscapesDataset(CustomDataset):
if len(metrics) > 0:
eval_results.update(
super(CityscapesDataset,
self).evaluate(results, metrics, logger, efficient_test))
self).evaluate(results, metrics, logger))
return eval_results
@ -184,12 +187,7 @@ class CityscapesDataset(CustomDataset):
msg = '\n' + msg
print_log(msg, logger=logger)
result_files, tmp_dir = self.format_results(results, imgfile_prefix)
if tmp_dir is None:
result_dir = imgfile_prefix
else:
result_dir = tmp_dir.name
result_dir = imgfile_prefix
eval_results = dict()
print_log(f'Evaluating results under {result_dir} ...', logger=logger)
@ -212,7 +210,4 @@ class CityscapesDataset(CustomDataset):
eval_results.update(
CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args))
if tmp_dir is not None:
tmp_dir.cleanup()
return eval_results

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
import warnings
from collections import OrderedDict
from functools import reduce
@ -10,7 +10,7 @@ from mmcv.utils import print_log
from prettytable import PrettyTable
from torch.utils.data import Dataset
from mmseg.core import eval_metrics
from mmseg.core import eval_metrics, intersect_and_union, pre_eval_to_metrics
from mmseg.utils import get_root_logger
from .builder import DATASETS
from .pipelines import Compose
@ -226,21 +226,55 @@ class CustomDataset(Dataset):
self.pre_pipeline(results)
return self.pipeline(results)
def format_results(self, results, **kwargs):
def format_results(self, results, imgfile_prefix, indices=None, **kwargs):
"""Place holder to format result to dataset specific output."""
raise NotImplementedError
def get_gt_seg_maps(self, efficient_test=False):
def get_gt_seg_maps(self, efficient_test=None):
"""Get ground truth segmentation maps for evaluation."""
gt_seg_maps = []
if efficient_test is not None:
warnings.warn(
'DeprecationWarning: ``efficient_test`` has been deprecated '
'since MMSeg v0.16, the ``get_gt_seg_maps()`` is CPU memory '
'friendly by default. ')
for img_info in self.img_infos:
seg_map = osp.join(self.ann_dir, img_info['ann']['seg_map'])
if efficient_test:
gt_seg_map = seg_map
else:
gt_seg_map = mmcv.imread(
seg_map, flag='unchanged', backend='pillow')
gt_seg_maps.append(gt_seg_map)
return gt_seg_maps
gt_seg_map = mmcv.imread(
seg_map, flag='unchanged', backend='pillow')
yield gt_seg_map
def pre_eval(self, preds, indices):
"""Collect eval result from each iteration.
Args:
preds (list[torch.Tensor] | torch.Tensor): the segmentation logit
after argmax, shape (N, H, W).
indices (list[int] | int): the prediction related ground truth
indices.
Returns:
list[torch.Tensor]: (area_intersect, area_union, area_prediction,
area_ground_truth).
"""
# In order to compat with batch inference
if not isinstance(indices, list):
indices = [indices]
if not isinstance(preds, list):
preds = [preds]
pre_eval_results = []
for pred, index in zip(preds, indices):
seg_map = osp.join(self.ann_dir,
self.img_infos[index]['ann']['seg_map'])
seg_map = mmcv.imread(seg_map, flag='unchanged', backend='pillow')
pre_eval_results.append(
intersect_and_union(pred, seg_map, len(self.CLASSES),
self.ignore_index, self.label_map,
self.reduce_zero_label))
return pre_eval_results
def get_classes_and_palette(self, classes=None, palette=None):
"""Get class names of current dataset.
@ -305,16 +339,13 @@ class CustomDataset(Dataset):
return palette
def evaluate(self,
results,
metric='mIoU',
logger=None,
efficient_test=False,
**kwargs):
def evaluate(self, results, metric='mIoU', logger=None, **kwargs):
"""Evaluate the dataset.
Args:
results (list): Testing results of the dataset.
results (list[tuple[torch.Tensor]] | list[str]): per image pre_eval
results or predict segmentation map for computing evaluation
metric.
metric (str | list[str]): Metrics to be evaluated. 'mIoU',
'mDice' and 'mFscore' are supported.
logger (logging.Logger | None | str): Logger used for printing
@ -323,28 +354,37 @@ class CustomDataset(Dataset):
Returns:
dict[str, float]: Default metrics.
"""
if isinstance(metric, str):
metric = [metric]
allowed_metrics = ['mIoU', 'mDice', 'mFscore']
if not set(metric).issubset(set(allowed_metrics)):
raise KeyError('metric {} is not supported'.format(metric))
eval_results = {}
gt_seg_maps = self.get_gt_seg_maps(efficient_test)
if self.CLASSES is None:
num_classes = len(
reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps]))
else:
num_classes = len(self.CLASSES)
ret_metrics = eval_metrics(
results,
gt_seg_maps,
num_classes,
self.ignore_index,
metric,
label_map=self.label_map,
reduce_zero_label=self.reduce_zero_label)
eval_results = {}
# test a list of files
if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of(
results, str):
gt_seg_maps = self.get_gt_seg_maps()
if self.CLASSES is None:
num_classes = len(
reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps]))
else:
num_classes = len(self.CLASSES)
# reset generator
gt_seg_maps = self.get_gt_seg_maps()
ret_metrics = eval_metrics(
results,
gt_seg_maps,
num_classes,
self.ignore_index,
metric,
label_map=self.label_map,
reduce_zero_label=self.reduce_zero_label)
# test a list of pre_eval_results
else:
ret_metrics = pre_eval_to_metrics(results, metric)
# Because dataset.CLASSES is required for per-eval.
if self.CLASSES is None:
class_names = tuple(range(num_classes))
else:
@ -396,7 +436,4 @@ class CustomDataset(Dataset):
for idx, name in enumerate(class_names)
})
if mmcv.is_list_of(results, str):
for file_name in results:
os.remove(file_name)
return eval_results

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 50 KiB

View File

@ -0,0 +1,72 @@
import shutil
from unittest.mock import MagicMock
import numpy as np
import pytest
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, dataloader
from mmseg.apis import single_gpu_test
class ExampleDataset(Dataset):
def __getitem__(self, idx):
results = dict(img=torch.tensor([1]), img_metas=dict())
return results
def __len__(self):
return 1
class ExampleModel(nn.Module):
def __init__(self):
super(ExampleModel, self).__init__()
self.test_cfg = None
self.conv = nn.Conv2d(3, 3, 3)
def forward(self, img, img_metas, return_loss=False, **kwargs):
return img
def test_single_gpu():
test_dataset = ExampleDataset()
data_loader = DataLoader(
test_dataset,
batch_size=1,
sampler=None,
num_workers=0,
shuffle=False,
)
model = ExampleModel()
# Test efficient test compatibility (will be deprecated)
results = single_gpu_test(model, data_loader, efficient_test=True)
assert len(results) == 1
pred = np.load(results[0])
assert isinstance(pred, np.ndarray)
assert pred.shape == (1, )
assert pred[0] == 1
shutil.rmtree('.efficient_test')
# Test pre_eval
test_dataset.pre_eval = MagicMock(return_value=['success'])
results = single_gpu_test(model, data_loader, pre_eval=True)
assert results == ['success']
# Test format_only
test_dataset.format_results = MagicMock(return_value=['success'])
results = single_gpu_test(model, data_loader, format_only=True)
assert results == ['success']
# efficient_test, pre_eval and format_only are mutually exclusive
with pytest.raises(AssertionError):
single_gpu_test(
model,
dataloader,
efficient_test=True,
format_only=True,
pre_eval=True)

View File

@ -1,9 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import shutil
from typing import Generator
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from PIL import Image
from mmseg.core.evaluation import get_classes, get_palette
from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset,
@ -152,10 +155,16 @@ def test_custom_dataset():
assert isinstance(test_data, dict)
# get gt seg map
gt_seg_maps = train_dataset.get_gt_seg_maps()
gt_seg_maps = train_dataset.get_gt_seg_maps(efficient_test=True)
assert isinstance(gt_seg_maps, Generator)
gt_seg_maps = list(gt_seg_maps)
assert len(gt_seg_maps) == 5
# evaluation
# format_results not implemented
with pytest.raises(NotImplementedError):
test_dataset.format_results([], '')
# test past evaluation
pseudo_results = []
for gt_seg_map in gt_seg_maps:
h, w = gt_seg_map.shape
@ -180,7 +189,7 @@ def test_custom_dataset():
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results
# evaluation with CLASSES
# test past evaluation with CLASSES
train_dataset.CLASSES = tuple(['a'] * 7)
eval_results = train_dataset.evaluate(pseudo_results, metric='mIoU')
assert isinstance(eval_results, dict)
@ -212,6 +221,95 @@ def test_custom_dataset():
assert 'mPrecision' in eval_results
assert 'mRecall' in eval_results
# test evaluation with pre-eval and the dataset.CLASSES is necessary
train_dataset.CLASSES = tuple(['a'] * 7)
pseudo_results = []
for idx in range(len(train_dataset)):
h, w = gt_seg_maps[idx].shape
pseudo_result = np.random.randint(low=0, high=7, size=(h, w))
pseudo_results.extend(train_dataset.pre_eval(pseudo_result, idx))
eval_results = train_dataset.evaluate(pseudo_results, metric=['mIoU'])
assert isinstance(eval_results, dict)
assert 'mIoU' in eval_results
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results
eval_results = train_dataset.evaluate(pseudo_results, metric='mDice')
assert isinstance(eval_results, dict)
assert 'mDice' in eval_results
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results
eval_results = train_dataset.evaluate(pseudo_results, metric='mFscore')
assert isinstance(eval_results, dict)
assert 'mRecall' in eval_results
assert 'mPrecision' in eval_results
assert 'mFscore' in eval_results
assert 'aAcc' in eval_results
eval_results = train_dataset.evaluate(
pseudo_results, metric=['mIoU', 'mDice', 'mFscore'])
assert isinstance(eval_results, dict)
assert 'mIoU' in eval_results
assert 'mDice' in eval_results
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results
assert 'mFscore' in eval_results
assert 'mPrecision' in eval_results
assert 'mRecall' in eval_results
def test_ade():
test_dataset = ADE20KDataset(
pipeline=[],
img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'))
assert len(test_dataset) == 5
# Test format_results
pseudo_results = []
for _ in range(len(test_dataset)):
h, w = (2, 2)
pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
file_paths = test_dataset.format_results(pseudo_results, '.format_ade')
assert len(file_paths) == len(test_dataset)
temp = np.array(Image.open(file_paths[0]))
assert np.allclose(temp, pseudo_results[0] + 1)
shutil.rmtree('.format_ade')
def test_cityscapes():
test_dataset = CityscapesDataset(
pipeline=[],
img_dir=osp.join(
osp.dirname(__file__),
'../data/pseudo_cityscapes_dataset/leftImg8bit'),
ann_dir=osp.join(
osp.dirname(__file__), '../data/pseudo_cityscapes_dataset/gtFine'))
assert len(test_dataset) == 1
gt_seg_maps = list(test_dataset.get_gt_seg_maps())
# Test format_results
pseudo_results = []
for idx in range(len(test_dataset)):
h, w = gt_seg_maps[idx].shape
pseudo_results.append(np.random.randint(low=0, high=19, size=(h, w)))
file_paths = test_dataset.format_results(pseudo_results, '.format_city')
assert len(file_paths) == len(test_dataset)
temp = np.array(Image.open(file_paths[0]))
assert np.allclose(temp,
test_dataset._convert_to_label_id(pseudo_results[0]))
# Test cityscapes evaluate
test_dataset.evaluate(
pseudo_results, metric='cityscapes', imgfile_prefix='.format_city')
shutil.rmtree('.format_city')
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
@patch('mmseg.datasets.CustomDataset.__getitem__',

View File

@ -53,6 +53,7 @@ def test_iter_eval_hook():
EvalHook(data_loader)
test_dataset = ExampleDataset()
test_dataset.pre_eval = MagicMock(return_value=[torch.tensor([1])])
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
loader = DataLoader(test_dataset, batch_size=1)
model = ExampleModel()
@ -64,7 +65,7 @@ def test_iter_eval_hook():
# test EvalHook
with tempfile.TemporaryDirectory() as tmpdir:
eval_hook = EvalHook(data_loader, by_epoch=False)
eval_hook = EvalHook(data_loader, by_epoch=False, efficient_test=True)
runner = mmcv.runner.IterBasedRunner(
model=model,
optimizer=optimizer,
@ -90,6 +91,7 @@ def test_epoch_eval_hook():
EvalHook(data_loader, by_epoch=True)
test_dataset = ExampleDataset()
test_dataset.pre_eval = MagicMock(return_value=[torch.tensor([1])])
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
loader = DataLoader(test_dataset, batch_size=1)
model = ExampleModel()
@ -117,8 +119,9 @@ def multi_gpu_test(model,
data_loader,
tmpdir=None,
gpu_collect=False,
efficient_test=False):
results = single_gpu_test(model, data_loader)
pre_eval=False):
# Pre eval is set by default when training.
results = single_gpu_test(model, data_loader, pre_eval=True)
return results
@ -137,6 +140,7 @@ def test_dist_eval_hook():
DistEvalHook(data_loader)
test_dataset = ExampleDataset()
test_dataset.pre_eval = MagicMock(return_value=[torch.tensor([1])])
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
loader = DataLoader(test_dataset, batch_size=1)
model = ExampleModel()
@ -148,7 +152,8 @@ def test_dist_eval_hook():
# test DistEvalHook
with tempfile.TemporaryDirectory() as tmpdir:
eval_hook = DistEvalHook(data_loader, by_epoch=False)
eval_hook = DistEvalHook(
data_loader, by_epoch=False, efficient_test=True)
runner = mmcv.runner.IterBasedRunner(
model=model,
optimizer=optimizer,
@ -175,6 +180,7 @@ def test_dist_eval_hook_epoch():
DistEvalHook(data_loader)
test_dataset = ExampleDataset()
test_dataset.pre_eval = MagicMock(return_value=[torch.tensor([1])])
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
loader = DataLoader(test_dataset, batch_size=1)
model = ExampleModel()

View File

@ -2,6 +2,7 @@
import argparse
import os
import os.path as osp
import shutil
import warnings
from typing import Any, Iterable
@ -234,24 +235,61 @@ def main():
model.CLASSES = dataset.CLASSES
model.PALETTE = dataset.PALETTE
efficient_test = False
if args.eval_options is not None:
efficient_test = args.eval_options.get('efficient_test', False)
# clean gpu memory when starting a new evaluation.
torch.cuda.empty_cache()
eval_kwargs = {} if args.eval_options is None else args.eval_options
# Deprecated
efficient_test = eval_kwargs.get('efficient_test', False)
if efficient_test:
warnings.warn(
'``efficient_test=True`` does not have effect in tools/test.py, '
'the evaluation and format results are CPU memory efficient by '
'default')
eval_on_format_results = (
args.eval is not None and 'cityscapes' in args.eval)
if eval_on_format_results:
assert len(args.eval) == 1, 'eval on format results is not ' \
'applicable for metrics other than ' \
'cityscapes'
if args.format_only or eval_on_format_results:
if 'imgfile_prefix' in eval_kwargs:
tmpdir = eval_kwargs['imgfile_prefix']
else:
tmpdir = '.format_cityscapes'
eval_kwargs.setdefault('imgfile_prefix', tmpdir)
mmcv.mkdir_or_exist(tmpdir)
else:
tmpdir = None
model = MMDataParallel(model, device_ids=[0])
outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
efficient_test, args.opacity)
results = single_gpu_test(
model,
data_loader,
args.show,
args.show_dir,
False,
args.opacity,
pre_eval=args.eval is not None and not eval_on_format_results,
format_only=args.format_only or eval_on_format_results,
format_args=eval_kwargs)
rank, _ = get_dist_info()
if rank == 0:
if args.out:
warnings.warn(
'The behavior of ``args.out`` has been changed since MMSeg '
'v0.16, the pickled outputs could be seg map as type of '
'np.array, pre-eval results or file paths for '
'``dataset.format_results()``.')
print(f'\nwriting results to {args.out}')
mmcv.dump(outputs, args.out)
kwargs = {} if args.eval_options is None else args.eval_options
if args.format_only:
dataset.format_results(outputs, **kwargs)
mmcv.dump(results, args.out)
if args.eval:
dataset.evaluate(outputs, args.eval, **kwargs)
dataset.evaluate(results, args.eval, **eval_kwargs)
if tmpdir is not None and eval_on_format_results:
# remove tmp dir when cityscapes evaluation
shutil.rmtree(tmpdir)
if __name__ == '__main__':

View File

@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import shutil
import warnings
import mmcv
import torch
@ -134,32 +136,76 @@ def main():
print('"PALETTE" not found in meta, use dataset.PALETTE instead')
model.PALETTE = dataset.PALETTE
efficient_test = False
if args.eval_options is not None:
efficient_test = args.eval_options.get('efficient_test', False)
# clean gpu memory when starting a new evaluation.
torch.cuda.empty_cache()
eval_kwargs = {} if args.eval_options is None else args.eval_options
# Deprecated
efficient_test = eval_kwargs.get('efficient_test', False)
if efficient_test:
warnings.warn(
'``efficient_test=True`` does not have effect in tools/test.py, '
'the evaluation and format results are CPU memory efficient by '
'default')
eval_on_format_results = (
args.eval is not None and 'cityscapes' in args.eval)
if eval_on_format_results:
assert len(args.eval) == 1, 'eval on format results is not ' \
'applicable for metrics other than ' \
'cityscapes'
if args.format_only or eval_on_format_results:
if 'imgfile_prefix' in eval_kwargs:
tmpdir = eval_kwargs['imgfile_prefix']
else:
tmpdir = '.format_cityscapes'
eval_kwargs.setdefault('imgfile_prefix', tmpdir)
mmcv.mkdir_or_exist(tmpdir)
else:
tmpdir = None
if not distributed:
model = MMDataParallel(model, device_ids=[0])
outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
efficient_test, args.opacity)
results = single_gpu_test(
model,
data_loader,
args.show,
args.show_dir,
False,
args.opacity,
pre_eval=args.eval is not None and not eval_on_format_results,
format_only=args.format_only or eval_on_format_results,
format_args=eval_kwargs)
else:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False)
outputs = multi_gpu_test(model, data_loader, args.tmpdir,
args.gpu_collect, efficient_test)
results = multi_gpu_test(
model,
data_loader,
args.tmpdir,
args.gpu_collect,
False,
pre_eval=args.eval is not None and not eval_on_format_results,
format_only=args.format_only or eval_on_format_results,
format_args=eval_kwargs)
rank, _ = get_dist_info()
if rank == 0:
if args.out:
warnings.warn(
'The behavior of ``args.out`` has been changed since MMSeg '
'v0.16, the pickled outputs could be seg map as type of '
'np.array, pre-eval results or file paths for '
'``dataset.format_results()``.')
print(f'\nwriting results to {args.out}')
mmcv.dump(outputs, args.out)
kwargs = {} if args.eval_options is None else args.eval_options
if args.format_only:
dataset.format_results(outputs, **kwargs)
mmcv.dump(results, args.out)
if args.eval:
dataset.evaluate(outputs, args.eval, **kwargs)
dataset.evaluate(results, args.eval, **eval_kwargs)
if tmpdir is not None and eval_on_format_results:
# remove tmp dir when cityscapes evaluation
shutil.rmtree(tmpdir)
if __name__ == '__main__':