mmsegmentation/mmseg/apis/test.py

234 lines
9.0 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
import warnings
import mmcv
import numpy as np
import torch
from mmcv.engine import collect_results_cpu, collect_results_gpu
from mmcv.image import tensor2imgs
from mmcv.runner import get_dist_info
def np2tmp(array, temp_file_name=None, tmpdir=None):
"""Save ndarray to local numpy file.
Args:
array (ndarray): Ndarray to save.
temp_file_name (str): Numpy file name. If 'temp_file_name=None', this
function will generate a file name with tempfile.NamedTemporaryFile
to save ndarray. Default: None.
tmpdir (str): Temporary directory to save Ndarray files. Default: None.
Returns:
str: The numpy file name.
"""
if temp_file_name is None:
temp_file_name = tempfile.NamedTemporaryFile(
suffix='.npy', delete=False, dir=tmpdir).name
np.save(temp_file_name, array)
return temp_file_name
def single_gpu_test(model,
data_loader,
show=False,
out_dir=None,
efficient_test=False,
opacity=0.5,
pre_eval=False,
format_only=False,
format_args={}):
"""Test with single GPU by progressive mode.
Args:
model (nn.Module): Model to be tested.
data_loader (utils.data.Dataloader): Pytorch data loader.
show (bool): Whether show results during inference. Default: False.
out_dir (str, optional): If specified, the results will be dumped into
the directory to save output results.
efficient_test (bool): Whether save the results as local numpy files to
save CPU memory during evaluation. Mutually exclusive with
pre_eval and format_results. Default: False.
opacity(float): Opacity of painted segmentation map.
Default 0.5.
Must be in (0, 1] range.
pre_eval (bool): Use dataset.pre_eval() function to generate
pre_results for metric evaluation. Mutually exclusive with
efficient_test and format_results. Default: False.
format_only (bool): Only format result for results commit.
Mutually exclusive with pre_eval and efficient_test.
Default: False.
format_args (dict): The args for format_results. Default: {}.
Returns:
list: list of evaluation pre-results or list of save file names.
"""
if efficient_test:
warnings.warn(
'DeprecationWarning: ``efficient_test`` will be deprecated, the '
'evaluation is CPU memory friendly with pre_eval=True')
mmcv.mkdir_or_exist('.efficient_test')
# when none of them is set true, return segmentation results as
# a list of np.array.
assert [efficient_test, pre_eval, format_only].count(True) <= 1, \
'``efficient_test``, ``pre_eval`` and ``format_only`` are mutually ' \
'exclusive, only one of them could be true .'
model.eval()
results = []
dataset = data_loader.dataset
prog_bar = mmcv.ProgressBar(len(dataset))
# The pipeline about how the data_loader retrieval samples from dataset:
# sampler -> batch_sampler -> indices
# The indices are passed to dataset_fetcher to get data from dataset.
# data_fetcher -> collate_fn(dataset[index]) -> data_sample
# we use batch_sampler to get correct data idx
loader_indices = data_loader.batch_sampler
for batch_indices, data in zip(loader_indices, data_loader):
with torch.no_grad():
result = model(return_loss=False, **data)
if show or out_dir:
img_tensor = data['img'][0]
img_metas = data['img_metas'][0].data[0]
imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
assert len(imgs) == len(img_metas)
for img, img_meta in zip(imgs, img_metas):
h, w, _ = img_meta['img_shape']
img_show = img[:h, :w, :]
ori_h, ori_w = img_meta['ori_shape'][:-1]
img_show = mmcv.imresize(img_show, (ori_w, ori_h))
if out_dir:
out_file = osp.join(out_dir, img_meta['ori_filename'])
else:
out_file = None
model.module.show_result(
img_show,
result,
palette=dataset.PALETTE,
show=show,
out_file=out_file,
opacity=opacity)
if efficient_test:
result = [np2tmp(_, tmpdir='.efficient_test') for _ in result]
if format_only:
result = dataset.format_results(
result, indices=batch_indices, **format_args)
if pre_eval:
# TODO: adapt samples_per_gpu > 1.
# only samples_per_gpu=1 valid now
result = dataset.pre_eval(result, indices=batch_indices)
results.extend(result)
else:
results.extend(result)
batch_size = len(result)
for _ in range(batch_size):
prog_bar.update()
return results
def multi_gpu_test(model,
data_loader,
tmpdir=None,
gpu_collect=False,
efficient_test=False,
pre_eval=False,
format_only=False,
format_args={}):
"""Test model with multiple gpus by progressive mode.
This method tests model with multiple gpus and collects the results
under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
it encodes results to gpu tensors and use gpu communication for results
collection. On cpu mode it saves the results on different gpus to 'tmpdir'
and collects them by the rank 0 worker.
Args:
model (nn.Module): Model to be tested.
data_loader (utils.data.Dataloader): Pytorch data loader.
tmpdir (str): Path of directory to save the temporary results from
different gpus under cpu mode. The same path is used for efficient
test. Default: None.
gpu_collect (bool): Option to use either gpu or cpu to collect results.
Default: False.
efficient_test (bool): Whether save the results as local numpy files to
save CPU memory during evaluation. Mutually exclusive with
pre_eval and format_results. Default: False.
pre_eval (bool): Use dataset.pre_eval() function to generate
pre_results for metric evaluation. Mutually exclusive with
efficient_test and format_results. Default: False.
format_only (bool): Only format result for results commit.
Mutually exclusive with pre_eval and efficient_test.
Default: False.
format_args (dict): The args for format_results. Default: {}.
Returns:
list: list of evaluation pre-results or list of save file names.
"""
if efficient_test:
warnings.warn(
'DeprecationWarning: ``efficient_test`` will be deprecated, the '
'evaluation is CPU memory friendly with pre_eval=True')
mmcv.mkdir_or_exist('.efficient_test')
# when none of them is set true, return segmentation results as
# a list of np.array.
assert [efficient_test, pre_eval, format_only].count(True) <= 1, \
'``efficient_test``, ``pre_eval`` and ``format_only`` are mutually ' \
'exclusive, only one of them could be true .'
model.eval()
results = []
dataset = data_loader.dataset
# The pipeline about how the data_loader retrieval samples from dataset:
# sampler -> batch_sampler -> indices
# The indices are passed to dataset_fetcher to get data from dataset.
# data_fetcher -> collate_fn(dataset[index]) -> data_sample
# we use batch_sampler to get correct data idx
# batch_sampler based on DistributedSampler, the indices only point to data
# samples of related machine.
loader_indices = data_loader.batch_sampler
rank, world_size = get_dist_info()
if rank == 0:
prog_bar = mmcv.ProgressBar(len(dataset))
for batch_indices, data in zip(loader_indices, data_loader):
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
if efficient_test:
result = [np2tmp(_, tmpdir='.efficient_test') for _ in result]
if format_only:
result = dataset.format_results(
result, indices=batch_indices, **format_args)
if pre_eval:
# TODO: adapt samples_per_gpu > 1.
# only samples_per_gpu=1 valid now
result = dataset.pre_eval(result, indices=batch_indices)
results.extend(result)
if rank == 0:
batch_size = len(result) * world_size
for _ in range(batch_size):
prog_bar.update()
# collect results from all ranks
if gpu_collect:
results = collect_results_gpu(results, len(dataset))
else:
results = collect_results_cpu(results, len(dataset), tmpdir)
return results