[Enhance] Enhance vis-pipeline tool. (#604)

* enhance vis-pipeline add intermediate imgs

* enhance vis-pipeline add intermediate imgs

* improve code of vi-pipeline

* modify docs for vis-pipeline

* Use `mmcv.utils.digit_version` instead of `distutils`

* add size info in the bottom

* preform adaptive-resize in before concat

* add warning info

* fix docs

* fix lint

* fix comment

* fix docs

Co-authored-by: mzr1996 <mzr1996@163.com>
pull/722/head
Ezra-Yu 2022-03-04 14:40:02 +08:00 committed by GitHub
parent 779a06257c
commit d08c2a148a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 237 additions and 140 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 44 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 9.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 19 KiB

View File

@ -1,4 +1,4 @@
# MISCELLANEOUS
# Miscellaneous
<!-- TOC -->

View File

@ -2,7 +2,6 @@
<!-- TOC -->
- [Visualization](#visualization)
- [Pipeline Visualization](#pipeline-visualization)
- [Learning Rate Schedule Visualization](#learning-rate-schedule-visualization)
- [Class Activation Map Visualization](#class-activation-map-visualization)
@ -14,17 +13,18 @@
```bash
python tools/visualizations/vis_pipeline.py \
${CONFIG_FILE} \
--output-dir ${OUTPUT_DIR} \
--phase ${DATASET_PHASE} \
--number ${BUNBER_IMAGES_DISPLAY} \
--skip-type ${SKIP_TRANSFORM_TYPE}
--mode ${DISPLAY_MODE} \
--show \
--adaptive \
--min-edge-length ${MIN_EDGE_LENGTH} \
--max-edge-length ${MAX_EDGE_LENGTH} \
--bgr2rgb \
--window-size ${WINDOW_SIZE}
[--output-dir ${OUTPUT_DIR}] \
[--phase ${DATASET_PHASE}] \
[--number ${BUNBER_IMAGES_DISPLAY}] \
[--skip-type ${SKIP_TRANSFORM_TYPE}] \
[--mode ${DISPLAY_MODE}] \
[--show] \
[--adaptive] \
[--min-edge-length ${MIN_EDGE_LENGTH}] \
[--max-edge-length ${MAX_EDGE_LENGTH}] \
[--bgr2rgb] \
[--window-size ${WINDOW_SIZE}] \
[--cfg-options ${CFG_OPTIONS}]
```
**Description of all arguments**
@ -32,48 +32,57 @@ python tools/visualizations/vis_pipeline.py \
- `config` : The path of a model config file.
- `--output-dir`: The output path for visualized images. If not specified, it will be set to `''`, which means not to save.
- `--phase`: Phase of visualizing datasetmust be one of `[train, val, test]`. If not specified, it will be set to `train`.
- `--number`: The number of samples to visualize. If not specified, display all images in the dataset.
- `--number`: The number of samples to visualized. If not specified, display all images in the dataset.
- `--skip-type`: The pipelines to be skipped. If not specified, it will be set to `['ToTensor', 'Normalize', 'ImageToTensor', 'Collect']`.
- `--mode`: The display mode, can be one of `[original, pipeline, concat]`. If not specified, it will be set to `concat`.
- `--show`: If set, display pictures in pop-up windows.
- `--adaptive`: If set, automatically adjust the size of the visualization images.
- `--adaptive`: If set, adaptively resize images for better visualization.
- `--min-edge-length`: The minimum edge length, used when `--adaptive` is set. When any side of the picture is smaller than `${MIN_EDGE_LENGTH}`, the picture will be enlarged while keeping the aspect ratio unchanged, and the short side will be aligned to `${MIN_EDGE_LENGTH}`. If not specified, it will be set to 200.
- `--max-edge-length`: The maximum edge length, used when `--adaptive` is set. When any side of the picture is larger than `${MAX_EDGE_LENGTH}`, the picture will be reduced while keeping the aspect ratio unchanged, and the long side will be aligned to `${MAX_EDGE_LENGTH}`. If not specified, it will be set to 1000.
- `--bgr2rgb`: If set, flip the color channel order of images.
- `--window-size`: The shape of the display window. If not specified, it will be set to `12*7`. If used, it must be in the format `'W*H'`.
- `--cfg-options` : Modifications to the configuration file, refer to [Tutorial 1: Learn about Configs](https://mmclassification.readthedocs.io/en/latest/tutorials/config.html).
```{note}
1. If the `--mode` is not specified, it will be set to `concat` as default, get the pictures stitched together by original pictures and transformed pictures; if the `--mode` is set to `original`, get the original pictures; if the `--mode` is set to `pipeline`, get the transformed pictures.
1. If the `--mode` is not specified, it will be set to `concat` as default, get the pictures stitched together by original pictures and transformed pictures; if the `--mode` is set to `original`, get the original pictures; if the `--mode` is set to `transformed`, get the transformed pictures; if the `--mode` is set to `pipeline`, get all the intermediate images through the pipeline.
2. When `--adaptive` option is set, images that are too large or too small will be automatically adjusted, you can use `--min-edge-length` and `--max-edge-length` to set the adjust size.
```
**Examples**
1. Visualize all the transformed pictures of the `ImageNet` training set and display them in pop-up windows
```shell
python ./tools/visualizations/vis_pipeline.py ./configs/resnet/resnet50_8xb32_in1k.py --show --mode pipeline
```
<div align=center><img src="../_static/image/tools/visualization/pipeline-pipeline.jpg" style=" width: auto; height: 40%; "></div>
2. Visualize 10 comparison pictures in the `ImageNet` train set and save them in the `./tmp` folder
```shell
python ./tools/visualizations/vis_pipeline.py configs/swin_transformer/swin_base_224_b16x64_300e_imagenet.py --phase train --output-dir tmp --number 10 --adaptive
```
<div align=center><img src="../_static/image/tools/visualization/pipeline-concat.jpg" style=" width: auto; height: 40%; "></div>
3. Visualize 100 original pictures in the `CIFAR100` validation set, then display and save them in the `./tmp` folder
1. In **'original'** mode, visualize 100 original pictures in the `CIFAR100` validation set, then display and save them in the `./tmp` folder
```shell
python ./tools/visualizations/vis_pipeline.py configs/resnet/resnet50_8xb16_cifar100.py --phase val --output-dir tmp --mode original --number 100 --show --adaptive --bgr2rgb
```
<div align=center><img src="../_static/image/tools/visualization/pipeline-original.jpg" style=" width: auto; height: 40%; "></div>
<div align=center><img src="https://user-images.githubusercontent.com/18586273/146117528-1ec2d918-57f8-4ae4-8ca3-a8d31b602f64.jpg" style=" width: auto; height: 40%; "></div>
2. In **'transformed'** mode, visualize all the transformed pictures of the `ImageNet` training set and display them in pop-up windows
```shell
python ./tools/visualizations/vis_pipeline.py ./configs/resnet/resnet50_8xb32_in1k.py --show --mode transformed
```
<div align=center><img src="https://user-images.githubusercontent.com/18586273/146117553-8006a4ba-e2fa-4f53-99bc-42a4b06e413f.jpg" style=" width: auto; height: 40%; "></div>
3. In **'concat'** mode, visualize 10 pairs of origin and transformed images for comparison in the `ImageNet` train set and save them in the `./tmp` folder
```shell
python ./tools/visualizations/vis_pipeline.py configs/swin_transformer/swin_base_224_b16x64_300e_imagenet.py --phase train --output-dir tmp --number 10 --adaptive
```
<div align=center><img src="https://user-images.githubusercontent.com/18586273/146128259-0a369991-7716-411d-8c27-c6863e6d76ea.JPEG" style=" width: auto; height: 40%; "></div>
4. In **'pipeline'** mode, visualize all the intermediate pictures in the `ImageNet` train set through the pipeline
```shell
python ./tools/visualizations/vis_pipeline.py configs/swin_transformer/swin_base_224_b16x64_300e_imagenet.py --phase train --adaptive --mode pipeline --show
```
<div align=center><img src="https://user-images.githubusercontent.com/18586273/146128201-eb97c2aa-a615-4a81-a649-38db1c315d0e.JPEG" style=" width: auto; height: 40%; "></div>
## Learning Rate Schedule Visualization

Binary file not shown.

Before

Width:  |  Height:  |  Size: 44 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 9.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 19 KiB

View File

@ -2,7 +2,6 @@
<!-- TOC -->
- [可视化](#可视化)
- [数据流水线可视化](#数据流水线可视化)
- [学习率策略可视化](#学习率策略可视化)
- [类别激活图可视化](#类别激活图可视化)
@ -15,17 +14,18 @@
```bash
python tools/visualizations/vis_pipeline.py \
${CONFIG_FILE} \
--output-dir ${OUTPUT_DIR} \
--phase ${DATASET_PHASE} \
--number ${BUNBER_IMAGES_DISPLAY} \
--skip-type ${SKIP_TRANSFORM_TYPE} \
--mode ${DISPLAY_MODE} \
--show \
--adaptive \
--min-edge-length ${MIN_EDGE_LENGTH} \
--max-edge-length ${MAX_EDGE_LENGTH} \
--bgr2rgb \
--window-size ${WINDOW_SIZE}
[--output-dir ${OUTPUT_DIR}] \
[--phase ${DATASET_PHASE}] \
[--number ${BUNBER_IMAGES_DISPLAY}] \
[--skip-type ${SKIP_TRANSFORM_TYPE}] \
[--mode ${DISPLAY_MODE}] \
[--show] \
[--adaptive] \
[--min-edge-length ${MIN_EDGE_LENGTH}] \
[--max-edge-length ${MAX_EDGE_LENGTH}] \
[--bgr2rgb] \
[--window-size ${WINDOW_SIZE}] \
[--cfg-options ${CFG_OPTIONS}]
```
**所有参数的说明**
@ -35,71 +35,80 @@ python tools/visualizations/vis_pipeline.py \
- `--phase`: 可视化数据集的阶段,只能为 `[train, val, test]` 之一,默认为 `train`
- `--number`: 可视化样本数量。如果没有指定,默认展示数据集的所有图片。
- `--skip-type`: 预设跳过的数据流水线过程。如果没有指定,默认为 `['ToTensor', 'Normalize', 'ImageToTensor', 'Collect']`
- `--mode`: 可视化的模式,只能为 `[original, pipeline, concat]` 之一,如果没有指定,默认为 `concat`
- `--mode`: 可视化的模式,只能为 `[original, transformed, concat, pipeline]` 之一,如果没有指定,默认为 `concat`
- `--show`: 将可视化图片以弹窗形式展示。
- `--adaptive`: 自动调节可视化图片的大小。
- `--min-edge-length`: 最短边长度,当使用了 `--adaptive` 时有效。 当图片任意边小于 `${MIN_EDGE_LENGTH}` 时,会保持长宽比不变放大图片,短边对齐至 `${MIN_EDGE_LENGTH}`默认为200。
- `--max-edge-length`: 最长边长度,当使用了 `--adaptive` 时有效。 当图片任意边大于 `${MAX_EDGE_LENGTH}` 时,会保持长宽比不变缩小图片,短边对齐至 `${MAX_EDGE_LENGTH}`默认为1000。
- `--bgr2rgb`: 将图片的颜色通道翻转。
- `--window-size`: 可视化窗口大小,如果没有指定,默认为 `12*7`。如果需要指定,按照格式 `'W*H'`
- `--cfg-options` : 对配置文件的修改,参考[教程 1如何编写配置文件](https://mmclassification.readthedocs.io/zh_CN/latest/tutorials/config.html)。
```{note}
1. 如果不指定 `--mode`,默认设置为 `concat`,获取原始图片和预处理后图片拼接的图片;如果 `--mode` 设置为 `original`,则获取原始图片; 如果 `--mode` 设置为 `pipeline`,则获取预处理后的图片。
1. 如果不指定 `--mode`,默认设置为 `concat`,获取原始图片和预处理后图片拼接的图片;如果 `--mode` 设置为 `original`,则获取原始图片;如果 `--mode` 设置为 `transformed`,则获取预处理后的图片;如果 `--mode` 设置为 `pipeline`,则获得数据流水线所有中间过程图片。
2. 当指定了 `--adaptive` 选项时,会自动的调整尺寸过大和过小的图片,你可以通过设定 `--min-edge-length``--max-edge-length` 来指定自动调整的图片尺寸。
```
**示例**
1. 可视化 `ImageNet` 训练集的所有经过预处理的图片,并以弹窗形式显示:
```shell
python ./tools/visualizations/vis_pipeline.py ./configs/resnet/resnet50_8xb32_in1k.py --show --mode pipeline
```
<div align=center><img src="../_static/image/tools/visualization/pipeline-pipeline.jpg" style=" width: auto; height: 40%; "></div>
2. 可视化 `ImageNet` 训练集的10张原始图片与预处理后图片对比图保存在 `./tmp` 文件夹下:
```shell
python ./tools/visualizations/vis_pipeline.py configs/swin_transformer/swin_base_224_b16x64_300e_imagenet.py --phase train --output-dir tmp --number 10 --adaptive
```
<div align=center><img src="../_static/image/tools/visualization/pipeline-concat.jpg" style=" width: auto; height: 40%; "></div>
3. 可视化 `CIFAR100` 验证集中的100张原始图片显示并保存在 `./tmp` 文件夹下:
1. **'original'** 模式,可视化 `CIFAR100` 验证集中的100张原始图片显示并保存在 `./tmp` 文件夹下:
```shell
python ./tools/visualizations/vis_pipeline.py configs/resnet/resnet50_8xb16_cifar100.py --phase val --output-dir tmp --mode original --number 100 --show --adaptive --bgr2rgb
```
<div align=center><img src="../_static/image/tools/visualization/pipeline-original.jpg" style=" width: auto; height: 40%; "></div>
<div align=center><img src="https://user-images.githubusercontent.com/18586273/146117528-1ec2d918-57f8-4ae4-8ca3-a8d31b602f64.jpg" style=" width: auto; height: 40%; "></div>
2. **'transformed'** 模式,可视化 `ImageNet` 训练集的所有经过预处理的图片,并以弹窗形式显示:
```shell
python ./tools/visualizations/vis_pipeline.py ./configs/resnet/resnet50_8xb32_in1k.py --show --mode transformed
```
<div align=center><img src="https://user-images.githubusercontent.com/18586273/146117553-8006a4ba-e2fa-4f53-99bc-42a4b06e413f.jpg" style=" width: auto; height: 40%; "></div>
3. **'concat'** 模式,可视化 `ImageNet` 训练集的10张原始图片与预处理后图片对比图保存在 `./tmp` 文件夹下:
```shell
python ./tools/visualizations/vis_pipeline.py configs/swin_transformer/swin_base_224_b16x64_300e_imagenet.py --phase train --output-dir tmp --number 10 --adaptive
```
<div align=center><img src="https://user-images.githubusercontent.com/18586273/146128259-0a369991-7716-411d-8c27-c6863e6d76ea.JPEG" style=" width: auto; height: 40%; "></div>
4. **'pipeline'** 模式,可视化 `ImageNet` 训练集经过数据流水线的过程图像:
```shell
python ./tools/visualizations/vis_pipeline.py configs/swin_transformer/swin_base_224_b16x64_300e_imagenet.py --phase train --adaptive --mode pipeline --show
```
<div align=center><img src="https://user-images.githubusercontent.com/18586273/146128201-eb97c2aa-a615-4a81-a649-38db1c315d0e.JPEG" style=" width: auto; height: 40%; "></div>
## 学习率策略可视化
```bash
python tools/visualizations/vis_lr.py \
${CONFIG_FILE} \
--dataset-size ${Dataset_Size} \
--ngpus ${NUM_GPUs}
--save-path ${SAVE_PATH} \
--title ${TITLE} \
--style ${STYLE} \
--window-size ${WINDOW_SIZE}
--cfg-options
[--dataset-size ${Dataset_Size}] \
[--ngpus ${NUM_GPUs}] \
[--save-path ${SAVE_PATH}] \
[--title ${TITLE}] \
[--style ${STYLE}] \
[--window-size ${WINDOW_SIZE}] \
[--cfg-options ${CFG_OPTIONS}] \
```
**所有参数的说明**
- `config` : 模型配置文件的路径。
- `dataset-size` : 数据集的大小。如果指定,`build_dataset` 将被跳过并使用这个大小作为数据集大小,默认使用 `build_dataset` 所得数据集的大小。
- `ngpus` : 使用 GPU 的数量。
- `save-path` : 保存的可视化图片的路径,默认不保存。
- `title` : 可视化图片的标题,默认为配置文件名。
- `style` : 可视化图片的风格,默认为 `whitegrid`
- `window-size`: 可视化窗口大小,如果没有指定,默认为 `12*7`。如果需要指定,按照格式 `'W*H'`
- `cfg-options` : 对配置文件的修改,参考[教程 1如何编写配置文件](https://mmclassification.readthedocs.io/zh_CN/latest/tutorials/config.html)。
- `--dataset-size` : 数据集的大小。如果指定,`build_dataset` 将被跳过并使用这个大小作为数据集大小,默认使用 `build_dataset` 所得数据集的大小。
- `--ngpus` : 使用 GPU 的数量。
- `--save-path` : 保存的可视化图片的路径,默认不保存。
- `--title` : 可视化图片的标题,默认为配置文件名。
- `--style` : 可视化图片的风格,默认为 `whitegrid`
- `--window-size`: 可视化窗口大小,如果没有指定,默认为 `12*7`。如果需要指定,按照格式 `'W*H'`
- `--cfg-options` : 对配置文件的修改,参考[教程 1如何编写配置文件](https://mmclassification.readthedocs.io/zh_CN/latest/tutorials/config.html)。
```{note}

View File

@ -1,18 +1,27 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import copy
import itertools
import os
import re
import sys
import warnings
from pathlib import Path
from typing import List
import cv2
import mmcv
import numpy as np
from mmcv import Config, DictAction, ProgressBar
from mmcls.core import visualization as vis
from mmcls.datasets.builder import build_dataset
from mmcls.datasets.pipelines import Compose
from mmcls.datasets.builder import PIPELINES, build_dataset, build_from_cfg
from mmcls.models.utils import to_2tuple
# text style
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
red_text, blue_text = '\x1b[31m', '\x1b[34m'
white_background = '\x1b[107m'
def parse_args():
@ -48,11 +57,12 @@ def parse_args():
'--mode',
default='concat',
type=str,
choices=['original', 'pipeline', 'concat'],
choices=['original', 'transformed', 'concat', 'pipeline'],
help='display mode; display original pictures or transformed pictures'
' or comparison pictures. "original" means show images load from disk;'
' "pipeline" means to show images after pipeline; "concat" means show '
'images stitched by "original" and "pipeline" images. Default concat.')
' or comparison pictures. "original" means show images load from disk'
'; "transformed" means to show images after transformed; "concat" '
'means show images stitched by "original" and "output" images. '
'"pipeline" means show all the intermediate images. Default concat.')
parser.add_argument(
'--show',
default=False,
@ -71,7 +81,7 @@ def parse_args():
'"--adaptive" is true. Default 200.')
parser.add_argument(
'--max-edge-length',
default=1000,
default=800,
type=int,
help='the max edge length when visualizing images, used when '
'"--adaptive" is true. Default 1000.')
@ -130,7 +140,7 @@ def retrieve_data_cfg(config_path, skip_type, cfg_options, phase):
return cfg
def build_dataset_pipeline(cfg, phase):
def build_dataset_pipelines(cfg, phase):
"""build dataset and pipeline from config.
Separate the pipeline except 'LoadImageFromFile' step if
@ -144,43 +154,103 @@ def build_dataset_pipeline(cfg, phase):
origin_pipeline = data_cfg.pipeline
data_cfg.pipeline = loadimage_pipeline
dataset = build_dataset(data_cfg)
pipeline = Compose(origin_pipeline)
pipelines = {
pipeline_cfg['type']: build_from_cfg(pipeline_cfg, PIPELINES)
for pipeline_cfg in origin_pipeline
}
return dataset, pipeline
return dataset, pipelines
def put_img(board, img, center):
"""put a image into a big board image with the anchor center."""
center_x, center_y = center
img_h, img_w, _ = img.shape
xmin, ymin = int(center_x - img_w // 2), int(center_y - img_h // 2)
board[ymin:ymin + img_h, xmin:xmin + img_w, :] = img
def prepare_imgs(args, imgs: List[np.ndarray], steps=None):
"""prepare the showing picture."""
ori_shapes = [img.shape for img in imgs]
# adaptive adjustment to rescale pictures
if args.adaptive:
for i, img in enumerate(imgs):
imgs[i] = adaptive_size(img, args.min_edge_length,
args.max_edge_length)
else:
# if src image is too large or too small,
# warning a "--adaptive" message.
for ori_h, ori_w, _ in ori_shapes:
if (args.min_edge_length > ori_h or args.min_edge_length > ori_w
or args.max_edge_length < ori_h
or args.max_edge_length < ori_w):
msg = red_text
msg += 'The visualization picture is too small or too large to'
msg += ' put text information on it, please add '
msg += bright_style + red_text + white_background
msg += '"--adaptive"'
msg += reset_style + red_text
msg += ' to adaptively rescale the showing pictures'
msg += reset_style
warnings.warn(msg)
if len(imgs) == 1:
return imgs[0]
else:
return concat_imgs(imgs, steps, ori_shapes)
def concat_imgs(imgs, steps, ori_shapes):
"""Concat list of pictures into a single big picture, align height here."""
show_shapes = [img.shape for img in imgs]
show_heights = [shape[0] for shape in show_shapes]
show_widths = [shape[1] for shape in show_shapes]
max_height = max(show_heights)
text_height = 20
font_size = 0.5
pic_horizontal_gap = min(show_widths) // 10
for i, img in enumerate(imgs):
cur_height = show_heights[i]
pad_height = max_height - cur_height
pad_top, pad_bottom = to_2tuple(pad_height // 2)
# handle instance that the pad_height is an odd number
if pad_height % 2 == 1:
pad_top = pad_top + 1
pad_bottom += text_height * 3 # keep pxs to put step information text
pad_left, pad_right = to_2tuple(pic_horizontal_gap)
# make border
img = cv2.copyMakeBorder(
img,
pad_top,
pad_bottom,
pad_left,
pad_right,
cv2.BORDER_CONSTANT,
value=(255, 255, 255))
# put transform phase information in the bottom
imgs[i] = cv2.putText(
img=img,
text=steps[i],
org=(pic_horizontal_gap, max_height + text_height // 2),
fontFace=cv2.FONT_HERSHEY_TRIPLEX,
fontScale=font_size,
color=(255, 0, 0),
lineType=1)
# put image size information in the bottom
imgs[i] = cv2.putText(
img=img,
text=str(ori_shapes[i]),
org=(pic_horizontal_gap, max_height + int(text_height * 1.5)),
fontFace=cv2.FONT_HERSHEY_TRIPLEX,
fontScale=font_size,
color=(255, 0, 0),
lineType=1)
# Height alignment for concatenating
board = np.concatenate(imgs, axis=1)
return board
def concat(left_img, right_img):
"""Concat two pictures into a single big picture, accepts two images with
diffenert shapes."""
GAP = 10
left_h, left_w, _ = left_img.shape
right_h, right_w, _ = right_img.shape
# create a big board to contain images with shape (board_h, board_w*2+10)
board_h, board_w = max(left_h, right_h), max(left_w, right_w)
board = np.ones([board_h, 2 * board_w + GAP, 3], np.uint8) * 255
put_img(board, left_img, (int(board_w // 2), int(board_h // 2)))
put_img(board, right_img,
(int(board_w // 2) + board_w + GAP // 2, int(board_h // 2)))
return board
def adaptive_size(mode, image, min_edge_length, max_edge_length):
"""rescale image if image is too small to put text like cifra."""
def adaptive_size(image, min_edge_length, max_edge_length, src_shape=None):
"""rescale image if image is too small to put text like cifar."""
assert min_edge_length >= 0 and max_edge_length >= 0
assert max_edge_length >= min_edge_length
image_h, image_w, *_ = image.shape
image_w = image_w // 2 if mode == 'concat' else image_w
src_shape = image.shape if src_shape is None else src_shape
image_h, image_w, _ = src_shape
if image_h < min_edge_length or image_w < min_edge_length:
image = mmcv.imrescale(
@ -191,49 +261,58 @@ def adaptive_size(mode, image, min_edge_length, max_edge_length):
return image
def get_display_img(item, pipeline, mode, bgr2rgb):
def get_display_img(args, item, pipelines):
"""get image to display."""
if bgr2rgb:
# srcs picture could be in RGB or BGR order due to different backends.
if args.bgr2rgb:
item['img'] = mmcv.bgr2rgb(item['img'])
src_image = item['img'].copy()
# get transformed picture
if mode in ['pipeline', 'concat']:
item = pipeline(item)
trans_image = item['img']
trans_image = np.ascontiguousarray(trans_image, dtype=np.uint8)
pipeline_images = [src_image]
# get intermediate images through pipelines
if args.mode in ['transformed', 'concat', 'pipeline']:
for pipeline in pipelines.values():
item = pipeline(item)
trans_image = copy.deepcopy(item['img'])
trans_image = np.ascontiguousarray(trans_image, dtype=np.uint8)
pipeline_images.append(trans_image)
# concatenate images to be showed according to mode
if args.mode == 'original':
image = prepare_imgs(args, [src_image], ['src'])
elif args.mode == 'transformed':
image = prepare_imgs(args, [pipeline_images[-1]], ['transformed'])
elif args.mode == 'concat':
steps = ['src', 'transformed']
image = prepare_imgs(args, [pipeline_images[0], pipeline_images[-1]],
steps)
elif args.mode == 'pipeline':
steps = ['src'] + list(pipelines.keys())
image = prepare_imgs(args, pipeline_images, steps)
if mode == 'concat':
image = concat(src_image, trans_image)
elif mode == 'original':
image = src_image
elif mode == 'pipeline':
image = trans_image
return image
def main():
args = parse_args()
wind_w, wind_h = args.window_size.split('*')
wind_w, wind_h = int(wind_w), int(wind_h)
wind_w, wind_h = int(wind_w), int(wind_h) # showing windows size
cfg = retrieve_data_cfg(args.config, args.skip_type, args.cfg_options,
args.phase)
dataset, pipeline = build_dataset_pipeline(cfg, args.phase)
dataset, pipelines = build_dataset_pipelines(cfg, args.phase)
CLASSES = dataset.CLASSES
display_number = min(args.number, len(dataset))
progressBar = ProgressBar(display_number)
with vis.ImshowInfosContextManager(fig_size=(wind_w, wind_h)) as manager:
for i, item in enumerate(itertools.islice(dataset, display_number)):
image = get_display_img(item, pipeline, args.mode, args.bgr2rgb)
if args.adaptive:
image = adaptive_size(args.mode, image, args.min_edge_length,
args.max_edge_length)
image = get_display_img(args, item, pipelines)
# dist_path is None as default, means not save pictures
# dist_path is None as default, means not saving pictures
dist_path = None
if args.output_dir:
# some datasets do not have filename, such as cifar, use id
# some datasets don't have filenames, such as cifar
src_path = item.get('filename', '{}.jpg'.format(i))
dist_path = os.path.join(args.output_dir, Path(src_path).name)