234 lines
9.0 KiB
Python
234 lines
9.0 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os.path as osp
|
|
import tempfile
|
|
import warnings
|
|
|
|
import mmcv
|
|
import numpy as np
|
|
import torch
|
|
from mmcv.engine import collect_results_cpu, collect_results_gpu
|
|
from mmcv.image import tensor2imgs
|
|
from mmcv.runner import get_dist_info
|
|
|
|
|
|
def np2tmp(array, temp_file_name=None, tmpdir=None):
|
|
"""Save ndarray to local numpy file.
|
|
|
|
Args:
|
|
array (ndarray): Ndarray to save.
|
|
temp_file_name (str): Numpy file name. If 'temp_file_name=None', this
|
|
function will generate a file name with tempfile.NamedTemporaryFile
|
|
to save ndarray. Default: None.
|
|
tmpdir (str): Temporary directory to save Ndarray files. Default: None.
|
|
Returns:
|
|
str: The numpy file name.
|
|
"""
|
|
|
|
if temp_file_name is None:
|
|
temp_file_name = tempfile.NamedTemporaryFile(
|
|
suffix='.npy', delete=False, dir=tmpdir).name
|
|
np.save(temp_file_name, array)
|
|
return temp_file_name
|
|
|
|
|
|
def single_gpu_test(model,
|
|
data_loader,
|
|
show=False,
|
|
out_dir=None,
|
|
efficient_test=False,
|
|
opacity=0.5,
|
|
pre_eval=False,
|
|
format_only=False,
|
|
format_args={}):
|
|
"""Test with single GPU by progressive mode.
|
|
|
|
Args:
|
|
model (nn.Module): Model to be tested.
|
|
data_loader (utils.data.Dataloader): Pytorch data loader.
|
|
show (bool): Whether show results during inference. Default: False.
|
|
out_dir (str, optional): If specified, the results will be dumped into
|
|
the directory to save output results.
|
|
efficient_test (bool): Whether save the results as local numpy files to
|
|
save CPU memory during evaluation. Mutually exclusive with
|
|
pre_eval and format_results. Default: False.
|
|
opacity(float): Opacity of painted segmentation map.
|
|
Default 0.5.
|
|
Must be in (0, 1] range.
|
|
pre_eval (bool): Use dataset.pre_eval() function to generate
|
|
pre_results for metric evaluation. Mutually exclusive with
|
|
efficient_test and format_results. Default: False.
|
|
format_only (bool): Only format result for results commit.
|
|
Mutually exclusive with pre_eval and efficient_test.
|
|
Default: False.
|
|
format_args (dict): The args for format_results. Default: {}.
|
|
Returns:
|
|
list: list of evaluation pre-results or list of save file names.
|
|
"""
|
|
if efficient_test:
|
|
warnings.warn(
|
|
'DeprecationWarning: ``efficient_test`` will be deprecated, the '
|
|
'evaluation is CPU memory friendly with pre_eval=True')
|
|
mmcv.mkdir_or_exist('.efficient_test')
|
|
# when none of them is set true, return segmentation results as
|
|
# a list of np.array.
|
|
assert [efficient_test, pre_eval, format_only].count(True) <= 1, \
|
|
'``efficient_test``, ``pre_eval`` and ``format_only`` are mutually ' \
|
|
'exclusive, only one of them could be true .'
|
|
|
|
model.eval()
|
|
results = []
|
|
dataset = data_loader.dataset
|
|
prog_bar = mmcv.ProgressBar(len(dataset))
|
|
# The pipeline about how the data_loader retrieval samples from dataset:
|
|
# sampler -> batch_sampler -> indices
|
|
# The indices are passed to dataset_fetcher to get data from dataset.
|
|
# data_fetcher -> collate_fn(dataset[index]) -> data_sample
|
|
# we use batch_sampler to get correct data idx
|
|
loader_indices = data_loader.batch_sampler
|
|
|
|
for batch_indices, data in zip(loader_indices, data_loader):
|
|
with torch.no_grad():
|
|
result = model(return_loss=False, **data)
|
|
|
|
if show or out_dir:
|
|
img_tensor = data['img'][0]
|
|
img_metas = data['img_metas'][0].data[0]
|
|
imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
|
|
assert len(imgs) == len(img_metas)
|
|
|
|
for img, img_meta in zip(imgs, img_metas):
|
|
h, w, _ = img_meta['img_shape']
|
|
img_show = img[:h, :w, :]
|
|
|
|
ori_h, ori_w = img_meta['ori_shape'][:-1]
|
|
img_show = mmcv.imresize(img_show, (ori_w, ori_h))
|
|
|
|
if out_dir:
|
|
out_file = osp.join(out_dir, img_meta['ori_filename'])
|
|
else:
|
|
out_file = None
|
|
|
|
model.module.show_result(
|
|
img_show,
|
|
result,
|
|
palette=dataset.PALETTE,
|
|
show=show,
|
|
out_file=out_file,
|
|
opacity=opacity)
|
|
|
|
if efficient_test:
|
|
result = [np2tmp(_, tmpdir='.efficient_test') for _ in result]
|
|
|
|
if format_only:
|
|
result = dataset.format_results(
|
|
result, indices=batch_indices, **format_args)
|
|
if pre_eval:
|
|
# TODO: adapt samples_per_gpu > 1.
|
|
# only samples_per_gpu=1 valid now
|
|
result = dataset.pre_eval(result, indices=batch_indices)
|
|
results.extend(result)
|
|
else:
|
|
results.extend(result)
|
|
|
|
batch_size = len(result)
|
|
for _ in range(batch_size):
|
|
prog_bar.update()
|
|
|
|
return results
|
|
|
|
|
|
def multi_gpu_test(model,
|
|
data_loader,
|
|
tmpdir=None,
|
|
gpu_collect=False,
|
|
efficient_test=False,
|
|
pre_eval=False,
|
|
format_only=False,
|
|
format_args={}):
|
|
"""Test model with multiple gpus by progressive mode.
|
|
|
|
This method tests model with multiple gpus and collects the results
|
|
under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
|
|
it encodes results to gpu tensors and use gpu communication for results
|
|
collection. On cpu mode it saves the results on different gpus to 'tmpdir'
|
|
and collects them by the rank 0 worker.
|
|
|
|
Args:
|
|
model (nn.Module): Model to be tested.
|
|
data_loader (utils.data.Dataloader): Pytorch data loader.
|
|
tmpdir (str): Path of directory to save the temporary results from
|
|
different gpus under cpu mode. The same path is used for efficient
|
|
test. Default: None.
|
|
gpu_collect (bool): Option to use either gpu or cpu to collect results.
|
|
Default: False.
|
|
efficient_test (bool): Whether save the results as local numpy files to
|
|
save CPU memory during evaluation. Mutually exclusive with
|
|
pre_eval and format_results. Default: False.
|
|
pre_eval (bool): Use dataset.pre_eval() function to generate
|
|
pre_results for metric evaluation. Mutually exclusive with
|
|
efficient_test and format_results. Default: False.
|
|
format_only (bool): Only format result for results commit.
|
|
Mutually exclusive with pre_eval and efficient_test.
|
|
Default: False.
|
|
format_args (dict): The args for format_results. Default: {}.
|
|
|
|
Returns:
|
|
list: list of evaluation pre-results or list of save file names.
|
|
"""
|
|
if efficient_test:
|
|
warnings.warn(
|
|
'DeprecationWarning: ``efficient_test`` will be deprecated, the '
|
|
'evaluation is CPU memory friendly with pre_eval=True')
|
|
mmcv.mkdir_or_exist('.efficient_test')
|
|
# when none of them is set true, return segmentation results as
|
|
# a list of np.array.
|
|
assert [efficient_test, pre_eval, format_only].count(True) <= 1, \
|
|
'``efficient_test``, ``pre_eval`` and ``format_only`` are mutually ' \
|
|
'exclusive, only one of them could be true .'
|
|
|
|
model.eval()
|
|
results = []
|
|
dataset = data_loader.dataset
|
|
# The pipeline about how the data_loader retrieval samples from dataset:
|
|
# sampler -> batch_sampler -> indices
|
|
# The indices are passed to dataset_fetcher to get data from dataset.
|
|
# data_fetcher -> collate_fn(dataset[index]) -> data_sample
|
|
# we use batch_sampler to get correct data idx
|
|
|
|
# batch_sampler based on DistributedSampler, the indices only point to data
|
|
# samples of related machine.
|
|
loader_indices = data_loader.batch_sampler
|
|
|
|
rank, world_size = get_dist_info()
|
|
if rank == 0:
|
|
prog_bar = mmcv.ProgressBar(len(dataset))
|
|
|
|
for batch_indices, data in zip(loader_indices, data_loader):
|
|
with torch.no_grad():
|
|
result = model(return_loss=False, rescale=True, **data)
|
|
|
|
if efficient_test:
|
|
result = [np2tmp(_, tmpdir='.efficient_test') for _ in result]
|
|
|
|
if format_only:
|
|
result = dataset.format_results(
|
|
result, indices=batch_indices, **format_args)
|
|
if pre_eval:
|
|
# TODO: adapt samples_per_gpu > 1.
|
|
# only samples_per_gpu=1 valid now
|
|
result = dataset.pre_eval(result, indices=batch_indices)
|
|
|
|
results.extend(result)
|
|
|
|
if rank == 0:
|
|
batch_size = len(result) * world_size
|
|
for _ in range(batch_size):
|
|
prog_bar.update()
|
|
|
|
# collect results from all ranks
|
|
if gpu_collect:
|
|
results = collect_results_gpu(results, len(dataset))
|
|
else:
|
|
results = collect_results_cpu(results, len(dataset), tmpdir)
|
|
return results
|