mmpretrain/tools/visualization/browse_dataset.py

204 lines
6.9 KiB
Python
Raw Normal View History

2022-05-30 03:11:44 +00:00
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
import sys
import mmcv
2022-06-16 13:33:19 +00:00
from mmengine.config import Config, DictAction
from mmengine.dataset import Compose
from mmengine.registry import init_default_scope
from mmengine.utils import ProgressBar
from mmengine.visualization.utils import img_from_canvas
2022-05-30 03:11:44 +00:00
from mmpretrain.datasets.builder import build_dataset
from mmpretrain.visualization import UniversalVisualizer, create_figure
2022-05-30 03:11:44 +00:00
def parse_args():
parser = argparse.ArgumentParser(description='Browse a dataset')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--output-dir',
'-o',
2022-06-16 13:33:19 +00:00
default=None,
2022-05-30 03:11:44 +00:00
type=str,
2022-06-16 13:33:19 +00:00
help='If there is no display interface, you can save it.')
2022-05-30 03:11:44 +00:00
parser.add_argument('--not-show', default=False, action='store_true')
parser.add_argument(
'--phase',
'-p',
2022-05-30 03:11:44 +00:00
default='train',
type=str,
choices=['train', 'test', 'val'],
help='phase of dataset to visualize, accept "train" "test" and "val".'
2022-06-16 13:33:19 +00:00
' Defaults to "train".')
2022-05-30 03:11:44 +00:00
parser.add_argument(
'--show-number',
'-n',
2022-05-30 03:11:44 +00:00
type=int,
default=sys.maxsize,
help='number of images selected to visualize, must bigger than 0. if '
'the number is bigger than length of dataset, show all the images in '
'dataset; default "sys.maxsize", show all images in dataset')
parser.add_argument(
'--show-interval',
'-i',
2022-05-30 03:11:44 +00:00
type=float,
default=2,
help='the interval of show (s)')
2022-06-16 13:33:19 +00:00
parser.add_argument(
'--mode',
'-m',
2022-06-16 13:33:19 +00:00
default='transformed',
type=str,
choices=['original', 'transformed', 'concat', 'pipeline'],
help='display mode; display original pictures or transformed pictures'
' or comparison pictures. "original" means show images load from disk'
'; "transformed" means to show images after transformed; "concat" '
'means show images stitched by "original" and "output" images. '
'"pipeline" means show all the intermediate images. '
'Defaults to "transformed".')
2022-05-30 03:11:44 +00:00
parser.add_argument(
'--rescale-factor',
'-r',
2022-05-30 03:11:44 +00:00
type=float,
help='image rescale factor, which is useful if the output is too '
'large or too small.')
parser.add_argument(
'--channel-order',
'-c',
default='BGR',
choices=['BGR', 'RGB'],
help='The channel order of the showing images, could be "BGR" '
'or "RGB", Defaults to "BGR".')
2022-05-30 03:11:44 +00:00
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
args = parser.parse_args()
return args
2022-06-16 13:33:19 +00:00
def make_grid(imgs, names, rescale_factor=None):
"""Concat list of pictures into a single big picture, align height here."""
figure = create_figure()
gs = figure.add_gridspec(1, len(imgs))
2022-06-16 13:33:19 +00:00
ori_shapes = [img.shape[:2] for img in imgs]
if rescale_factor is not None:
imgs = [mmcv.imrescale(img, rescale_factor) for img in imgs]
for i, img in enumerate(imgs):
subplot = figure.add_subplot(gs[0, i])
subplot.axis(False)
subplot.imshow(img)
subplot.set_title(f'{names[i]}\n{ori_shapes[i]}')
return img_from_canvas(figure.canvas)
2022-06-16 13:33:19 +00:00
class InspectCompose(Compose):
"""Compose multiple transforms sequentially.
And record "img" field of all results in one list.
"""
def __init__(self, transforms, intermediate_imgs):
super().__init__(transforms=transforms)
self.intermediate_imgs = intermediate_imgs
def __call__(self, data):
if 'img' in data:
self.intermediate_imgs.append({
'name': 'Original',
2022-06-16 13:33:19 +00:00
'img': data['img'].copy()
})
for t in self.transforms:
data = t(data)
if data is None:
return None
if 'img' in data:
self.intermediate_imgs.append({
'name': t.__class__.__name__,
'img': data['img'].copy()
})
return data
2022-05-30 03:11:44 +00:00
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
init_default_scope('mmpretrain') # Use mmpretrain as default scope.
2022-05-30 03:11:44 +00:00
2022-06-16 13:33:19 +00:00
dataset_cfg = cfg.get(args.phase + '_dataloader').get('dataset')
dataset = build_dataset(dataset_cfg)
intermediate_imgs = []
dataset.pipeline = InspectCompose(dataset.pipeline.transforms,
intermediate_imgs)
# init visualizer
cfg.visualizer.pop('type')
visualizer = UniversalVisualizer(**cfg.visualizer)
2022-05-30 03:11:44 +00:00
visualizer.dataset_meta = dataset.metainfo
2022-06-16 13:33:19 +00:00
# init visualization image number
2022-05-30 03:11:44 +00:00
display_number = min(args.show_number, len(dataset))
progress_bar = ProgressBar(display_number)
2022-05-30 03:11:44 +00:00
2022-06-16 13:33:19 +00:00
for i, item in zip(range(display_number), dataset):
rescale_factor = args.rescale_factor
if args.mode == 'original':
image = intermediate_imgs[0]['img']
elif args.mode == 'transformed':
image = intermediate_imgs[-1]['img']
elif args.mode == 'concat':
ori_image = intermediate_imgs[0]['img']
trans_image = intermediate_imgs[-1]['img']
image = make_grid([ori_image, trans_image],
['original', 'transformed'], rescale_factor)
rescale_factor = None
else:
image = make_grid([result['img'] for result in intermediate_imgs],
[result['name'] for result in intermediate_imgs],
rescale_factor)
rescale_factor = None
intermediate_imgs.clear()
data_sample = item['data_samples'].numpy()
2022-05-30 03:11:44 +00:00
2022-06-16 13:33:19 +00:00
# get filename from dataset or just use index as filename
if hasattr(item['data_samples'], 'img_path'):
filename = osp.basename(item['data_samples'].img_path)
2022-06-16 13:33:19 +00:00
else:
# some dataset have not image path
filename = f'{i}.jpg'
out_file = osp.join(args.output_dir,
filename) if args.output_dir is not None else None
2022-05-30 03:11:44 +00:00
visualizer.visualize_cls(
image if args.channel_order == 'RGB' else image[..., ::-1],
2022-05-30 03:11:44 +00:00
data_sample,
2022-06-16 13:33:19 +00:00
rescale_factor=rescale_factor,
2022-05-30 03:11:44 +00:00
show=not args.not_show,
wait_time=args.show_interval,
name=filename,
2022-05-30 03:11:44 +00:00
out_file=out_file)
progress_bar.update()
if __name__ == '__main__':
main()