183 lines
6.1 KiB
Python
183 lines
6.1 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
||
import argparse
|
||
import os
|
||
import warnings
|
||
from pathlib import Path
|
||
|
||
import mmcv
|
||
import numpy as np
|
||
from mmcv import Config, DictAction
|
||
|
||
from mmseg.datasets.builder import build_dataset
|
||
|
||
|
||
def parse_args():
|
||
parser = argparse.ArgumentParser(description='Browse a dataset')
|
||
parser.add_argument('config', help='train config file path')
|
||
parser.add_argument(
|
||
'--show-origin',
|
||
default=False,
|
||
action='store_true',
|
||
help='if True, omit all augmentation in pipeline,'
|
||
' show origin image and seg map')
|
||
parser.add_argument(
|
||
'--skip-type',
|
||
type=str,
|
||
nargs='+',
|
||
default=['DefaultFormatBundle', 'Normalize', 'Collect'],
|
||
help='skip some useless pipeline,if `show-origin` is true, '
|
||
'all pipeline except `Load` will be skipped')
|
||
parser.add_argument(
|
||
'--output-dir',
|
||
default='./output',
|
||
type=str,
|
||
help='If there is no display interface, you can save it')
|
||
parser.add_argument('--show', default=False, action='store_true')
|
||
parser.add_argument(
|
||
'--show-interval',
|
||
type=int,
|
||
default=999,
|
||
help='the interval of show (ms)')
|
||
parser.add_argument(
|
||
'--opacity',
|
||
type=float,
|
||
default=0.5,
|
||
help='the opacity of semantic map')
|
||
parser.add_argument(
|
||
'--cfg-options',
|
||
nargs='+',
|
||
action=DictAction,
|
||
help='override some settings in the used config, the key-value pair '
|
||
'in xxx=yyy format will be merged into config file. If the value to '
|
||
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||
'Note that the quotation marks are necessary and that no white space '
|
||
'is allowed.')
|
||
args = parser.parse_args()
|
||
return args
|
||
|
||
|
||
def imshow_semantic(img,
|
||
seg,
|
||
class_names,
|
||
palette=None,
|
||
win_name='',
|
||
show=False,
|
||
wait_time=0,
|
||
out_file=None,
|
||
opacity=0.5):
|
||
"""Draw `result` over `img`.
|
||
|
||
Args:
|
||
img (str or Tensor): The image to be displayed.
|
||
seg (Tensor): The semantic segmentation results to draw over
|
||
`img`.
|
||
class_names (list[str]): Names of each classes.
|
||
palette (list[list[int]]] | np.ndarray | None): The palette of
|
||
segmentation map. If None is given, random palette will be
|
||
generated. Default: None
|
||
win_name (str): The window name.
|
||
wait_time (int): Value of waitKey param.
|
||
Default: 0.
|
||
show (bool): Whether to show the image.
|
||
Default: False.
|
||
out_file (str or None): The filename to write the image.
|
||
Default: None.
|
||
opacity(float): Opacity of painted segmentation map.
|
||
Default 0.5.
|
||
Must be in (0, 1] range.
|
||
Returns:
|
||
img (Tensor): Only if not `show` or `out_file`
|
||
"""
|
||
img = mmcv.imread(img)
|
||
img = img.copy()
|
||
if palette is None:
|
||
palette = np.random.randint(0, 255, size=(len(class_names), 3))
|
||
palette = np.array(palette)
|
||
assert palette.shape[0] == len(class_names)
|
||
assert palette.shape[1] == 3
|
||
assert len(palette.shape) == 2
|
||
assert 0 < opacity <= 1.0
|
||
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
|
||
for label, color in enumerate(palette):
|
||
color_seg[seg == label, :] = color
|
||
# convert to BGR
|
||
color_seg = color_seg[..., ::-1]
|
||
|
||
img = img * (1 - opacity) + color_seg * opacity
|
||
img = img.astype(np.uint8)
|
||
# if out_file specified, do not show image in window
|
||
if out_file is not None:
|
||
show = False
|
||
|
||
if show:
|
||
mmcv.imshow(img, win_name, wait_time)
|
||
if out_file is not None:
|
||
mmcv.imwrite(img, out_file)
|
||
|
||
if not (show or out_file):
|
||
warnings.warn('show==False and out_file is not specified, only '
|
||
'result image will be returned')
|
||
return img
|
||
|
||
|
||
def _retrieve_data_cfg(_data_cfg, skip_type, show_origin):
|
||
if show_origin is True:
|
||
# only keep pipeline of Loading data and ann
|
||
_data_cfg['pipeline'] = [
|
||
x for x in _data_cfg.pipeline if 'Load' in x['type']
|
||
]
|
||
else:
|
||
_data_cfg['pipeline'] = [
|
||
x for x in _data_cfg.pipeline if x['type'] not in skip_type
|
||
]
|
||
|
||
|
||
def retrieve_data_cfg(config_path, skip_type, cfg_options, show_origin=False):
|
||
cfg = Config.fromfile(config_path)
|
||
if cfg_options is not None:
|
||
cfg.merge_from_dict(cfg_options)
|
||
train_data_cfg = cfg.data.train
|
||
if isinstance(train_data_cfg, list):
|
||
for _data_cfg in train_data_cfg:
|
||
while 'dataset' in _data_cfg and _data_cfg[
|
||
'type'] != 'MultiImageMixDataset':
|
||
_data_cfg = _data_cfg['dataset']
|
||
if 'pipeline' in _data_cfg:
|
||
_retrieve_data_cfg(_data_cfg, skip_type, show_origin)
|
||
else:
|
||
raise ValueError
|
||
else:
|
||
while 'dataset' in train_data_cfg and train_data_cfg[
|
||
'type'] != 'MultiImageMixDataset':
|
||
train_data_cfg = train_data_cfg['dataset']
|
||
_retrieve_data_cfg(train_data_cfg, skip_type, show_origin)
|
||
return cfg
|
||
|
||
|
||
def main():
|
||
args = parse_args()
|
||
cfg = retrieve_data_cfg(args.config, args.skip_type, args.cfg_options,
|
||
args.show_origin)
|
||
dataset = build_dataset(cfg.data.train)
|
||
progress_bar = mmcv.ProgressBar(len(dataset))
|
||
for item in dataset:
|
||
filename = os.path.join(args.output_dir,
|
||
Path(item['filename']).name
|
||
) if args.output_dir is not None else None
|
||
imshow_semantic(
|
||
item['img'],
|
||
item['gt_semantic_seg'],
|
||
dataset.CLASSES,
|
||
dataset.PALETTE,
|
||
show=args.show,
|
||
wait_time=args.show_interval,
|
||
out_file=filename,
|
||
opacity=args.opacity,
|
||
)
|
||
progress_bar.update()
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|