[Feature] Add pipeline visualization tools. (#406)

* add vis

* add tool vis-pipeline

* add docs

* Update docs

* pre-commit

* enhence english expression

* Add `BaseImshowContextmanager` and `ImshowInfosContextManager` to reuse
matplotlib figure.

* Use context manager to implement `imshow_infos`

* Add unit tests.

* More general base context manager.

* unit tests for context manager.

* Improve docstring.

* Fix context manager exit cannot close figure when matplotlib>=3.4.0

* Fix unit tests

* fix lint

* fix lint

* add adaptive

* add adaptive

* update adaptive

* add GAP

* improve doc and docstring

* add visualization in doc index

* Update doc

* Update doc

* Update doc

* Update doc

* Update doc

* Update doc

* update docs and docstring

* add progressbar

* add progressbar

* add images

* add images

* Delete .DS_Store

* replace images

* replace images and modify rgb2bgr

* add picture size

* mv pictures

* update img display

* add doc_zh-CN images

* Update vis_pipeline.py

* Update visualization.md

* Update visualization.md

* fix lint

* Improve docs.

Co-authored-by: mzr1996 <mzr1996@163.com>
pull/400/merge
Ezra-Yu 2021-10-20 10:28:21 +08:00 committed by GitHub
parent 2932f9d8a3
commit 9dbe58bf8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 430 additions and 0 deletions

BIN
docs/_static/image/concat.JPEG vendored 100644

Binary file not shown.

After

Width:  |  Height:  |  Size: 44 KiB

BIN
docs/_static/image/original.JPEG vendored 100644

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.2 KiB

BIN
docs/_static/image/pipeline.JPEG vendored 100644

Binary file not shown.

After

Width:  |  Height:  |  Size: 19 KiB

View File

@ -34,6 +34,7 @@ You can switch between Chinese and English documents in the lower-left corner of
tools/onnx2tensorrt.md
tools/pytorch2torchscript.md
tools/model_serving.md
tools/visualization.md
.. toctree::

View File

@ -0,0 +1,81 @@
# Visualization
<!-- TOC -->
- [Visualization](#visualization)
- [Pipeline Visualization](#pipeline-visualization)
- [Usage](#usage)
- [FAQs](#faqs)
<!-- TOC -->
## pipeline visualization
### Usage
```bash
python tools/visualizations/vis_pipeline.py \
${CONFIG_FILE} \
--output-dir ${OUTPUT_DIR} \
--phase ${DATASET_PHASE} \
--number ${BUNBER_IMAGES_DISPLAY} \
--skip-type ${SKIP_TRANSFORM_TYPE}
--mode ${DISPLAY_MODE} \
--show \
--adaptive \
--min-edge-length ${MIN_EDGE_LENGTH} \
--max-edge-length ${MAX_EDGE_LENGTH} \
--bgr2rgb \
--window-size ${WINDOW_SIZE}
```
**Description of all arguments**
- `config` : The path of a model config file.
- `--output-dir`: The output path for visualized images. If not specified, it will be set to `''`, which means not to save.
- `--phase`: Phase of visualizing datasetmust be one of `[train, val, test]`. If not specified, it will be set to `train`.
- `--number`: The number of samples to visualize. If not specified, display all images in the dataset.
- `--skip-type`: The pipelines to be skipped. If not specified, it will be set to `['ToTensor', 'Normalize', 'ImageToTensor', 'Collect']`.
- `--mode`: The display mode, can be one of `[original, pipeline, concat]`. If not specified, it will be set to `concat`.
- `--show`: If set, display pictures in pop-up windows.
- `--adaptive`: If set, automatically adjust the size of the visualization images.
- `--min-edge-length`: The minimum edge length, used when `--adaptive` is set. When any side of the picture is smaller than `${MIN_EDGE_LENGTH}`, the picture will be enlarged while keeping the aspect ratio unchanged, and the short side will be aligned to `${MIN_EDGE_LENGTH}`. If not specified, it will be set to 200.
- `--max-edge-length`: The maximum edge length, used when `--adaptive` is set. When any side of the picture is larger than `${MAX_EDGE_LENGTH}`, the picture will be reduced while keeping the aspect ratio unchanged, and the long side will be aligned to `${MAX_EDGE_LENGTH}`. If not specified, it will be set to 1000.
- `--bgr2rgb`: If set, flip the color channel order of images.
- `--window-size`: The shape of the display window. If not specified, it will be set to `12*7`. If used, it must be in the format `'W*H'`.
```{note}
1. If the `--mode` is not specified, it will be set to `concat` as default, get the pictures stitched together by original pictures and transformed pictures; if the `--mode` is set to `original`, get the original pictures; if the `--mode` is set to `pipeline`, get the transformed pictures.
2. When `--adaptive` option is set, images that are too large or too small will be automatically adjusted, you can use `--min-edge-length` and `--max-edge-length` to set the adjust size.
```
**Examples**
1. Visualize all the transformed pictures of the `ImageNet` training set and display them in pop-up windows
```shell
python ./tools/visualizations/vis_pipeline.py ./configs/resnet/resnet50_b32x8_imagenet.py --show --mode pipeline
```
<div align=center><img src="../_static/image/pipeline.JPEG" style=" width: auto; height: 40%; "></div>
2. Visualize 10 comparison pictures in the `ImageNet` train set and save them in the `./tmp` folder
```shell
python ./tools/visualizations/vis_pipeline.py configs/swin_transformer/swin_base_224_b16x64_300e_imagenet.py --phase train --output-dir tmp --number 10 --adaptive
```
<div align=center><img src="../_static/image/concat.JPEG" style=" width: auto; height: 40%; "></div>
3. Visualize 100 original pictures in the `CIFAR100` validation set, then display and save them in the `./tmp` folder
```shell
python ./tools/visualizations/vis_pipeline.py configs/resnet/resnet50_b16x8_cifar100.py --phase val --output-dir tmp --mode original --number 100 --show --adaptive --bgr2rgb
```
<div align=center><img src="../_static/image/original.JPEG" style=" width: auto; height: 40%; "></div>
## FAQs
- None

View File

@ -142,3 +142,7 @@ More supported backends can be found in [mmcv.fileio.FileClient](https://github.
dict(type='Collect', keys=['img', 'gt_label'])
]
```
## Pipeline visualization
After designing data pipelines, you can use the [visualization tools](../tools/visualization.md) to view the performance.

Binary file not shown.

After

Width:  |  Height:  |  Size: 44 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 19 KiB

View File

@ -34,6 +34,7 @@ You can switch between Chinese and English documents in the lower-left corner of
tools/onnx2tensorrt.md
tools/pytorch2torchscript.md
tools/model_serving.md
tools/visualization.md
.. toctree::

View File

@ -0,0 +1,82 @@
# 可视化
<!-- TOC -->
- [可视化](#可视化)
- [数据流水线可视化](#数据流水线可视化)
- [使用方法](#使用方法)
- [常见问题](#常见问题)
<!-- TOC -->
## 数据流水线可视化
### 使用方法
```bash
python tools/visualizations/vis_pipeline.py \
${CONFIG_FILE} \
--output-dir ${OUTPUT_DIR} \
--phase ${DATASET_PHASE} \
--number ${BUNBER_IMAGES_DISPLAY} \
--skip-type ${SKIP_TRANSFORM_TYPE} \
--mode ${DISPLAY_MODE} \
--show \
--adaptive \
--min-edge-length ${MIN_EDGE_LENGTH} \
--max-edge-length ${MAX_EDGE_LENGTH} \
--bgr2rgb \
--window-size ${WINDOW_SIZE}
```
**所有参数的说明**
- `config` : 模型配置文件的路径。
- `--output-dir`: 保存图片文件夹,如果没有指定,默认为 `''`,表示不保存。
- `--phase`: 可视化数据集的阶段,只能为 `[train, val, test]` 之一,默认为 `train`
- `--number`: 可视化样本数量。如果没有指定,默认展示数据集的所有图片。
- `--skip-type`: 预设跳过的数据流水线过程。如果没有指定,默认为 `['ToTensor', 'Normalize', 'ImageToTensor', 'Collect']`
- `--mode`: 可视化的模式,只能为 `[original, pipeline, concat]` 之一,如果没有指定,默认为 `concat`
- `--show`: 将可视化图片以弹窗形式展示。
- `--adaptive`: 自动调节可视化图片的大小。
- `--min-edge-length`: 最短边长度,当使用了 `--adaptive` 时有效。 当图片任意边小于 `${MIN_EDGE_LENGTH}` 时,会保持长宽比不变放大图片,短边对齐至 `${MIN_EDGE_LENGTH}`默认为200。
- `--max-edge-length`: 最长边长度,当使用了 `--adaptive` 时有效。 当图片任意边大于 `${MAX_EDGE_LENGTH}` 时,会保持长宽比不变缩小图片,短边对齐至 `${MAX_EDGE_LENGTH}`默认为1000。
- `--bgr2rgb`: 将图片的颜色通道翻转。
- `--window-size`: 可视化窗口大小,如果没有指定,默认为 `12*7`。如果需要指定,按照格式 `'W*H'`
```{note}
1. 如果不指定 `--mode`,默认设置为 `concat`,获取原始图片和预处理后图片拼接的图片;如果 `--mode` 设置为 `original`,则获取原始图片; 如果 `--mode` 设置为 `pipeline`,则获取预处理后的图片。
2. 当指定了 `--adaptive` 选项时,会自动的调整尺寸过大和过小的图片,你可以通过设定 `--min-edge-length``--max-edge-length` 来指定自动调整的图片尺寸。
```
**示例**
1. 可视化 `ImageNet` 训练集的所有经过预处理的图片,并以弹窗形式显示:
```shell
python ./tools/visualizations/vis_pipeline.py ./configs/resnet/resnet50_b32x8_imagenet.py --show --mode pipeline
```
<div align=center><img src="../_static/image/pipeline.JPEG" style=" width: auto; height: 40%; "></div>
2. 可视化 `ImageNet` 训练集的10张原始图片与预处理后图片对比图保存在 `./tmp` 文件夹下:
```shell
python ./tools/visualizations/vis_pipeline.py configs/swin_transformer/swin_base_224_b16x64_300e_imagenet.py --phase train --output-dir tmp --number 10 --adaptive
```
<div align=center><img src="../_static/image/concat.JPEG" style=" width: auto; height: 40%; "></div>
3. 可视化 `CIFAR100` 验证集中的100张原始图片显示并保存在 `./tmp` 文件夹下:
```shell
python ./tools/visualizations/vis_pipeline.py configs/resnet/resnet50_b16x8_cifar100.py --phase val --output-dir tmp --mode original --number 100 --show --adaptive --bgr2rgb
```
<div align=center><img src="../_static/image/original.JPEG" style=" width: auto; height: 40%; "></div>
## 常见问题
- 无

View File

@ -142,3 +142,7 @@ train_pipeline = [
dict(type='Collect', keys=['img', 'gt_label'])
]
```
## 流水线可视化
设计好数据流水线后,可以使用[可视化工具](../tools/visualization.md)查看具体的效果。

View File

@ -0,0 +1,257 @@
import argparse
import itertools
import os
import re
import sys
from pathlib import Path
import mmcv
import numpy as np
from mmcv import Config, DictAction, ProgressBar
from mmcls.core import visualization as vis
from mmcls.datasets.builder import build_dataset
from mmcls.datasets.pipelines import Compose
def parse_args():
parser = argparse.ArgumentParser(
description='Visualize a Dataset Pipeline')
parser.add_argument('config', help='config file path')
parser.add_argument(
'--skip-type',
type=str,
nargs='*',
default=['ToTensor', 'Normalize', 'ImageToTensor', 'Collect'],
help='the pipelines to skip when visualizing')
parser.add_argument(
'--output-dir',
default='',
type=str,
help='folder to save output pictures, if not set, do not save.')
parser.add_argument(
'--phase',
default='train',
type=str,
choices=['train', 'test', 'val'],
help='phase of dataset to visualize, accept "train" "test" and "val".'
' Default train.')
parser.add_argument(
'--number',
type=int,
default=sys.maxsize,
help='number of images selected to visualize, must bigger than 0. if '
'the number is bigger than length of dataset, show all the images in '
'dataset; default "sys.maxsize", show all images in dataset')
parser.add_argument(
'--mode',
default='concat',
type=str,
choices=['original', 'pipeline', 'concat'],
help='display mode; display original pictures or transformed pictures'
' or comparison pictures. "original" means show images load from disk;'
' "pipeline" means to show images after pipeline; "concat" means show '
'images stitched by "original" and "pipeline" images. Default concat.')
parser.add_argument(
'--show',
default=False,
action='store_true',
help='whether to display images in pop-up window. Default False.')
parser.add_argument(
'--adaptive',
default=False,
action='store_true',
help='whether to automatically adjust the visualization image size')
parser.add_argument(
'--min-edge-length',
default=200,
type=int,
help='the min edge length when visualizing images, used when '
'"--adaptive" is true. Default 200.')
parser.add_argument(
'--max-edge-length',
default=1000,
type=int,
help='the max edge length when visualizing images, used when '
'"--adaptive" is true. Default 1000.')
parser.add_argument(
'--bgr2rgb',
default=False,
action='store_true',
help='flip the color channel order of images')
parser.add_argument(
'--window-size',
default='12*7',
help='size of the window to display images, in format of "$W*$H".')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--show-options',
nargs='+',
action=DictAction,
help='custom options for display. key-value pair in xxx=yyy. options '
'in `mmcls.core.visualization.ImshowInfosContextManager.put_img_infos`'
)
args = parser.parse_args()
assert args.number > 0, "'args.number' must be larger than zero."
if args.window_size != '':
assert re.match(r'\d+\*\d+', args.window_size), \
"'window-size' must be in format 'W*H'."
if args.output_dir == '' and not args.show:
raise ValueError("if '--output-dir' and '--show' are not set, "
'nothing will happen when the program running.')
if args.show_options is None:
args.show_options = {}
return args
def retrieve_data_cfg(config_path, skip_type, cfg_options, phase):
cfg = Config.fromfile(config_path)
if cfg_options is not None:
cfg.merge_from_dict(cfg_options)
# import modules from string list.
if cfg.get('custom_imports', None):
from mmcv.utils import import_modules_from_strings
import_modules_from_strings(**cfg['custom_imports'])
data_cfg = cfg.data[phase]
while 'dataset' in data_cfg:
data_cfg = data_cfg['dataset']
data_cfg['pipeline'] = [
x for x in data_cfg.pipeline if x['type'] not in skip_type
]
return cfg
def build_dataset_pipeline(cfg, phase):
"""build dataset and pipeline from config.
Separate the pipeline except 'LoadImageFromFile' step if
'LoadImageFromFile' in the pipeline.
"""
data_cfg = cfg.data[phase]
loadimage_pipeline = []
if len(data_cfg.pipeline
) != 0 and data_cfg.pipeline[0]['type'] == 'LoadImageFromFile':
loadimage_pipeline.append(data_cfg.pipeline.pop(0))
origin_pipeline = data_cfg.pipeline
data_cfg.pipeline = loadimage_pipeline
dataset = build_dataset(data_cfg)
pipeline = Compose(origin_pipeline)
return dataset, pipeline
def put_img(board, img, center):
"""put a image into a big board image with the anchor center."""
center_x, center_y = center
img_h, img_w, _ = img.shape
xmin, ymin = int(center_x - img_w // 2), int(center_y - img_h // 2)
board[ymin:ymin + img_h, xmin:xmin + img_w, :] = img
return board
def concat(left_img, right_img):
"""Concat two pictures into a single big picture, accepts two images with
diffenert shapes."""
GAP = 10
left_h, left_w, _ = left_img.shape
right_h, right_w, _ = right_img.shape
# create a big board to contain images with shape (board_h, board_w*2+10)
board_h, board_w = max(left_h, right_h), max(left_w, right_w)
board = np.ones([board_h, 2 * board_w + GAP, 3], np.uint8) * 255
put_img(board, left_img, (int(board_w // 2), int(board_h // 2)))
put_img(board, right_img,
(int(board_w // 2) + board_w + GAP // 2, int(board_h // 2)))
return board
def adaptive_size(mode, image, min_edge_length, max_edge_length):
"""rescale image if image is too small to put text like cifra."""
assert min_edge_length >= 0 and max_edge_length >= 0
assert max_edge_length >= min_edge_length
image_h, image_w, *_ = image.shape
image_w = image_w // 2 if mode == 'concat' else image_w
if image_h < min_edge_length or image_w < min_edge_length:
image = mmcv.imrescale(
image, min(min_edge_length / image_h, min_edge_length / image_h))
if image_h > max_edge_length or image_w > max_edge_length:
image = mmcv.imrescale(
image, max(max_edge_length / image_h, max_edge_length / image_w))
return image
def get_display_img(item, pipeline, mode, bgr2rgb):
"""get image to display."""
if bgr2rgb:
item['img'] = mmcv.bgr2rgb(item['img'])
src_image = item['img'].copy()
# get transformed picture
if mode in ['pipeline', 'concat']:
item = pipeline(item)
trans_image = item['img']
trans_image = np.ascontiguousarray(trans_image, dtype=np.uint8)
if mode == 'concat':
image = concat(src_image, trans_image)
elif mode == 'original':
image = src_image
elif mode == 'pipeline':
image = trans_image
return image
def main():
args = parse_args()
wind_w, wind_h = args.window_size.split('*')
wind_w, wind_h = int(wind_w), int(wind_h)
cfg = retrieve_data_cfg(args.config, args.skip_type, args.cfg_options,
args.phase)
dataset, pipeline = build_dataset_pipeline(cfg, args.phase)
CLASSES = dataset.CLASSES
display_number = min(args.number, len(dataset))
progressBar = ProgressBar(display_number)
with vis.ImshowInfosContextManager(fig_size=(wind_w, wind_h)) as manager:
for i, item in enumerate(itertools.islice(dataset, display_number)):
image = get_display_img(item, pipeline, args.mode, args.bgr2rgb)
if args.adaptive:
image = adaptive_size(args.mode, image, args.min_edge_length,
args.max_edge_length)
# dist_path is None as default, means not save pictures
dist_path = None
if args.output_dir:
# some datasets do not have filename, such as cifar, use id
src_path = item.get('filename', '{}.jpg'.format(i))
dist_path = os.path.join(args.output_dir, Path(src_path).name)
infos = dict(label=CLASSES[item['gt_label']])
manager.put_img_infos(
image,
infos,
font_size=20,
out_file=dist_path,
show=args.show,
**args.show_options)
progressBar.update()
if __name__ == '__main__':
main()