[Feature] Support browse_dataset.py to visualize original dataset (#1503)

* update browse dataset

* enhance browse_dataset

* update docs and fix original mode

Co-authored-by: gaotongxiao <gaotongxiao@gmail.com>
pull/1634/head
Xinyu Wang 2022-12-17 01:04:23 +10:30 committed by GitHub
parent f6da8715b9
commit c38618bf51
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 467 additions and 42 deletions

View File

@ -80,6 +80,12 @@ data/icdar2015
└── textdet_train.json
```
Once your dataset has been prepared, you can use the [browse_dataset.py](./useful_tools.md#dataset-visualization-tool) to visualize the dataset and check if the annotations are correct.
```bash
python tools/analysis_tools/browse_dataset.py configs/textdet/_base_/datasets/icdar2015.py
```
## Dataset Configuration
### Single Dataset Training

View File

@ -4,26 +4,91 @@
### Dataset Visualization Tool
MMOCR provides a dataset visualization tool `tools/analysis_tools/browse_datasets.py` to help users troubleshoot possible dataset-related problems. You just need to specify the path to the training config and the tool will automatically plots the images transformed by corresponding data pipelines with the GT labels. The following example demonstrates how to use the tool to visualize the training data used by the "DBNet_R50_icdar2015" model.
MMOCR provides a dataset visualization tool `tools/analysis_tools/browse_datasets.py` to help users troubleshoot possible dataset-related problems. You just need to specify the path to the training config (usually stored in `configs/textdet/dbnet/xxx.py`) or the dataset config (usually stored in `configs/textdet/_base_/datasets/xxx.py`), and the tool will automatically plots the transformed (or original) images and labels.
```Bash
# Example: Visualizing the training data used by dbnet_r50dcn_v2_fpnc_1200e_icadr2015
python tools/analysis_tools/browse_dataset.py configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py
#### Usage
```bash
python tools/visualizations/browse_dataset.py \
${CONFIG_FILE} \
[-o, --output-dir ${OUTPUT_DIR}] \
[-p, --phase ${DATASET_PHASE}] \
[-m, --mode ${DISPLAY_MODE}] \
[-t, --task ${DATASET_TASK}] \
[-n, --show-number ${NUMBER_IMAGES_DISPLAY}] \
[-i, --show-interval ${SHOW_INTERRVAL}] \
[--cfg-options ${CFG_OPTIONS}]
```
The visualization results will be like:
| ARGS | Type | Description |
| ------------------- | ------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------ |
| config | str | (required) Path to the config. |
| -o, --output-dir | str | If GUI is not available, specifying an output path to save the visualization results. |
| -p, --phase | str | Phase of dataset to visualize. Use "train", "test" or "val" if you just want to visualize the default split. It's also possible to be a dataset variable name, which might be useful when a dataset split has multiple variants in the config. |
| -m, --mode | `original`, `transformed`, `pipeline` | Display mode: display original pictures or transformed pictures or comparison pictures. `original` only visualizes the original dataset & annotations; `transformed` shows the resulting images processed through all the transforms; `pipeline` shows all the intermediate images. Defaults to "transformed". |
| -t, --task | `auto`, `textdet`, `textrecog` | Specify the task type of the dataset. If `auto`, the task type will be inferred from the config. If the script is unable to infer the task type, you need to specify it manually. Defaults to `auto`. |
| -n, --show-number | int | The number of samples to visualized. If not specified, display all images in the dataset. |
| -i, --show-interval | float | Interval of visualization (s), defaults to 2. |
| --cfg-options | float | Override configs. [Example](./config.md#command-line-modification) |
#### Examples
The following example demonstrates how to use the tool to visualize the training data used by the "DBNet_R50_icdar2015" model.
```Bash
# Example: Visualizing the training data used by dbnet_r50dcn_v2_fpnc_1200e_icadr2015 model
python tools/analysis_tools/browse_dataset.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py
```
By default, the visualization mode is "transformed", and you will see the images & annotations being transformed by the pipeline:
<center class="half">
<img src="https://user-images.githubusercontent.com/24622904/187611542-01e9aa94-fc12-4756-964b-a0e472522a3a.jpg" width="250"/><img src="https://user-images.githubusercontent.com/24622904/187611555-3f5ea616-863d-4538-884f-bccbebc2f7e7.jpg" width="250"/><img src="https://user-images.githubusercontent.com/24622904/187611581-88be3970-fbfe-4f62-8cdf-7a8a7786af29.jpg" width="250"/>
</center>
Based on this tool, users can easily verify if the annotation of a custom dataset is correct. Also, you can verify if the data augmentation strategies are running as you expected by modifying `train_pipeline` in the configuration file. The optional parameters of `browse_dataset.py` are as follows.
If you just want to visualize the original dataset, simply set the mode to "original":
| ARGS | Type | Description |
| --------------- | ----- | ------------------------------------------------------------------------------------- |
| config | str | (required) Path to the config. |
| --output-dir | str | If GUI is not available, specifying an output path to save the visualization results. |
| --show-interval | float | Interval of visualization (s), defaults to 2. |
```Bash
python tools/analysis_tools/browse_dataset.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py -m original
```
<div align=center><img src="https://user-images.githubusercontent.com/22607038/206646570-382d0f26-908a-4ab4-b1a7-5cc31fa70c5f.jpg" style=" width: auto; height: 40%; "></div>
Or, to visualize the entire pipeline:
```Bash
python tools/analysis_tools/browse_dataset.py configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py -m pipeline
```
<div align=center><img src="https://user-images.githubusercontent.com/22607038/206637571-287640c0-1f55-453f-a2fc-9f9734b9593f.jpg" style=" width: auto; height: 40%; "></div>
In addition, users can also visualize the original images and their corresponding labels of the dataset by specifying the path to the dataset config file, for example:
```Bash
python tools/analysis_tools/browse_dataset.py configs/textrecog/_base_/datasets/icdar2015.py
```
Some datasets might have multiple variants. For example, the test split of `icdar2015` textrecog dataset has two variants, which the [base dataset config](/configs/textrecog/_base_/datasets/icdar2015.py) defines as follows:
```python
icdar2015_textrecog_test = dict(
ann_file='textrecog_test.json',
# ...
)
icdar2015_1811_textrecog_test = dict(
ann_file='textrecog_test_1811.json',
# ...
)
```
In this case, you can specify the variant name to visualize the corresponding dataset:
```Bash
python tools/analysis_tools/browse_dataset.py configs/textrecog/_base_/datasets/icdar2015.py -p icdar2015_1811_textrecog_test
```
Based on this tool, users can easily verify if the annotation of a custom dataset is correct.
### Offline Evaluation Tool

View File

@ -80,6 +80,12 @@ data/icdar2015
└── textdet_train.json
```
数据准备完毕以后,你也可以通过使用我们提供的数据集浏览工具 [browse_dataset.py](./useful_tools.md#数据集可视化工具) 来可视化数据集的标签是否被正确生成,例如:
```bash
python tools/analysis_tools/browse_dataset.py configs/textdet/_base_/datasets/icdar2015.py
```
## 数据集配置文件
### 单数据集训练

View File

@ -4,10 +4,16 @@
### 数据集可视化工具
MMOCR 提供了数据集可视化工具 `tools/analysis_tools/browse_datasets.py` 以辅助用户排查可能遇到的数据集相关的问题。用户只需要指定所使用的训练配置文件路径该工具即可自动将经过数据流水线data pipeline处理过的图像及其对应的真实标签绘制出来。例如以下命令演示了如何使用该工具对 "DBNet_R50_icdar2015" 模型使用的训练数据进行可视化操作:
```{note}
本工具的中文文档已经过时,请以英文文档为准。如果您有兴趣参与本节中文文档的翻译,欢迎通过 [Issue: Documentation](https://github.com/open-mmlab/mmocr/issues/new?assignees=&labels=docs&template=4-documentation.yml&title=%5BDocs%5D+) 及时告知我们:)
```
MMOCR 提供了数据集可视化工具 `tools/analysis_tools/browse_datasets.py` 以辅助用户排查可能遇到的数据集相关的问题。用户只需要指定所使用的训练配置文件(通常存放在如 `configs/textdet/dbnet/xxx.py` 文件中)或数据集配置(通常存放在 `configs/textdet/_base_/datasets/xxx.py` 文件中路径。该工具将依据输入的配置文件类型自动将经过数据流水线data pipeline处理过的图像及其对应的标签或原始图片及其对应的标签绘制出来。
例如,以下命令演示了如何使用该工具对 "DBNet_R50_icdar2015" 模型使用的经过数据变换的训练数据进行可视化操作:
```Bash
# 示例:可视化 dbnet_r50dcn_v2_fpnc_1200e_icadr2015 使用的训练数据
# 示例:可视化 dbnet_r50dcn_v2_fpnc_1200e_icadr2015 模型使用的训练数据
python tools/analysis_tools/browse_dataset.py configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py
```
@ -17,6 +23,12 @@ python tools/analysis_tools/browse_dataset.py configs/textdet/dbnet/dbnet_r50dcn
<img src="https://user-images.githubusercontent.com/24622904/187611542-01e9aa94-fc12-4756-964b-a0e472522a3a.jpg" width="250"/><img src="https://user-images.githubusercontent.com/24622904/187611555-3f5ea616-863d-4538-884f-bccbebc2f7e7.jpg" width="250"/><img src="https://user-images.githubusercontent.com/24622904/187611581-88be3970-fbfe-4f62-8cdf-7a8a7786af29.jpg" width="250"/>
</center>
另外,用户也可以通过输入数据集配置文件路径来可视化数据集的原始图像及其对应的标签:
```Bash
python tools/analysis_tools/browse_dataset.py configs/textdet/_base_/datasets/icdar2015.py
```
基于此工具,用户可以方便地验证自定义数据集的标注格式是否正确;也可以通过修改配置文件中的 `train_pipeline` 来验证不同的数据增强策略组合是否符合自己的预期。`browse_dataset.py` 的可选参数如下:
| 参数 | 类型 | 说明 |

View File

@ -258,6 +258,9 @@ class KIELocalVisualizer(BaseLocalVisualizer):
if out_file is not None:
mmcv.imwrite(cat_images[..., ::-1], out_file)
self.set_image(cat_images)
return self.get_image()
def draw_arrows(self,
x_data: Union[np.ndarray, torch.Tensor],
y_data: Union[np.ndarray, torch.Tensor],

View File

@ -167,3 +167,6 @@ class TextDetLocalVisualizer(BaseLocalVisualizer):
if out_file is not None:
mmcv.imwrite(cat_images[..., ::-1], out_file)
self.set_image(cat_images)
return self.get_image()

View File

@ -118,11 +118,13 @@ class TextRecogLocalVisualizer(BaseLocalVisualizer):
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
cat_images = [image]
if draw_gt and data_sample is not None and 'gt_text' in data_sample:
if (draw_gt and data_sample is not None and 'gt_text' in data_sample
and 'item' in data_sample.gt_text):
gt_text = data_sample.gt_text.item
cat_images.append(self._draw_instances(image, gt_text))
if (draw_pred and data_sample is not None
and 'pred_text' in data_sample):
and 'pred_text' in data_sample
and 'item' in data_sample.pred_text):
pred_text = data_sample.pred_text.item
cat_images.append(self._draw_instances(image, pred_text))
cat_images = self._cat_image(cat_images, axis=0)
@ -134,3 +136,6 @@ class TextRecogLocalVisualizer(BaseLocalVisualizer):
if out_file is not None:
mmcv.imwrite(cat_images[..., ::-1], out_file)
self.set_image(cat_images)
return self.get_image()

View File

@ -133,3 +133,6 @@ class TextSpottingLocalVisualizer(BaseLocalVisualizer):
if out_file is not None:
mmcv.imwrite(cat_images[..., ::-1], out_file)
self.set_image(cat_images)
return self.get_image()

View File

@ -1,33 +1,82 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
import sys
from typing import Optional, Tuple
import mmengine
import cv2
import mmcv
import numpy as np
from mmengine.config import Config, DictAction
from mmengine.dataset import Compose
from mmengine.utils import ProgressBar
from mmengine.visualization import Visualizer
from mmocr.registry import DATASETS, VISUALIZERS
from mmocr.utils import register_all_modules
# TODO: Support for printing the change in key of results
def parse_args():
parser = argparse.ArgumentParser(description='Browse a dataset')
parser.add_argument('config', help='Train config file path')
parser.add_argument('config', help='Path to model or dataset config.')
parser.add_argument(
'--phase',
'-p',
default='train',
type=str,
help='Phase of dataset to visualize. Use "train", "test" or "val" if '
"you just want to visualize the default split. It's also possible to "
'be a dataset variable name, which might be useful when a dataset '
'split has multiple variants in the config.')
parser.add_argument(
'--mode',
'-m',
default='transformed',
type=str,
choices=['original', 'transformed', 'pipeline'],
help='Display mode: display original pictures or '
'transformed pictures or comparison pictures. "original" '
'only visualizes the original dataset & annotations; '
'"transformed" shows the resulting images processed through all the '
'transforms; "pipeline" shows all the intermediate images. '
'Defaults to "transformed".')
parser.add_argument(
'--output-dir',
'-o',
default=None,
type=str,
help='If there is no display interface, you can save it')
help='If there is no display interface, you can save it.')
parser.add_argument(
'--task',
'-t',
default='auto',
choices=['auto', 'textdet', 'textrecog'],
type=str,
help='Specify the task type of the dataset. If "auto", the task type '
'will be inferred from the config. If the script is unable to infer '
'the task type, you need to specify it manually. Defaults to "auto".')
parser.add_argument('--not-show', default=False, action='store_true')
parser.add_argument(
'--show-number',
'-n',
type=int,
default=sys.maxsize,
help='number of images selected to visualize, '
'must bigger than 0. if the number is bigger than length '
'of dataset, show all the images in dataset; '
'default "sys.maxsize", show all images in dataset')
parser.add_argument(
'--show-interval',
'-i',
type=float,
default=2,
help='The interval of show (s)')
default=3,
help='the interval of show (s)')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='Override some settings in the used config, the key-value pair '
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
@ -37,40 +86,313 @@ def parse_args():
return args
def main():
def _get_adaptive_scale(img_shape: Tuple[int, int],
min_scale: float = 0.3,
max_scale: float = 3.0) -> float:
"""Get adaptive scale according to image shape.
The target scale depends on the the short edge length of the image. If the
short edge length equals 224, the output is 1.0. And output linear
scales according the short edge length. You can also specify the minimum
scale and the maximum scale to limit the linear scale.
Args:
img_shape (Tuple[int, int]): The shape of the canvas image.
min_scale (int): The minimum scale. Defaults to 0.3.
max_scale (int): The maximum scale. Defaults to 3.0.
Returns:
int: The adaptive scale.
"""
short_edge_length = min(img_shape)
scale = short_edge_length / 224.
return min(max(scale, min_scale), max_scale)
def make_grid(imgs, names):
"""Concat list of pictures into a single big picture, align height here."""
visualizer = Visualizer.get_current_instance()
ori_shapes = [img.shape[:2] for img in imgs]
max_height = int(max(img.shape[0] for img in imgs) * 1.1)
min_width = min(img.shape[1] for img in imgs)
horizontal_gap = min_width // 10
img_scale = _get_adaptive_scale((max_height, min_width))
texts = []
text_positions = []
start_x = 0
for i, img in enumerate(imgs):
pad_height = (max_height - img.shape[0]) // 2
pad_width = horizontal_gap // 2
# make border
imgs[i] = cv2.copyMakeBorder(
img,
pad_height,
max_height - img.shape[0] - pad_height + int(img_scale * 30 * 2),
pad_width,
pad_width,
cv2.BORDER_CONSTANT,
value=(255, 255, 255))
texts.append(f'{"execution: "}{i}\n{names[i]}\n{ori_shapes[i]}')
text_positions.append(
[start_x + img.shape[1] // 2 + pad_width, max_height])
start_x += img.shape[1] + horizontal_gap
display_img = np.concatenate(imgs, axis=1)
visualizer.set_image(display_img)
img_scale = _get_adaptive_scale(display_img.shape[:2])
visualizer.draw_texts(
texts,
positions=np.array(text_positions),
font_sizes=img_scale * 7,
colors='black',
horizontal_alignments='center',
font_families='monospace')
return visualizer.get_image()
class InspectCompose(Compose):
"""Compose multiple transforms sequentially.
And record "img" field of all results in one list.
"""
def __init__(self, transforms, intermediate_imgs):
super().__init__(transforms=transforms)
self.intermediate_imgs = intermediate_imgs
def __call__(self, data):
if 'img' in data:
self.intermediate_imgs.append({
'name': 'original',
'img': data['img'].copy()
})
self.ptransforms = [
self.transforms[i] for i in range(len(self.transforms) - 1)
]
for t in self.ptransforms:
data = t(data)
# Keep the same meta_keys in the PackDetInputs
self.transforms[-1].meta_keys = [key for key in data]
data_sample = self.transforms[-1](data)
if data is None:
return None
if 'img' in data:
self.intermediate_imgs.append({
'name':
t.__class__.__name__,
'dataset_sample':
data_sample['data_samples']
})
return data
def infer_dataset_task(task: str,
dataset_cfg: Config,
var_name: Optional[str] = None) -> str:
"""Try to infer the dataset's task type from the config and the variable
name."""
if task != 'auto':
return task
if dataset_cfg.pipeline is not None:
if dataset_cfg.pipeline[-1].type == 'PackTextDetInputs':
return 'textdet'
elif dataset_cfg.pipeline[-1].type == 'PackTextRecogInputs':
return 'textrecog'
if var_name is not None:
if 'det' in var_name:
return 'textdet'
elif 'rec' in var_name:
return 'textrecog'
raise ValueError(
'Unable to infer the task type from dataset pipeline '
'or variable name. Please specify the task type with --task argument '
'explicitly.')
def obtain_dataset_cfg(cfg: Config, phase: str, mode: str, task: str) -> Tuple:
"""Obtain dataset and visualizer from config. Two modes are supported:
1. Model Config Mode:
In this mode, the input config should be a complete model config, which
includes a dataset within pipeline and a visualizer.
2. Dataset Config Mode:
In this mode, the input config should be a complete dataset config,
which only includes basic dataset information, and it may does not
contain a visualizer and dataset pipeline.
Examples:
Typically, the model config files are stored in
`configs/textdet/dbnet/xxx.py` and should look like:
>>> train_dataloader = dict(
>>> batch_size=16,
>>> num_workers=8,
>>> persistent_workers=True,
>>> sampler=dict(type='DefaultSampler', shuffle=True),
>>> dataset=icdar2015_textdet_train)
while the dataset config files are stored in
`configs/textdet/_base_/datasets/xxx.py` and should be like:
>>> icdar2015_textdet_train = dict(
>>> type='OCRDataset',
>>> data_root=ic15_det_data_root,
>>> ann_file='textdet_train.json',
>>> filter_cfg=dict(filter_empty_gt=True, min_size=32),
>>> pipeline=None)
Args:
cfg (Config): Config object.
phase (str): The dataset phase to visualize.
mode (str): Script mode.
task (str): The current task type.
Returns:
Tuple: Tuple of (dataset, visualizer).
"""
default_cfgs = dict(
textdet=dict(
visualizer=dict(
type='TextDetLocalVisualizer',
name='visualizer',
vis_backends=[dict(type='LocalVisBackend')]),
pipeline=[
dict(
type='LoadImageFromFile',
file_client_args=dict(backend='disk'),
color_type='color_ignore_orientation'),
dict(
type='LoadOCRAnnotations',
with_polygon=True,
with_bbox=True,
with_label=True,
),
dict(
type='PackTextDetInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape'))
]),
textrecog=dict(
visualizer=dict(
type='TextRecogLocalVisualizer',
name='visualizer',
vis_backends=[dict(type='LocalVisBackend')]),
pipeline=[
dict(
type='LoadImageFromFile',
file_client_args=dict(backend='disk'),
ignore_empty=True,
min_size=2),
dict(type='LoadOCRAnnotations', with_text=True),
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape',
'valid_ratio'))
]),
)
# Model config mode
dataloader_name = f'{phase}_dataloader'
if dataloader_name in cfg:
dataset = cfg.get(dataloader_name).dataset
visualizer = cfg.visualizer
if mode == 'original':
default_cfg = default_cfgs[infer_dataset_task(task, dataset)]
dataset.pipeline = default_cfg['pipeline']
return dataset, visualizer
# Dataset config mode
for key in cfg.keys():
if key.endswith(phase) and cfg[key]['type'].endswith('Dataset'):
dataset = cfg[key]
default_cfg = default_cfgs[infer_dataset_task(
task, dataset, key.lower())]
visualizer = default_cfg['visualizer']
dataset['pipeline'] = default_cfg['pipeline'] if dataset[
'pipeline'] is None else dataset['pipeline']
return dataset, visualizer
raise ValueError(
f'Unable to find "{phase}_dataloader" or any dataset variable ending '
f'with "{phase}". Please check your config file or --phase argument '
'and try again. More details can be found in the docstring of '
'obtain_dataset_cfg function. Or, you may visit the documentation via '
'https://mmocr.readthedocs.io/en/dev-1.x/user_guides/useful_tools.html#dataset-visualization-tool' # noqa: E501
)
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# register all modules in mmocr into the registries
# register all modules in mmyolo into the registries
register_all_modules()
dataset = DATASETS.build(cfg.train_dataloader.dataset)
visualizer = VISUALIZERS.build(cfg.visualizer)
dataset_cfg, visualizer_cfg = obtain_dataset_cfg(cfg, args.phase,
args.mode, args.task)
dataset = DATASETS.build(dataset_cfg)
visualizer = VISUALIZERS.build(visualizer_cfg)
visualizer.dataset_meta = dataset.metainfo
progress_bar = mmengine.ProgressBar(len(dataset))
for item in dataset:
img = item['inputs'].permute(1, 2, 0).numpy()
data_sample = item['data_samples'].numpy()
img_path = osp.basename(item['data_samples'].img_path)
intermediate_imgs = []
if dataset_cfg.type == 'ConcatDataset':
for sub_dataset in dataset.datasets:
sub_dataset.pipeline = InspectCompose(
sub_dataset.pipeline.transforms, intermediate_imgs)
else:
dataset.pipeline = InspectCompose(dataset.pipeline.transforms,
intermediate_imgs)
# init visualization image number
assert args.show_number > 0
display_number = min(args.show_number, len(dataset))
progress_bar = ProgressBar(display_number)
# fetching items from dataset is a must for visualization
for i, _ in zip(range(display_number), dataset):
image_i = []
result_i = [result['dataset_sample'] for result in intermediate_imgs]
for k, datasample in enumerate(result_i):
image = datasample.img
image = image[..., [2, 1, 0]] # bgr to rgb
image_show = visualizer.add_datasample(
'result',
image,
datasample,
draw_pred=False,
draw_gt=True,
show=False)
image_i.append(image_show)
if args.mode == 'pipeline':
image = make_grid([result for result in image_i],
[result['name'] for result in intermediate_imgs])
else:
image = image_i[-1]
if hasattr(datasample, 'img_path'):
filename = osp.basename(datasample.img_path)
else:
# some dataset have not image path
filename = f'{i}.jpg'
out_file = osp.join(args.output_dir,
img_path) if args.output_dir is not None else None
filename) if args.output_dir is not None else None
if img.ndim == 3 and img.shape[-1] == 3:
img = img[..., [2, 1, 0]] # bgr to rgb
if out_file is not None:
mmcv.imwrite(image[..., ::-1], out_file)
visualizer.add_datasample(
name=osp.basename(img_path),
image=img,
data_sample=data_sample,
draw_pred=False,
show=not args.not_show,
wait_time=args.show_interval,
out_file=out_file)
if not args.not_show:
visualizer.show(
image, win_name=filename, wait_time=args.show_interval)
intermediate_imgs.clear()
progress_bar.update()