168 lines
5.3 KiB
Python
168 lines
5.3 KiB
Python
|
import argparse
|
|||
|
import os
|
|||
|
import warnings
|
|||
|
from pathlib import Path
|
|||
|
|
|||
|
import mmcv
|
|||
|
import numpy as np
|
|||
|
from mmcv import Config
|
|||
|
|
|||
|
from mmseg.datasets.builder import build_dataset
|
|||
|
|
|||
|
|
|||
|
def parse_args():
|
|||
|
parser = argparse.ArgumentParser(description='Browse a dataset')
|
|||
|
parser.add_argument('config', help='train config file path')
|
|||
|
parser.add_argument(
|
|||
|
'--show-origin',
|
|||
|
default=False,
|
|||
|
action='store_true',
|
|||
|
help='if True, omit all augmentation in pipeline,'
|
|||
|
' show origin image and seg map')
|
|||
|
parser.add_argument(
|
|||
|
'--skip-type',
|
|||
|
type=str,
|
|||
|
nargs='+',
|
|||
|
default=['DefaultFormatBundle', 'Normalize', 'Collect'],
|
|||
|
help='skip some useless pipeline,if `show-origin` is true, '
|
|||
|
'all pipeline except `Load` will be skipped')
|
|||
|
parser.add_argument(
|
|||
|
'--output-dir',
|
|||
|
default='./output',
|
|||
|
type=str,
|
|||
|
help='If there is no display interface, you can save it')
|
|||
|
parser.add_argument('--show', default=False, action='store_true')
|
|||
|
parser.add_argument(
|
|||
|
'--show-interval',
|
|||
|
type=int,
|
|||
|
default=999,
|
|||
|
help='the interval of show (ms)')
|
|||
|
parser.add_argument(
|
|||
|
'--opacity',
|
|||
|
type=float,
|
|||
|
default=0.5,
|
|||
|
help='the opacity of semantic map')
|
|||
|
args = parser.parse_args()
|
|||
|
return args
|
|||
|
|
|||
|
|
|||
|
def imshow_semantic(img,
|
|||
|
seg,
|
|||
|
class_names,
|
|||
|
palette=None,
|
|||
|
win_name='',
|
|||
|
show=False,
|
|||
|
wait_time=0,
|
|||
|
out_file=None,
|
|||
|
opacity=0.5):
|
|||
|
"""Draw `result` over `img`.
|
|||
|
|
|||
|
Args:
|
|||
|
img (str or Tensor): The image to be displayed.
|
|||
|
seg (Tensor): The semantic segmentation results to draw over
|
|||
|
`img`.
|
|||
|
class_names (list[str]): Names of each classes.
|
|||
|
palette (list[list[int]]] | np.ndarray | None): The palette of
|
|||
|
segmentation map. If None is given, random palette will be
|
|||
|
generated. Default: None
|
|||
|
win_name (str): The window name.
|
|||
|
wait_time (int): Value of waitKey param.
|
|||
|
Default: 0.
|
|||
|
show (bool): Whether to show the image.
|
|||
|
Default: False.
|
|||
|
out_file (str or None): The filename to write the image.
|
|||
|
Default: None.
|
|||
|
opacity(float): Opacity of painted segmentation map.
|
|||
|
Default 0.5.
|
|||
|
Must be in (0, 1] range.
|
|||
|
Returns:
|
|||
|
img (Tensor): Only if not `show` or `out_file`
|
|||
|
"""
|
|||
|
img = mmcv.imread(img)
|
|||
|
img = img.copy()
|
|||
|
if palette is None:
|
|||
|
palette = np.random.randint(0, 255, size=(len(class_names), 3))
|
|||
|
palette = np.array(palette)
|
|||
|
assert palette.shape[0] == len(class_names)
|
|||
|
assert palette.shape[1] == 3
|
|||
|
assert len(palette.shape) == 2
|
|||
|
assert 0 < opacity <= 1.0
|
|||
|
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
|
|||
|
for label, color in enumerate(palette):
|
|||
|
color_seg[seg == label, :] = color
|
|||
|
# convert to BGR
|
|||
|
color_seg = color_seg[..., ::-1]
|
|||
|
|
|||
|
img = img * (1 - opacity) + color_seg * opacity
|
|||
|
img = img.astype(np.uint8)
|
|||
|
# if out_file specified, do not show image in window
|
|||
|
if out_file is not None:
|
|||
|
show = False
|
|||
|
|
|||
|
if show:
|
|||
|
mmcv.imshow(img, win_name, wait_time)
|
|||
|
if out_file is not None:
|
|||
|
mmcv.imwrite(img, out_file)
|
|||
|
|
|||
|
if not (show or out_file):
|
|||
|
warnings.warn('show==False and out_file is not specified, only '
|
|||
|
'result image will be returned')
|
|||
|
return img
|
|||
|
|
|||
|
|
|||
|
def _retrieve_data_cfg(_data_cfg, skip_type, show_origin):
|
|||
|
if show_origin is True:
|
|||
|
# only keep pipeline of Loading data and ann
|
|||
|
_data_cfg['pipeline'] = [
|
|||
|
x for x in _data_cfg.pipeline if 'Load' in x['type']
|
|||
|
]
|
|||
|
else:
|
|||
|
_data_cfg['pipeline'] = [
|
|||
|
x for x in _data_cfg.pipeline if x['type'] not in skip_type
|
|||
|
]
|
|||
|
|
|||
|
|
|||
|
def retrieve_data_cfg(config_path, skip_type, show_origin=False):
|
|||
|
cfg = Config.fromfile(config_path)
|
|||
|
train_data_cfg = cfg.data.train
|
|||
|
if isinstance(train_data_cfg, list):
|
|||
|
for _data_cfg in train_data_cfg:
|
|||
|
if 'pipeline' in _data_cfg:
|
|||
|
_retrieve_data_cfg(_data_cfg, skip_type, show_origin)
|
|||
|
elif 'dataset' in _data_cfg:
|
|||
|
_retrieve_data_cfg(_data_cfg['dataset'], skip_type,
|
|||
|
show_origin)
|
|||
|
else:
|
|||
|
raise ValueError
|
|||
|
elif 'dataset' in train_data_cfg:
|
|||
|
_retrieve_data_cfg(train_data_cfg['dataset'], skip_type, show_origin)
|
|||
|
else:
|
|||
|
_retrieve_data_cfg(train_data_cfg, skip_type, show_origin)
|
|||
|
return cfg
|
|||
|
|
|||
|
|
|||
|
def main():
|
|||
|
args = parse_args()
|
|||
|
cfg = retrieve_data_cfg(args.config, args.skip_type, args.show_origin)
|
|||
|
dataset = build_dataset(cfg.data.train)
|
|||
|
progress_bar = mmcv.ProgressBar(len(dataset))
|
|||
|
for item in dataset:
|
|||
|
filename = os.path.join(args.output_dir,
|
|||
|
Path(item['filename']).name
|
|||
|
) if args.output_dir is not None else None
|
|||
|
imshow_semantic(
|
|||
|
item['img'],
|
|||
|
item['gt_semantic_seg'],
|
|||
|
dataset.CLASSES,
|
|||
|
dataset.PALETTE,
|
|||
|
show=args.show,
|
|||
|
wait_time=args.show_interval,
|
|||
|
out_file=filename,
|
|||
|
opacity=args.opacity,
|
|||
|
)
|
|||
|
progress_bar.update()
|
|||
|
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
main()
|