import argparse import os import warnings from pathlib import Path import mmcv import numpy as np from mmcv import Config, DictAction from mmseg.datasets.builder import build_dataset def parse_args(): parser = argparse.ArgumentParser(description='Browse a dataset') parser.add_argument('config', help='train config file path') parser.add_argument( '--show-origin', default=False, action='store_true', help='if True, omit all augmentation in pipeline,' ' show origin image and seg map') parser.add_argument( '--skip-type', type=str, nargs='+', default=['DefaultFormatBundle', 'Normalize', 'Collect'], help='skip some useless pipeline,if `show-origin` is true, ' 'all pipeline except `Load` will be skipped') parser.add_argument( '--output-dir', default='./output', type=str, help='If there is no display interface, you can save it') parser.add_argument('--show', default=False, action='store_true') parser.add_argument( '--show-interval', type=int, default=999, help='the interval of show (ms)') parser.add_argument( '--opacity', type=float, default=0.5, help='the opacity of semantic map') parser.add_argument( '--cfg-options', nargs='+', action=DictAction, help='override some settings in the used config, the key-value pair ' 'in xxx=yyy format will be merged into config file. If the value to ' 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'Note that the quotation marks are necessary and that no white space ' 'is allowed.') args = parser.parse_args() return args def imshow_semantic(img, seg, class_names, palette=None, win_name='', show=False, wait_time=0, out_file=None, opacity=0.5): """Draw `result` over `img`. Args: img (str or Tensor): The image to be displayed. seg (Tensor): The semantic segmentation results to draw over `img`. class_names (list[str]): Names of each classes. palette (list[list[int]]] | np.ndarray | None): The palette of segmentation map. If None is given, random palette will be generated. Default: None win_name (str): The window name. wait_time (int): Value of waitKey param. Default: 0. show (bool): Whether to show the image. Default: False. out_file (str or None): The filename to write the image. Default: None. opacity(float): Opacity of painted segmentation map. Default 0.5. Must be in (0, 1] range. Returns: img (Tensor): Only if not `show` or `out_file` """ img = mmcv.imread(img) img = img.copy() if palette is None: palette = np.random.randint(0, 255, size=(len(class_names), 3)) palette = np.array(palette) assert palette.shape[0] == len(class_names) assert palette.shape[1] == 3 assert len(palette.shape) == 2 assert 0 < opacity <= 1.0 color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) for label, color in enumerate(palette): color_seg[seg == label, :] = color # convert to BGR color_seg = color_seg[..., ::-1] img = img * (1 - opacity) + color_seg * opacity img = img.astype(np.uint8) # if out_file specified, do not show image in window if out_file is not None: show = False if show: mmcv.imshow(img, win_name, wait_time) if out_file is not None: mmcv.imwrite(img, out_file) if not (show or out_file): warnings.warn('show==False and out_file is not specified, only ' 'result image will be returned') return img def _retrieve_data_cfg(_data_cfg, skip_type, show_origin): if show_origin is True: # only keep pipeline of Loading data and ann _data_cfg['pipeline'] = [ x for x in _data_cfg.pipeline if 'Load' in x['type'] ] else: _data_cfg['pipeline'] = [ x for x in _data_cfg.pipeline if x['type'] not in skip_type ] def retrieve_data_cfg(config_path, skip_type, cfg_options, show_origin=False): cfg = Config.fromfile(config_path) if cfg_options is not None: cfg.merge_from_dict(cfg_options) train_data_cfg = cfg.data.train if isinstance(train_data_cfg, list): for _data_cfg in train_data_cfg: while 'dataset' in _data_cfg and _data_cfg[ 'type'] != 'MultiImageMixDataset': _data_cfg = _data_cfg['dataset'] if 'pipeline' in _data_cfg: _retrieve_data_cfg(_data_cfg, skip_type, show_origin) else: raise ValueError else: while 'dataset' in train_data_cfg and train_data_cfg[ 'type'] != 'MultiImageMixDataset': train_data_cfg = train_data_cfg['dataset'] _retrieve_data_cfg(train_data_cfg, skip_type, show_origin) return cfg def main(): args = parse_args() cfg = retrieve_data_cfg(args.config, args.skip_type, args.cfg_options, args.show_origin) dataset = build_dataset(cfg.data.train) progress_bar = mmcv.ProgressBar(len(dataset)) for item in dataset: filename = os.path.join(args.output_dir, Path(item['filename']).name ) if args.output_dir is not None else None imshow_semantic( item['img'], item['gt_semantic_seg'], dataset.CLASSES, dataset.PALETTE, show=args.show, wait_time=args.show_interval, out_file=filename, opacity=args.opacity, ) progress_bar.update() if __name__ == '__main__': main()