From 524425cf935e183c99c47193d3199d0f884070c4 Mon Sep 17 00:00:00 2001 From: sennnnn <58427300+sennnnn@users.noreply.github.com> Date: Mon, 17 May 2021 10:29:28 +0800 Subject: [PATCH] [Feature] Add results2img, format_results for ade dataset (#544) * [Feature] Add results2img, format_results for ade dataset. * clean rebundant code. --- mmseg/datasets/ade.py | 79 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/mmseg/datasets/ade.py b/mmseg/datasets/ade.py index 5913e4377..5daf7e373 100644 --- a/mmseg/datasets/ade.py +++ b/mmseg/datasets/ade.py @@ -1,3 +1,10 @@ +import os.path as osp +import tempfile + +import mmcv +import numpy as np +from PIL import Image + from .builder import DATASETS from .custom import CustomDataset @@ -82,3 +89,75 @@ class ADE20KDataset(CustomDataset): seg_map_suffix='.png', reduce_zero_label=True, **kwargs) + + def results2img(self, results, imgfile_prefix, to_label_id): + """Write the segmentation results to images. + + Args: + results (list[list | tuple | ndarray]): Testing results of the + dataset. + imgfile_prefix (str): The filename prefix of the png files. + If the prefix is "somepath/xxx", + the png files will be named "somepath/xxx.png". + to_label_id (bool): whether convert output to label_id for + submission + + Returns: + list[str: str]: result txt files which contains corresponding + semantic segmentation images. + """ + mmcv.mkdir_or_exist(imgfile_prefix) + result_files = [] + prog_bar = mmcv.ProgressBar(len(self)) + for idx in range(len(self)): + result = results[idx] + + filename = self.img_infos[idx]['filename'] + basename = osp.splitext(osp.basename(filename))[0] + + png_filename = osp.join(imgfile_prefix, f'{basename}.png') + + # The index range of official requirement is from 0 to 150. + # But the index range of output is from 0 to 149. + # That is because we set reduce_zero_label=True. + result = result + 1 + + output = Image.fromarray(result.astype(np.uint8)) + output.save(png_filename) + result_files.append(png_filename) + + prog_bar.update() + + return result_files + + def format_results(self, results, imgfile_prefix=None, to_label_id=True): + """Format the results into dir (standard format for ade20k evaluation). + + Args: + results (list): Testing results of the dataset. + imgfile_prefix (str | None): The prefix of images files. It + includes the file path and the prefix of filename, e.g., + "a/b/prefix". If not specified, a temp file will be created. + Default: None. + to_label_id (bool): whether convert output to label_id for + submission. Default: False + + Returns: + tuple: (result_files, tmp_dir), result_files is a list containing + the image paths, tmp_dir is the temporal directory created + for saving json/png files when img_prefix is not specified. + """ + + assert isinstance(results, list), 'results must be a list' + assert len(results) == len(self), ( + 'The length of results is not equal to the dataset len: ' + f'{len(results)} != {len(self)}') + + if imgfile_prefix is None: + tmp_dir = tempfile.TemporaryDirectory() + imgfile_prefix = tmp_dir.name + else: + tmp_dir = None + + result_files = self.results2img(results, imgfile_prefix, to_label_id) + return result_files, tmp_dir