memory efficient test (#330)

* memory efficient test

* implement efficient test

* merge

* Add document and docstring

* fix unit test

* add memory usage report
This commit is contained in:
yamengxi 2021-01-10 15:47:31 +08:00 committed by GitHub
parent 8ed47abd23
commit ce46d70d20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 187 additions and 67 deletions

View File

@ -25,6 +25,7 @@ Optional arguments:
- `EVAL_METRICS`: Items to be evaluated on the results. Allowed values depend on the dataset, e.g., `mIoU` is available for all dataset. Cityscapes could be evaluated by `cityscapes` as well as standard `mIoU` metrics. - `EVAL_METRICS`: Items to be evaluated on the results. Allowed values depend on the dataset, e.g., `mIoU` is available for all dataset. Cityscapes could be evaluated by `cityscapes` as well as standard `mIoU` metrics.
- `--show`: If specified, segmentation results will be plotted on the images and shown in a new window. It is only applicable to single GPU testing and used for debugging and visualization. Please make sure that GUI is available in your environment, otherwise you may encounter the error like `cannot connect to X server`. - `--show`: If specified, segmentation results will be plotted on the images and shown in a new window. It is only applicable to single GPU testing and used for debugging and visualization. Please make sure that GUI is available in your environment, otherwise you may encounter the error like `cannot connect to X server`.
- `--show-dir`: If specified, segmentation results will be plotted on the images and saved to the specified directory. It is only applicable to single GPU testing and used for debugging and visualization. You do NOT need a GUI available in your environment for using this option. - `--show-dir`: If specified, segmentation results will be plotted on the images and saved to the specified directory. It is only applicable to single GPU testing and used for debugging and visualization. You do NOT need a GUI available in your environment for using this option.
- `--eval-options`: Optional parameters during evaluation. When `efficient_test=True`, it will save intermediate results to local files to save CPU memory. Make sure that you have enough local storage space (more than 20GB).
Examples: Examples:
@ -86,3 +87,15 @@ Assume that you have already downloaded the checkpoints to the directory `checkp
You will get png files under `./pspnet_test_results` directory. You will get png files under `./pspnet_test_results` directory.
You may run `zip -r results.zip pspnet_test_results/` and submit the zip file to [evaluation server](https://www.cityscapes-dataset.com/submit/). You may run `zip -r results.zip pspnet_test_results/` and submit the zip file to [evaluation server](https://www.cityscapes-dataset.com/submit/).
6. CPU memory efficient test DeeplabV3+ on Cityscapes (without saving the test results) and evaluate the mIoU.
```shell
python tools/test.py \
configs/deeplabv3plus/deeplabv3plus_r18-d8_512x1024_80k_cityscapes.py \
deeplabv3plus_r18-d8_512x1024_80k_cityscapes_20201226_080942-cff257fe.pth \
--eval-options efficient_test=True \
--eval mIoU
```
Using ```pmap``` to view CPU memory footprint, it used 2.25GB CPU memory with ```efficient_test=True``` and 11.06GB CPU memory with ```efficient_test=False``` . This optional parameter can save a lot of memory.

View File

@ -4,21 +4,48 @@ import shutil
import tempfile import tempfile
import mmcv import mmcv
import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from mmcv.image import tensor2imgs from mmcv.image import tensor2imgs
from mmcv.runner import get_dist_info from mmcv.runner import get_dist_info
def single_gpu_test(model, data_loader, show=False, out_dir=None): def np2tmp(array, temp_file_name=None):
"""Save ndarray to local numpy file.
Args:
array (ndarray): Ndarray to save.
temp_file_name (str): Numpy file name. If 'temp_file_name=None', this
function will generate a file name with tempfile.NamedTemporaryFile
to save ndarray. Default: None.
Returns:
str: The numpy file name.
"""
if temp_file_name is None:
temp_file_name = tempfile.NamedTemporaryFile(
suffix='.npy', delete=False).name
np.save(temp_file_name, array)
return temp_file_name
def single_gpu_test(model,
data_loader,
show=False,
out_dir=None,
efficient_test=False):
"""Test with single GPU. """Test with single GPU.
Args: Args:
model (nn.Module): Model to be tested. model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): Pytorch data loader. data_loader (utils.data.Dataloader): Pytorch data loader.
show (bool): Whether show results during infernece. Default: False. show (bool): Whether show results during infernece. Default: False.
out_dir (str, optional): If specified, the results will be dumped out_dir (str, optional): If specified, the results will be dumped into
into the directory to save output results. the directory to save output results.
efficient_test (bool): Whether save the results as local numpy files to
save CPU memory during evaluation. Default: False.
Returns: Returns:
list: The prediction results. list: The prediction results.
@ -31,10 +58,6 @@ def single_gpu_test(model, data_loader, show=False, out_dir=None):
for i, data in enumerate(data_loader): for i, data in enumerate(data_loader):
with torch.no_grad(): with torch.no_grad():
result = model(return_loss=False, **data) result = model(return_loss=False, **data)
if isinstance(result, list):
results.extend(result)
else:
results.append(result)
if show or out_dir: if show or out_dir:
img_tensor = data['img'][0] img_tensor = data['img'][0]
@ -61,13 +84,26 @@ def single_gpu_test(model, data_loader, show=False, out_dir=None):
show=show, show=show,
out_file=out_file) out_file=out_file)
if isinstance(result, list):
if efficient_test:
result = [np2tmp(_) for _ in result]
results.extend(result)
else:
if efficient_test:
result = np2tmp(result)
results.append(result)
batch_size = data['img'][0].size(0) batch_size = data['img'][0].size(0)
for _ in range(batch_size): for _ in range(batch_size):
prog_bar.update() prog_bar.update()
return results return results
def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False): def multi_gpu_test(model,
data_loader,
tmpdir=None,
gpu_collect=False,
efficient_test=False):
"""Test model with multiple gpus. """Test model with multiple gpus.
This method tests model with multiple gpus and collects the results This method tests model with multiple gpus and collects the results
@ -78,10 +114,12 @@ def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
Args: Args:
model (nn.Module): Model to be tested. model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): Pytorch data loader. data_loader (utils.data.Dataloader): Pytorch data loader.
tmpdir (str): Path of directory to save the temporary results from tmpdir (str): Path of directory to save the temporary results from
different gpus under cpu mode. different gpus under cpu mode.
gpu_collect (bool): Option to use either gpu or cpu to collect results. gpu_collect (bool): Option to use either gpu or cpu to collect results.
efficient_test (bool): Whether save the results as local numpy files to
save CPU memory during evaluation. Default: False.
Returns: Returns:
list: The prediction results. list: The prediction results.
@ -96,9 +134,14 @@ def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
for i, data in enumerate(data_loader): for i, data in enumerate(data_loader):
with torch.no_grad(): with torch.no_grad():
result = model(return_loss=False, rescale=True, **data) result = model(return_loss=False, rescale=True, **data)
if isinstance(result, list): if isinstance(result, list):
if efficient_test:
result = [np2tmp(_) for _ in result]
results.extend(result) results.extend(result)
else: else:
if efficient_test:
result = np2tmp(result)
results.append(result) results.append(result)
if rank == 0: if rank == 0:

View File

@ -1,24 +1,49 @@
import mmcv
import numpy as np import numpy as np
def intersect_and_union(pred_label, label, num_classes, ignore_index): def intersect_and_union(pred_label,
label,
num_classes,
ignore_index,
label_map=dict(),
reduce_zero_label=False):
"""Calculate intersection and Union. """Calculate intersection and Union.
Args: Args:
pred_label (ndarray): Prediction segmentation map pred_label (ndarray): Prediction segmentation map.
label (ndarray): Ground truth segmentation map label (ndarray): Ground truth segmentation map.
num_classes (int): Number of categories num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation. ignore_index (int): Index that will be ignored in evaluation.
label_map (dict): Mapping old labels to new labels. The parameter will
work only when label is str. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. The parameter will
work only when label is str. Default: False.
Returns: Returns:
ndarray: The intersection of prediction and ground truth histogram ndarray: The intersection of prediction and ground truth histogram
on all classes on all classes.
ndarray: The union of prediction and ground truth histogram on all ndarray: The union of prediction and ground truth histogram on all
classes classes.
ndarray: The prediction histogram on all classes. ndarray: The prediction histogram on all classes.
ndarray: The ground truth histogram on all classes. ndarray: The ground truth histogram on all classes.
""" """
if isinstance(pred_label, str):
pred_label = np.load(pred_label)
if isinstance(label, str):
label = mmcv.imread(label, flag='unchanged', backend='pillow')
# modify if custom classes
if label_map is not None:
for old_id, new_id in label_map.items():
label[label == old_id] = new_id
if reduce_zero_label:
# avoid using underflow conversion
label[label == 0] = 255
label = label - 1
label[label == 254] = 255
mask = (label != ignore_index) mask = (label != ignore_index)
pred_label = pred_label[mask] pred_label = pred_label[mask]
label = label[mask] label = label[mask]
@ -34,20 +59,27 @@ def intersect_and_union(pred_label, label, num_classes, ignore_index):
return area_intersect, area_union, area_pred_label, area_label return area_intersect, area_union, area_pred_label, area_label
def total_intersect_and_union(results, gt_seg_maps, num_classes, ignore_index): def total_intersect_and_union(results,
gt_seg_maps,
num_classes,
ignore_index,
label_map=dict(),
reduce_zero_label=False):
"""Calculate Total Intersection and Union. """Calculate Total Intersection and Union.
Args: Args:
results (list[ndarray]): List of prediction segmentation maps results (list[ndarray]): List of prediction segmentation maps.
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
num_classes (int): Number of categories num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation. ignore_index (int): Index that will be ignored in evaluation.
label_map (dict): Mapping old labels to new labels. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. Default: False.
Returns: Returns:
ndarray: The intersection of prediction and ground truth histogram ndarray: The intersection of prediction and ground truth histogram
on all classes on all classes.
ndarray: The union of prediction and ground truth histogram on all ndarray: The union of prediction and ground truth histogram on all
classes classes.
ndarray: The prediction histogram on all classes. ndarray: The prediction histogram on all classes.
ndarray: The ground truth histogram on all classes. ndarray: The ground truth histogram on all classes.
""" """
@ -61,7 +93,7 @@ def total_intersect_and_union(results, gt_seg_maps, num_classes, ignore_index):
for i in range(num_imgs): for i in range(num_imgs):
area_intersect, area_union, area_pred_label, area_label = \ area_intersect, area_union, area_pred_label, area_label = \
intersect_and_union(results[i], gt_seg_maps[i], num_classes, intersect_and_union(results[i], gt_seg_maps[i], num_classes,
ignore_index=ignore_index) ignore_index, label_map, reduce_zero_label)
total_area_intersect += area_intersect total_area_intersect += area_intersect
total_area_union += area_union total_area_union += area_union
total_area_pred_label += area_pred_label total_area_pred_label += area_pred_label
@ -70,21 +102,29 @@ def total_intersect_and_union(results, gt_seg_maps, num_classes, ignore_index):
total_area_pred_label, total_area_label total_area_pred_label, total_area_label
def mean_iou(results, gt_seg_maps, num_classes, ignore_index, nan_to_num=None): def mean_iou(results,
gt_seg_maps,
num_classes,
ignore_index,
nan_to_num=None,
label_map=dict(),
reduce_zero_label=False):
"""Calculate Mean Intersection and Union (mIoU) """Calculate Mean Intersection and Union (mIoU)
Args: Args:
results (list[ndarray]): List of prediction segmentation maps results (list[ndarray]): List of prediction segmentation maps.
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
num_classes (int): Number of categories num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation. ignore_index (int): Index that will be ignored in evaluation.
nan_to_num (int, optional): If specified, NaN values will be replaced nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None. by the numbers defined by the user. Default: None.
label_map (dict): Mapping old labels to new labels. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. Default: False.
Returns: Returns:
float: Overall accuracy on all images. float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, ) ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category IoU, shape (num_classes, ) ndarray: Per category IoU, shape (num_classes, ).
""" """
all_acc, acc, iou = eval_metrics( all_acc, acc, iou = eval_metrics(
@ -93,7 +133,9 @@ def mean_iou(results, gt_seg_maps, num_classes, ignore_index, nan_to_num=None):
num_classes=num_classes, num_classes=num_classes,
ignore_index=ignore_index, ignore_index=ignore_index,
metrics=['mIoU'], metrics=['mIoU'],
nan_to_num=nan_to_num) nan_to_num=nan_to_num,
label_map=label_map,
reduce_zero_label=reduce_zero_label)
return all_acc, acc, iou return all_acc, acc, iou
@ -101,21 +143,25 @@ def mean_dice(results,
gt_seg_maps, gt_seg_maps,
num_classes, num_classes,
ignore_index, ignore_index,
nan_to_num=None): nan_to_num=None,
label_map=dict(),
reduce_zero_label=False):
"""Calculate Mean Dice (mDice) """Calculate Mean Dice (mDice)
Args: Args:
results (list[ndarray]): List of prediction segmentation maps results (list[ndarray]): List of prediction segmentation maps.
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
num_classes (int): Number of categories num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation. ignore_index (int): Index that will be ignored in evaluation.
nan_to_num (int, optional): If specified, NaN values will be replaced nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None. by the numbers defined by the user. Default: None.
label_map (dict): Mapping old labels to new labels. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. Default: False.
Returns: Returns:
float: Overall accuracy on all images. float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, ) ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category dice, shape (num_classes, ) ndarray: Per category dice, shape (num_classes, ).
""" """
all_acc, acc, dice = eval_metrics( all_acc, acc, dice = eval_metrics(
@ -124,7 +170,9 @@ def mean_dice(results,
num_classes=num_classes, num_classes=num_classes,
ignore_index=ignore_index, ignore_index=ignore_index,
metrics=['mDice'], metrics=['mDice'],
nan_to_num=nan_to_num) nan_to_num=nan_to_num,
label_map=label_map,
reduce_zero_label=reduce_zero_label)
return all_acc, acc, dice return all_acc, acc, dice
@ -133,20 +181,24 @@ def eval_metrics(results,
num_classes, num_classes,
ignore_index, ignore_index,
metrics=['mIoU'], metrics=['mIoU'],
nan_to_num=None): nan_to_num=None,
label_map=dict(),
reduce_zero_label=False):
"""Calculate evaluation metrics """Calculate evaluation metrics
Args: Args:
results (list[ndarray]): List of prediction segmentation maps results (list[ndarray]): List of prediction segmentation maps.
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps gt_seg_maps (list[ndarray]): list of ground truth segmentation maps.
num_classes (int): Number of categories num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation. ignore_index (int): Index that will be ignored in evaluation.
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'. metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
nan_to_num (int, optional): If specified, NaN values will be replaced nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None. by the numbers defined by the user. Default: None.
label_map (dict): Mapping old labels to new labels. Default: dict().
reduce_zero_label (bool): Wether ignore zero label. Default: False.
Returns: Returns:
float: Overall accuracy on all images. float: Overall accuracy on all images.
ndarray: Per category accuracy, shape (num_classes, ) ndarray: Per category accuracy, shape (num_classes, ).
ndarray: Per category evalution metrics, shape (num_classes, ) ndarray: Per category evalution metrics, shape (num_classes, ).
""" """
if isinstance(metrics, str): if isinstance(metrics, str):
@ -156,8 +208,9 @@ def eval_metrics(results,
raise KeyError('metrics {} is not supported'.format(metrics)) raise KeyError('metrics {} is not supported'.format(metrics))
total_area_intersect, total_area_union, total_area_pred_label, \ total_area_intersect, total_area_union, total_area_pred_label, \
total_area_label = total_intersect_and_union(results, gt_seg_maps, total_area_label = total_intersect_and_union(results, gt_seg_maps,
num_classes, num_classes, ignore_index,
ignore_index=ignore_index) label_map,
reduce_zero_label)
all_acc = total_area_intersect.sum() / total_area_label.sum() all_acc = total_area_intersect.sum() / total_area_label.sum()
acc = total_area_intersect / total_area_label acc = total_area_intersect / total_area_label
ret_metrics = [all_acc, acc] ret_metrics = [all_acc, acc]

View File

@ -38,6 +38,8 @@ class CityscapesDataset(CustomDataset):
@staticmethod @staticmethod
def _convert_to_label_id(result): def _convert_to_label_id(result):
"""Convert trainId to id for cityscapes.""" """Convert trainId to id for cityscapes."""
if isinstance(result, str):
result = np.load(result)
import cityscapesscripts.helpers.labels as CSLabels import cityscapesscripts.helpers.labels as CSLabels
result_copy = result.copy() result_copy = result.copy()
for trainId, label in CSLabels.trainId2label.items(): for trainId, label in CSLabels.trainId2label.items():
@ -123,7 +125,8 @@ class CityscapesDataset(CustomDataset):
results, results,
metric='mIoU', metric='mIoU',
logger=None, logger=None,
imgfile_prefix=None): imgfile_prefix=None,
efficient_test=False):
"""Evaluation in Cityscapes/default protocol. """Evaluation in Cityscapes/default protocol.
Args: Args:
@ -154,7 +157,7 @@ class CityscapesDataset(CustomDataset):
if len(metrics) > 0: if len(metrics) > 0:
eval_results.update( eval_results.update(
super(CityscapesDataset, super(CityscapesDataset,
self).evaluate(results, metrics, logger)) self).evaluate(results, metrics, logger, efficient_test))
return eval_results return eval_results

View File

@ -1,3 +1,4 @@
import os
import os.path as osp import os.path as osp
from functools import reduce from functools import reduce
@ -226,25 +227,17 @@ class CustomDataset(Dataset):
"""Place holder to format result to dataset specific output.""" """Place holder to format result to dataset specific output."""
pass pass
def get_gt_seg_maps(self): def get_gt_seg_maps(self, efficient_test=False):
"""Get ground truth segmentation maps for evaluation.""" """Get ground truth segmentation maps for evaluation."""
gt_seg_maps = [] gt_seg_maps = []
for img_info in self.img_infos: for img_info in self.img_infos:
seg_map = osp.join(self.ann_dir, img_info['ann']['seg_map']) seg_map = osp.join(self.ann_dir, img_info['ann']['seg_map'])
gt_seg_map = mmcv.imread( if efficient_test:
seg_map, flag='unchanged', backend='pillow') gt_seg_map = seg_map
# modify if custom classes else:
if self.label_map is not None: gt_seg_map = mmcv.imread(
for old_id, new_id in self.label_map.items(): seg_map, flag='unchanged', backend='pillow')
gt_seg_map[gt_seg_map == old_id] = new_id
if self.reduce_zero_label:
# avoid using underflow conversion
gt_seg_map[gt_seg_map == 0] = 255
gt_seg_map = gt_seg_map - 1
gt_seg_map[gt_seg_map == 254] = 255
gt_seg_maps.append(gt_seg_map) gt_seg_maps.append(gt_seg_map)
return gt_seg_maps return gt_seg_maps
def get_classes_and_palette(self, classes=None, palette=None): def get_classes_and_palette(self, classes=None, palette=None):
@ -310,7 +303,12 @@ class CustomDataset(Dataset):
return palette return palette
def evaluate(self, results, metric='mIoU', logger=None, **kwargs): def evaluate(self,
results,
metric='mIoU',
logger=None,
efficient_test=False,
**kwargs):
"""Evaluate the dataset. """Evaluate the dataset.
Args: Args:
@ -330,7 +328,7 @@ class CustomDataset(Dataset):
if not set(metric).issubset(set(allowed_metrics)): if not set(metric).issubset(set(allowed_metrics)):
raise KeyError('metric {} is not supported'.format(metric)) raise KeyError('metric {} is not supported'.format(metric))
eval_results = {} eval_results = {}
gt_seg_maps = self.get_gt_seg_maps() gt_seg_maps = self.get_gt_seg_maps(efficient_test)
if self.CLASSES is None: if self.CLASSES is None:
num_classes = len( num_classes = len(
reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps])) reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps]))
@ -340,8 +338,10 @@ class CustomDataset(Dataset):
results, results,
gt_seg_maps, gt_seg_maps,
num_classes, num_classes,
ignore_index=self.ignore_index, self.ignore_index,
metrics=metric) metric,
label_map=self.label_map,
reduce_zero_label=self.reduce_zero_label)
class_table_data = [['Class'] + [m[1:] for m in metric] + ['Acc']] class_table_data = [['Class'] + [m[1:] for m in metric] + ['Acc']]
if self.CLASSES is None: if self.CLASSES is None:
class_names = tuple(range(num_classes)) class_names = tuple(range(num_classes))
@ -374,4 +374,7 @@ class CustomDataset(Dataset):
for i in range(1, len(summary_table_data[0])): for i in range(1, len(summary_table_data[0])):
eval_results[summary_table_data[0] eval_results[summary_table_data[0]
[i]] = summary_table_data[1][i] / 100.0 [i]] = summary_table_data[1][i] / 100.0
if mmcv.is_list_of(results, str):
for file_name in results:
os.remove(file_name)
return eval_results return eval_results

View File

@ -115,16 +115,21 @@ def main():
model.CLASSES = checkpoint['meta']['CLASSES'] model.CLASSES = checkpoint['meta']['CLASSES']
model.PALETTE = checkpoint['meta']['PALETTE'] model.PALETTE = checkpoint['meta']['PALETTE']
efficient_test = False
if args.eval_options is not None:
efficient_test = args.eval_options.get('efficient_test', False)
if not distributed: if not distributed:
model = MMDataParallel(model, device_ids=[0]) model = MMDataParallel(model, device_ids=[0])
outputs = single_gpu_test(model, data_loader, args.show, args.show_dir) outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
efficient_test)
else: else:
model = MMDistributedDataParallel( model = MMDistributedDataParallel(
model.cuda(), model.cuda(),
device_ids=[torch.cuda.current_device()], device_ids=[torch.cuda.current_device()],
broadcast_buffers=False) broadcast_buffers=False)
outputs = multi_gpu_test(model, data_loader, args.tmpdir, outputs = multi_gpu_test(model, data_loader, args.tmpdir,
args.gpu_collect) args.gpu_collect, efficient_test)
rank, _ = get_dist_info() rank, _ = get_dist_info()
if rank == 0: if rank == 0: