[Feature] Implememnt the universal visualizer for multiple task. (#1404)

* [Feature] Implememnt the universal visualizer for multiple task.

* Update tools

* Improve according to comments.

* Fix tools docs

* Add --test-cfg option and set default collate function.
pull/1418/head
Ma Zerun 2023-03-09 11:36:54 +08:00 committed by GitHub
parent dbf3df21a3
commit 3472ee5d2c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 922 additions and 706 deletions

View File

@ -22,7 +22,7 @@ from rich.table import Table
from mmpretrain.apis import get_model
from mmpretrain.datasets import CIFAR10, CIFAR100, ImageNet
from mmpretrain.utils import register_all_modules
from mmpretrain.visualization import ClsVisualizer
from mmpretrain.visualization import UniversalVisualizer
console = Console()
MMCLS_ROOT = Path(__file__).absolute().parents[2]
@ -166,7 +166,7 @@ def show_summary(summary_data, args):
for model_name, summary in summary_data.items():
row = [model_name]
valid = summary['valid']
color = 'green' if valid == 'PASS' else 'red'
color = {'PASS': 'green', 'CUDA OOM': 'yellow'}.get(valid, 'red')
row.append(f'[{color}]{valid}[/{color}]')
if valid == 'PASS':
row.append(str(summary['resolution']))
@ -248,15 +248,19 @@ def main(args):
result = inference(MMCLS_ROOT / config, checkpoint, tmpdir.name,
args, model_name)
result['valid'] = 'PASS'
except Exception:
import traceback
logger.error(f'"{config}" :\n{traceback.format_exc()}')
result = {'valid': 'FAIL'}
except Exception as e:
if 'CUDA out of memory' in str(e):
logger.error(f'"{config}" :\nCUDA out of memory')
result = {'valid': 'CUDA OOM'}
else:
import traceback
logger.error(f'"{config}" :\n{traceback.format_exc()}')
result = {'valid': 'FAIL'}
summary_data[model_name] = result
# show the results
if args.show:
vis = ClsVisualizer.get_instance('valid')
vis = UniversalVisualizer.get_instance('valid')
vis.set_image(mmcv.imread(args.img))
vis.draw_texts(
texts='\n'.join([f'{k}: {v}' for k, v in result.items()]),

View File

@ -36,7 +36,7 @@ env_cfg = dict(
# set visualizer
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(type='ClsVisualizer', vis_backends=vis_backends)
visualizer = dict(type='UniversalVisualizer', vis_backends=vis_backends)
# set log level
log_level = 'INFO'

View File

@ -1,46 +0,0 @@
_base_ = '../_base_/default_runtime.py'
model = dict(
type='ImageClassifier',
backbone=dict(
type='ResNet',
depth=50,
in_channels=3,
num_stages=4,
out_indices=(3, ),
norm_cfg=dict(type='BN'),
frozen_stages=-1),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))
dataset_type = 'ImageNet'
data_root = 'data/imagenet/'
data_preprocessor = dict(
num_classes=1000,
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True,
)
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='ResizeEdge', scale=256, edge='short'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackInputs'),
]
test_dataloader = dict(
batch_size=8,
num_workers=4,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)

View File

@ -1,46 +0,0 @@
_base_ = '../_base_/default_runtime.py'
model = dict(
type='ImageClassifier',
backbone=dict(
type='SwinTransformer',
arch='base',
img_size=192,
out_indices=-1,
drop_path_rate=0.1,
stage_cfgs=dict(block_cfgs=dict(window_size=6))),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1024,
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
cal_acc=False))
dataset_type = 'ImageNet'
data_root = 'data/imagenet/'
data_preprocessor = dict(
num_classes=1000,
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True,
)
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='ResizeEdge', scale=256, edge='short'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackInputs'),
]
test_dataloader = dict(
batch_size=8,
num_workers=4,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)

View File

@ -1,48 +0,0 @@
_base_ = '../_base_/default_runtime.py'
model = dict(
type='ImageClassifier',
backbone=dict(
type='VisionTransformer',
arch='base',
img_size=224,
patch_size=16,
out_indices=-1,
drop_path_rate=0.1,
out_type='featmap',
final_norm=False),
neck=None,
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=768,
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
init_cfg=[dict(type='TruncNormal', layer='Linear', std=2e-5)]),
)
dataset_type = 'ImageNet'
data_root = 'data/imagenet/'
data_preprocessor = dict(
num_classes=1000,
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True,
)
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='ResizeEdge', scale=256, edge='short'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackInputs'),
]
test_dataloader = dict(
batch_size=8,
num_workers=4,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)

View File

@ -169,7 +169,7 @@ scalars. By default, the recorded information will be saved at the `vis_data` fo
```python
visualizer = dict(
type='ClsVisualizer',
type='UniversalVisualizer',
vis_backends=[
dict(type='LocalVisBackend'),
]
@ -181,7 +181,7 @@ For example, to save them to TensorBoard, simply set them as below:
```python
visualizer = dict(
type='ClsVisualizer',
type='UniversalVisualizer',
vis_backends=[
dict(type='LocalVisBackend'),
dict(type='TensorboardVisBackend'),
@ -193,7 +193,7 @@ Or save them to WandB as below:
```python
visualizer = dict(
type='ClsVisualizer',
type='UniversalVisualizer',
vis_backends=[
dict(type='LocalVisBackend'),
dict(type='WandbVisBackend'),

View File

@ -33,6 +33,8 @@ Inference
:template: callable.rst
ImageClassificationInferencer
ImageRetrievalInferencer
FeatureExtractor
.. autosummary::
:toctree: generated

View File

@ -6,8 +6,9 @@
mmpretrain.visualization
===================================
This package includes visualizer components for classification tasks.
This package includes visualizer and some helper functions for visualization.
ClsVisualizer
Visualizer
-------------
.. autoclass:: ClsVisualizer
.. autoclass:: UniversalVisualizer
:members:

View File

@ -431,7 +431,7 @@ default_hooks = dict(
)
visualizer = dict(
type='ClsVisualizer',
type='UniversalVisualizer',
vis_backends=[dict(type='LocalVisBackend'), dict(type='TensorboardVisBackend')],
)
```
@ -472,7 +472,7 @@ See the {external+mmengine:doc}`MMEngine tutorial <advanced_tutorials/visualizat
```python
visualizer = dict(
type='ClsVisualizer',
type='UniversalVisualizer',
vis_backends=[
dict(type='LocalVisBackend'),
# Uncomment the below line to save the log and visualization results to TensorBoard.

View File

@ -2,7 +2,7 @@
## Introduction of the CAM visualization tool
MMClassification provides `tools\visualizations\vis_cam.py` tool to visualize class activation map. Please use `pip install "grad-cam>=1.3.6"` command to install [pytorch-grad-cam](https://github.com/jacobgil/pytorch-grad-cam).
MMPretrain provides `tools/visualization/vis_cam.py` tool to visualize class activation map. Please use `pip install "grad-cam>=1.3.6"` command to install [pytorch-grad-cam](https://github.com/jacobgil/pytorch-grad-cam).
The supported methods are as follows:
@ -18,7 +18,7 @@ The supported methods are as follows:
**Command**
```bash
python tools/visualizations/vis_cam.py \
python tools/visualization/vis_cam.py \
${IMG} \
${CONFIG_FILE} \
${CHECKPOINT} \
@ -73,7 +73,7 @@ For example, the `backbone.layer4[-1]` is the same as `backbone.layer4.2` since
1. Use different methods to visualize CAM for `ResNet50`, the `target-category` is the predicted result by the given checkpoint, using the default `target-layers`.
```shell
python tools/visualizations/vis_cam.py \
python tools/visualization/vis_cam.py \
demo/bird.JPEG \
configs/resnet/resnet50_8xb32_in1k.py \
https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_batch256_imagenet_20200708-cfb998bf.pth \
@ -88,7 +88,7 @@ For example, the `backbone.layer4[-1]` is the same as `backbone.layer4.2` since
2. Use different `target-category` to get CAM from the same picture. In `ImageNet` dataset, the category 238 is 'Greater Swiss Mountain dog', the category 281 is 'tabby, tabby cat'.
```shell
python tools/visualizations/vis_cam.py \
python tools/visualization/vis_cam.py \
demo/cat-dog.png configs/resnet/resnet50_8xb32_in1k.py \
https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_batch256_imagenet_20200708-cfb998bf.pth \
--target-layers 'backbone.layer4.2' \
@ -105,7 +105,7 @@ For example, the `backbone.layer4[-1]` is the same as `backbone.layer4.2` since
3. Use `--eigen-smooth` and `--aug-smooth` to improve visual effects.
```shell
python tools/visualizations/vis_cam.py \
python tools/visualization/vis_cam.py \
demo/dog.jpg \
configs/mobilenet_v3/mobilenet-v3-large_8xb128_in1k.py \
https://download.openmmlab.com/mmclassification/v0/mobilenet_v3/convert/mobilenet_v3_large-3ea3c186.pth \
@ -129,12 +129,12 @@ For ViT-like networks, such as ViT, T2T-ViT and Swin-Transformer, the features a
Besides the flattened features, some ViT-like networks also add extra tokens like the class token in ViT and T2T-ViT, and the distillation token in DeiT. In these networks, the final classification is done on the tokens computed in the last attention block, and therefore, the classification score will not be affected by other features and the gradient of the classification score with respect to them, will be zero. Therefore, you shouldn't use the output of the last attention block as the target layer in these networks.
To exclude these extra tokens, we need know the number of extra tokens. Almost all transformer-based backbones in MMClassification have the `num_extra_tokens` attribute. If you want to use this tool in a new or third-party network that don't have the `num_extra_tokens` attribute, please specify it the `--num-extra-tokens` argument.
To exclude these extra tokens, we need know the number of extra tokens. Almost all transformer-based backbones in MMPretrain have the `num_extra_tokens` attribute. If you want to use this tool in a new or third-party network that don't have the `num_extra_tokens` attribute, please specify it the `--num-extra-tokens` argument.
1. Visualize CAM for `Swin Transformer`, using default `target-layers`:
```shell
python tools/visualizations/vis_cam.py \
python tools/visualization/vis_cam.py \
demo/bird.JPEG \
configs/swin_transformer/swin-tiny_16xb64_in1k.py \
https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_b16x64_300e_imagenet_20210616_090925-66df6be6.pth \
@ -144,7 +144,7 @@ To exclude these extra tokens, we need know the number of extra tokens. Almost a
2. Visualize CAM for `Vision Transformer(ViT)`:
```shell
python tools/visualizations/vis_cam.py \
python tools/visualization/vis_cam.py \
demo/bird.JPEG \
configs/vision_transformer/vit-base-p16_ft-64xb64_in1k-384.py \
https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-base-p16_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-98e8652b.pth \
@ -155,7 +155,7 @@ To exclude these extra tokens, we need know the number of extra tokens. Almost a
3. Visualize CAM for `T2T-ViT`:
```shell
python tools/visualizations/vis_cam.py \
python tools/visualization/vis_cam.py \
demo/bird.JPEG \
configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py \
https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-14_3rdparty_8xb64_in1k_20210928-b7c09b62.pth \

View File

@ -3,7 +3,7 @@
## Introduce the dataset visualization tool
```bash
python tools/visualizations/browse_dataset.py \
python tools/visualization/browse_dataset.py \
${CONFIG_FILE} \
[-o, --output-dir ${OUTPUT_DIR}] \
[-p, --phase ${DATASET_PHASE}] \
@ -42,7 +42,7 @@ python tools/visualizations/browse_dataset.py \
In **'original'** mode:
```shell
python ./tools/visualizations/browse_dataset.py ./configs/resnet/resnet101_8xb16_cifar10.py --phase val --output-dir tmp --mode original --show-number 100 --rescale-factor 10 --channel-order RGB
python ./tools/visualization/browse_dataset.py ./configs/resnet/resnet101_8xb16_cifar10.py --phase val --output-dir tmp --mode original --show-number 100 --rescale-factor 10 --channel-order RGB
```
- `--phase val`: Visual validation set, can be simplified to `-p val`;
@ -59,7 +59,7 @@ python ./tools/visualizations/browse_dataset.py ./configs/resnet/resnet101_8xb16
In **'transformed'** mode:
```shell
python ./tools/visualizations/browse_dataset.py ./configs/resnet/resnet50_8xb32_in1k.py -n 100 -r 2
python ./tools/visualization/browse_dataset.py ./configs/resnet/resnet50_8xb32_in1k.py -n 100 -r 2
```
<div align=center><img src="https://user-images.githubusercontent.com/18586273/190994696-737b09d9-d0fb-4593-94a2-4487121e0286.JPEG" style=" width: auto; height: 40%; "></div>
@ -69,7 +69,7 @@ python ./tools/visualizations/browse_dataset.py ./configs/resnet/resnet50_8xb32_
In **'concat'** mode:
```shell
python ./tools/visualizations/browse_dataset.py configs/swin_transformer/swin-small_16xb64_in1k.py -n 10 -m concat
python ./tools/visualization/browse_dataset.py configs/swin_transformer/swin-small_16xb64_in1k.py -n 10 -m concat
```
<div align=center><img src="https://user-images.githubusercontent.com/18586273/190995078-3872feb2-d4e2-4727-a21b-7062d52f7d3e.JPEG" style=" width: auto; height: 40%; "></div>
@ -77,7 +77,7 @@ python ./tools/visualizations/browse_dataset.py configs/swin_transformer/swin-sm
4. In **'pipeline'** mode
```shell
python ./tools/visualizations/browse_dataset.py configs/swin_transformer/swin-small_16xb64_in1k.py -m pipeline
python ./tools/visualization/browse_dataset.py configs/swin_transformer/swin-small_16xb64_in1k.py -m pipeline
```
<div align=center><img src="https://user-images.githubusercontent.com/18586273/190995525-fac0220f-6630-4013-b94a-bc6de4fdff7a.JPEG" style=" width: auto; height: 40%; "></div>

View File

@ -1,8 +1,8 @@
# Torchserve Deployment
In order to serve an `MMClassification` model with [`TorchServe`](https://pytorch.org/serve/), you can follow the steps:
In order to serve an `MMPretrain` model with [`TorchServe`](https://pytorch.org/serve/), you can follow the steps:
## 1. Convert model from MMClassification to TorchServe
## 1. Convert model from MMPretrain to TorchServe
```shell
python tools/torchserve/mmpretrain2torchserve.py ${CONFIG_FILE} ${CHECKPOINT_FILE} \

View File

@ -5,7 +5,7 @@ This tool aims to help the user to check the hyper-parameter scheduler of the op
## Introduce the scheduler visualization tool
```bash
python tools/visualizations/vis_scheduler.py \
python tools/visualization/vis_scheduler.py \
${CONFIG_FILE} \
[-p, --parameter ${PARAMETER_NAME}] \
[-d, --dataset-size ${DATASET_SIZE}] \
@ -38,7 +38,7 @@ Loading annotations maybe consume much time, you can directly specify the size o
You can use the following command to plot the step learning rate schedule used in the config `configs/resnet/resnet50_b16x8_cifar100.py`:
```bash
python tools/visualizations/vis_scheduler.py configs/resnet/resnet50_b16x8_cifar100.py
python tools/visualization/vis_scheduler.py configs/resnet/resnet50_b16x8_cifar100.py
```
<div align=center><img src="https://user-images.githubusercontent.com/18586273/191006713-023f065d-d366-4165-a52e-36176367506e.png" style=" width: auto; height: 40%; "></div>
@ -46,7 +46,7 @@ python tools/visualizations/vis_scheduler.py configs/resnet/resnet50_b16x8_cifar
When using ImageNet, directly specify the size of ImageNet, as below:
```bash
python tools/visualizations/vis_scheduler.py configs/repvgg/repvgg-B3g4_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py --dataset-size 1281167 --ngpus 4 --save-path ./repvgg-B3g4_4xb64-lr.jpg
python tools/visualization/vis_scheduler.py configs/repvgg/repvgg-B3g4_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py --dataset-size 1281167 --ngpus 4 --save-path ./repvgg-B3g4_4xb64-lr.jpg
```
<div align=center><img src="https://user-images.githubusercontent.com/18586273/191006721-0f680e07-355e-4cd6-889c-86c0cad9acb7.png" style=" width: auto; height: 40%; "></div>

View File

@ -1,6 +1,6 @@
# Verify Dataset
In MMClassification, we also provide a tool `tools/misc/verify_dataset.py` to check whether there exists **broken pictures** in the given dataset.
In MMPretrain, we also provide a tool `tools/misc/verify_dataset.py` to check whether there exists **broken pictures** in the given dataset.
## Introduce the tool

View File

@ -257,7 +257,7 @@ env_cfg = dict(
# set visualizer
vis_backends = [dict(type='LocalVisBackend')] # use local HDD backend
visualizer = dict(
type='ClsVisualizer', vis_backends=vis_backends, name='visualizer')
type='UniversalVisualizer', vis_backends=vis_backends, name='visualizer')
# set log level
log_level = 'INFO'

View File

@ -162,7 +162,7 @@ Visualizer用于记录训练和测试过程中的各种信息包括日志、
```python
visualizer = dict(
type='ClsVisualizer',
type='UniversalVisualizer',
vis_backends=[
dict(type='LocalVisBackend'),
]
@ -174,7 +174,7 @@ visualizer = dict(
```python
visualizer = dict(
type='ClsVisualizer',
type='UniversalVisualizer',
vis_backends=[
dict(type='LocalVisBackend'),
dict(type='TensorboardVisBackend'),
@ -186,7 +186,7 @@ visualizer = dict(
```python
visualizer = dict(
type='ClsVisualizer',
type='UniversalVisualizer',
vis_backends=[
dict(type='LocalVisBackend'),
dict(type='WandbVisBackend'),

View File

@ -419,7 +419,7 @@ default_hooks = dict(
)
visualizer = dict(
type='ClsVisualizer',
type='UniversalVisualizer',
vis_backends=[dict(type='LocalVisBackend'), dict(type='TensorboardVisBackend')],
)
```
@ -459,7 +459,7 @@ env_cfg = dict(
```python
visualizer = dict(
type='ClsVisualizer',
type='UniversalVisualizer',
vis_backends=[
dict(type='LocalVisBackend'),
# 将下行取消注释,即可将日志和可视化结果保存至 TesnorBoard
@ -491,13 +491,13 @@ visualizer = dict(
`mmpretrain.core` 包被重命名为 [`mmpretrain.engine`](mmpretrain.engine)
| 子包 | 变动 |
| :-------------: | :-------------------------------------------------------------------------------------------------------------------- |
| `evaluation` | 移除,使用 [`mmpretrain.evaluation`](mmpretrain.evaluation) |
| `hook` | 移动至 [`mmpretrain.engine.hooks`](mmpretrain.engine.hooks) |
| `optimizers` | 移动至 [`mmpretrain.engine.optimizers`](mmpretrain.engine.optimizers) |
| `utils` | 移除,分布式环境相关的函数统一至 [`mmengine.dist`](mmengine.dist) 包 |
| `visualization` | 移除,其中可视化相关的功能被移动至 [`mmpretrain.visualization.ClsVisualizer`](mmpretrain.visualization.ClsVisualizer) |
| 子包 | 变动 |
| :-------------: | :-------------------------------------------------------------------------------------------------------------------------------- |
| `evaluation` | 移除,使用 [`mmpretrain.evaluation`](mmpretrain.evaluation) |
| `hook` | 移动至 [`mmpretrain.engine.hooks`](mmpretrain.engine.hooks) |
| `optimizers` | 移动至 [`mmpretrain.engine.optimizers`](mmpretrain.engine.optimizers) |
| `utils` | 移除,分布式环境相关的函数统一至 [`mmengine.dist`](mmengine.dist) 包 |
| `visualization` | 移除,其中可视化相关的功能被移动至 [`mmpretrain.visualization.UniversalVisualizer`](mmpretrain.visualization.UniversalVisualizer) |
`hooks` 包中的 `MMClsWandbHook` 尚未实现。

View File

@ -2,7 +2,7 @@
## 类别激活图可视化工具介绍
MMClassification 提供 `tools\visualizations\vis_cam.py` 工具来可视化类别激活图。请使用 `pip install "grad-cam>=1.3.6"` 安装依赖的 [pytorch-grad-cam](https://github.com/jacobgil/pytorch-grad-cam)。
MMPretrain 提供 `tools/visualization/vis_cam.py` 工具来可视化类别激活图。请使用 `pip install "grad-cam>=1.3.6"` 安装依赖的 [pytorch-grad-cam](https://github.com/jacobgil/pytorch-grad-cam)。
目前支持的方法有:
@ -18,7 +18,7 @@ MMClassification 提供 `tools\visualizations\vis_cam.py` 工具来可视化类
**命令行**
```bash
python tools/visualizations/vis_cam.py \
python tools/visualization/vis_cam.py \
${IMG} \
${CONFIG_FILE} \
${CHECKPOINT} \
@ -71,7 +71,7 @@ python tools/visualizations/vis_cam.py \
1. 使用不同方法可视化 `ResNet50`,默认 `target-category` 为模型检测的结果,使用默认推导的 `target-layers`
```shell
python tools/visualizations/vis_cam.py \
python tools/visualization/vis_cam.py \
demo/bird.JPEG \
configs/resnet/resnet50_8xb32_in1k.py \
https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_batch256_imagenet_20200708-cfb998bf.pth \
@ -86,7 +86,7 @@ python tools/visualizations/vis_cam.py \
2. 同一张图不同类别的激活图效果图,在 `ImageNet` 数据集中类别238为 'Greater Swiss Mountain dog'类别281为 'tabby, tabby cat'。
```shell
python tools/visualizations/vis_cam.py \
python tools/visualization/vis_cam.py \
demo/cat-dog.png configs/resnet/resnet50_8xb32_in1k.py \
https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_batch256_imagenet_20200708-cfb998bf.pth \
--target-layers 'backbone.layer4.2' \
@ -103,7 +103,7 @@ python tools/visualizations/vis_cam.py \
3. 使用 `--eigen-smooth` 以及 `--aug-smooth` 获取更好的可视化效果。
```shell
python tools/visualizations/vis_cam.py \
python tools/visualization/vis_cam.py \
demo/dog.jpg \
configs/mobilenet_v3/mobilenet-v3-large_8xb128_in1k.py \
https://download.openmmlab.com/mmclassification/v0/mobilenet_v3/convert/mobilenet_v3_large-3ea3c186.pth \
@ -127,12 +127,12 @@ python tools/visualizations/vis_cam.py \
除了特征被展平之外,一些类 ViT 的网络还会添加额外的 tokens。比如 ViT 和 T2T-ViT 中添加了分类 tokenDeiT 中还添加了蒸馏 token。在这些网络中分类计算在最后一个注意力模块之后就已经完成了分类得分也只和这些额外的 tokens 有关,与特征图无关,也就是说,分类得分对这些特征图的导数为 0。因此我们不能使用最后一个注意力模块的输出作为 CAM 绘制的目标层。
另外,为了去除这些额外的 toekns 以获得特征图,我们需要知道这些额外 tokens 的数量。MMClassification 中几乎所有 Transformer-based 的网络都拥有 `num_extra_tokens` 属性。而如果你希望将此工具应用于新的,或者第三方的网络,而且该网络没有指定 `num_extra_tokens` 属性,那么可以使用 `--num-extra-tokens` 参数手动指定其数量。
另外,为了去除这些额外的 toekns 以获得特征图,我们需要知道这些额外 tokens 的数量。MMPretrain 中几乎所有 Transformer-based 的网络都拥有 `num_extra_tokens` 属性。而如果你希望将此工具应用于新的,或者第三方的网络,而且该网络没有指定 `num_extra_tokens` 属性,那么可以使用 `--num-extra-tokens` 参数手动指定其数量。
1. 对 `Swin Transformer` 使用默认 `target-layers` 进行 CAM 可视化:
```shell
python tools/visualizations/vis_cam.py \
python tools/visualization/vis_cam.py \
demo/bird.JPEG \
configs/swin_transformer/swin-tiny_16xb64_in1k.py \
https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_b16x64_300e_imagenet_20210616_090925-66df6be6.pth \
@ -142,7 +142,7 @@ python tools/visualizations/vis_cam.py \
2. 对 `Vision Transformer(ViT)` 进行 CAM 可视化:
```shell
python tools/visualizations/vis_cam.py \
python tools/visualization/vis_cam.py \
demo/bird.JPEG \
configs/vision_transformer/vit-base-p16_ft-64xb64_in1k-384.py \
https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-base-p16_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-98e8652b.pth \
@ -153,7 +153,7 @@ python tools/visualizations/vis_cam.py \
3. 对 `T2T-ViT` 进行 CAM 可视化:
```shell
python tools/visualizations/vis_cam.py \
python tools/visualization/vis_cam.py \
demo/bird.JPEG \
configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py \
https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-14_3rdparty_8xb64_in1k_20210928-b7c09b62.pth \

View File

@ -3,7 +3,7 @@
## 数据集可视化工具简介
```bash
python tools/visualizations/browse_dataset.py \
python tools/visualization/browse_dataset.py \
${CONFIG_FILE} \
[-o, --output-dir ${OUTPUT_DIR}] \
[-p, --phase ${DATASET_PHASE}] \
@ -43,7 +43,7 @@ python tools/visualizations/browse_dataset.py \
使用 **'original'** 模式
```shell
python ./tools/visualizations/browse_dataset.py ./configs/resnet/resnet101_8xb16_cifar10.py --phase val --output-dir tmp --mode original --show-number 100 --rescale-factor 10 --channel-order RGB
python ./tools/visualization/browse_dataset.py ./configs/resnet/resnet101_8xb16_cifar10.py --phase val --output-dir tmp --mode original --show-number 100 --rescale-factor 10 --channel-order RGB
```
- `--phase val`: 可视化验证集, 可简化为 `-p val`;
@ -60,7 +60,7 @@ python ./tools/visualizations/browse_dataset.py ./configs/resnet/resnet101_8xb16
使用 **'transformed'** 模式:
```shell
python ./tools/visualizations/browse_dataset.py ./configs/resnet/resnet50_8xb32_in1k.py -n 100 -r 2
python ./tools/visualization/browse_dataset.py ./configs/resnet/resnet50_8xb32_in1k.py -n 100 -r 2
```
<div align=center><img src="https://user-images.githubusercontent.com/18586273/190994696-737b09d9-d0fb-4593-94a2-4487121e0286.JPEG" style=" width: auto; height: 40%; "></div>
@ -70,7 +70,7 @@ python ./tools/visualizations/browse_dataset.py ./configs/resnet/resnet50_8xb32_
使用 **'concat'** 模式:
```shell
python ./tools/visualizations/browse_dataset.py configs/swin_transformer/swin-small_16xb64_in1k.py -n 10 -m concat
python ./tools/visualization/browse_dataset.py configs/swin_transformer/swin-small_16xb64_in1k.py -n 10 -m concat
```
<div align=center><img src="https://user-images.githubusercontent.com/18586273/190995078-3872feb2-d4e2-4727-a21b-7062d52f7d3e.JPEG" style=" width: auto; height: 40%; "></div>
@ -78,7 +78,7 @@ python ./tools/visualizations/browse_dataset.py configs/swin_transformer/swin-sm
使用 **'pipeline'** 模式:
```shell
python ./tools/visualizations/browse_dataset.py configs/swin_transformer/swin-small_16xb64_in1k.py -m pipeline
python ./tools/visualization/browse_dataset.py configs/swin_transformer/swin-small_16xb64_in1k.py -m pipeline
```
<div align=center><img src="https://user-images.githubusercontent.com/18586273/190995525-fac0220f-6630-4013-b94a-bc6de4fdff7a.JPEG" style=" width: auto; height: 40%; "></div>

View File

@ -1,8 +1,8 @@
# TorchServe 部署
为了使用 [`TorchServe`](https://pytorch.org/serve/) 部署一个 `MMClassification` 模型,需要进行以下几步:
为了使用 [`TorchServe`](https://pytorch.org/serve/) 部署一个 `MMPretrain` 模型,需要进行以下几步:
## 1. 转换 MMClassification 模型至 TorchServe
## 1. 转换 MMPretrain 模型至 TorchServe
```shell
python tools/torchserve/mmpretrain2torchserve.py ${CONFIG_FILE} ${CHECKPOINT_FILE} \

View File

@ -5,7 +5,7 @@
## 工具简介
```bash
python tools/visualizations/vis_scheduler.py \
python tools/visualization/vis_scheduler.py \
${CONFIG_FILE} \
[-p, --parameter ${PARAMETER_NAME}] \
[-d, --dataset-size ${DATASET_SIZE}] \
@ -38,7 +38,7 @@ python tools/visualizations/vis_scheduler.py \
你可以使用如下命令来绘制配置文件 `configs/resnet/resnet50_b16x8_cifar100.py` 将会使用的变化率曲线:
```bash
python tools/visualizations/vis_scheduler.py configs/resnet/resnet50_b16x8_cifar100.py
python tools/visualization/vis_scheduler.py configs/resnet/resnet50_b16x8_cifar100.py
```
<div align=center><img src="https://user-images.githubusercontent.com/18586273/191006713-023f065d-d366-4165-a52e-36176367506e.png" style=" width: auto; height: 40%; "></div>
@ -46,7 +46,7 @@ python tools/visualizations/vis_scheduler.py configs/resnet/resnet50_b16x8_cifar
当数据集为 ImageNet 时,通过直接指定数据集大小来节约时间,并保存图片:
```bash
python tools/visualizations/vis_scheduler.py configs/repvgg/repvgg-B3g4_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py --dataset-size 1281167 --ngpus 4 --save-path ./repvgg-B3g4_4xb64-lr.jpg
python tools/visualization/vis_scheduler.py configs/repvgg/repvgg-B3g4_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py --dataset-size 1281167 --ngpus 4 --save-path ./repvgg-B3g4_4xb64-lr.jpg
```
<div align=center><img src="https://user-images.githubusercontent.com/18586273/191006721-0f680e07-355e-4cd6-889c-86c0cad9acb7.png" style=" width: auto; height: 40%; "></div>

View File

@ -1,6 +1,6 @@
# 数据集验证
MMClassification中,`tools/misc/verify_dataset.py` 脚本会检查数据集的所有图片,查看是否有**已经损坏**的图片。
MMPretrain 中,`tools/misc/verify_dataset.py` 脚本会检查数据集的所有图片,查看是否有**已经损坏**的图片。
## 工具介绍

View File

@ -250,7 +250,7 @@ env_cfg = dict(
# 设置可视化工具
vis_backends = [dict(type='LocalVisBackend')] # 使用磁盘(HDD)后端
visualizer = dict(
type='ClsVisualizer', vis_backends=vis_backends, name='visualizer')
type='UniversalVisualizer', vis_backends=vis_backends, name='visualizer')
# 设置日志级别
log_level = 'INFO'

View File

@ -82,7 +82,6 @@ class FeatureExtractor(BaseInferencer):
@torch.no_grad()
def forward(self, inputs: Union[dict, tuple], **kwargs):
"""Feed the inputs to the model."""
inputs = self.model.data_preprocessor(inputs, False)['inputs']
outputs = self.model.extract_feat(inputs, **kwargs)

View File

@ -57,7 +57,8 @@ class ImageClassificationInferencer(BaseInferencer):
""" # noqa: E501
visualize_kwargs: set = {
'rescale_factor', 'draw_score', 'show', 'show_dir'
'resize', 'rescale_factor', 'draw_score', 'show', 'show_dir',
'wait_time'
}
def __init__(
@ -102,6 +103,8 @@ class ImageClassificationInferencer(BaseInferencer):
return_datasamples (bool): Whether to return results as
:obj:`DataSample`. Defaults to False.
batch_size (int): Batch size. Defaults to 1.
resize (int, optional): Resize the short edge of the image to the
specified length before visualization. Defaults to None.
rescale_factor (float, optional): Rescale the image by the rescale
factor for visualization. This is helpful when the image is too
large or too small for visualization. Defaults to None.
@ -109,6 +112,8 @@ class ImageClassificationInferencer(BaseInferencer):
of prediction categories. Defaults to True.
show (bool): Whether to display the visualization result in a
window. Defaults to False.
wait_time (float): The display time (s). Defaults to 0, which means
"forever".
show_dir (str, optional): If not None, save the visualization
results in the specified directory. Defaults to None.
@ -148,6 +153,8 @@ class ImageClassificationInferencer(BaseInferencer):
ori_inputs: List[InputType],
preds: List[DataSample],
show: bool = False,
wait_time: int = 0,
resize: Optional[int] = None,
rescale_factor: Optional[float] = None,
draw_score=True,
show_dir=None):
@ -155,10 +162,8 @@ class ImageClassificationInferencer(BaseInferencer):
return None
if self.visualizer is None:
from mmpretrain.visualization import ClsVisualizer
self.visualizer = ClsVisualizer()
if self.classes is not None:
self.visualizer._dataset_meta = dict(classes=self.classes)
from mmpretrain.visualization import UniversalVisualizer
self.visualizer = UniversalVisualizer()
visualization = []
for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)):
@ -177,15 +182,18 @@ class ImageClassificationInferencer(BaseInferencer):
else:
out_file = None
self.visualizer.add_datasample(
name,
self.visualizer.visualize_cls(
image,
data_sample,
classes=self.classes,
resize=resize,
show=show,
wait_time=wait_time,
rescale_factor=rescale_factor,
draw_gt=False,
draw_pred=True,
draw_score=draw_score,
name=name,
out_file=out_file)
visualization.append(self.visualizer.get_image())
if show:

View File

@ -56,6 +56,9 @@ class ImageRetrievalInferencer(BaseInferencer):
>>> inferencer(['demo/dog.jpg', 'demo/bird.JPEG'], show_dir="./visualize/")
""" # noqa: E501
visualize_kwargs: set = {
'draw_score', 'resize', 'show_dir', 'show', 'wait_time'
}
postprocess_kwargs: set = {'topk'}
def __init__(
@ -87,6 +90,10 @@ class ImageRetrievalInferencer(BaseInferencer):
self.prototype_dataset = self._prepare_prototype(
prototype, prototype_vecs, prepare_batch_size)
# An ugly hack to escape from the duplicated arguments check in the
# base class
self.visualize_kwargs.add('topk')
def _prepare_prototype(self, prototype, prototype_vecs=None, batch_size=8):
from mmengine.dataset import DefaultSampler
from torch.utils.data import DataLoader
@ -157,13 +164,14 @@ class ImageRetrievalInferencer(BaseInferencer):
return_datasamples (bool): Whether to return results as
:obj:`DataSample`. Defaults to False.
batch_size (int): Batch size. Defaults to 1.
rescale_factor (float, optional): Rescale the image by the rescale
factor for visualization. This is helpful when the image is too
large or too small for visualization. Defaults to None.
draw_score (bool): Whether to draw the prediction scores
of prediction categories. Defaults to True.
resize (int, optional): Resize the long edge of the image to the
specified length before visualization. Defaults to None.
draw_score (bool): Whether to draw the match scores.
Defaults to True.
show (bool): Whether to display the visualization result in a
window. Defaults to False.
wait_time (float): The display time (s). Defaults to 0, which means
"forever".
show_dir (str, optional): If not None, save the visualization
results in the specified directory. Defaults to None.
@ -202,13 +210,51 @@ class ImageRetrievalInferencer(BaseInferencer):
def visualize(self,
ori_inputs: List[InputType],
preds: List[DataSample],
topk: int = 3,
resize: Optional[int] = 224,
show: bool = False,
wait_time: int = 0,
draw_score=True,
show_dir=None):
if not show and show_dir is None:
return None
raise NotImplementedError('Not implemented yet.')
if self.visualizer is None:
from mmpretrain.visualization import UniversalVisualizer
self.visualizer = UniversalVisualizer()
visualization = []
for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)):
image = imread(input_)
if isinstance(input_, str):
# The image loaded from path is BGR format.
image = image[..., ::-1]
name = Path(input_).stem
else:
name = str(i)
if show_dir is not None:
show_dir = Path(show_dir)
show_dir.mkdir(exist_ok=True)
out_file = str((show_dir / name).with_suffix('.png'))
else:
out_file = None
self.visualizer.visualize_image_retrieval(
image,
data_sample,
self.prototype_dataset,
topk=topk,
resize=resize,
draw_score=draw_score,
show=show,
wait_time=wait_time,
name=name,
out_file=out_file)
visualization.append(self.visualizer.get_image())
if show:
self.visualizer.close()
return visualization
def postprocess(
self,
@ -248,3 +294,49 @@ class ImageRetrievalInferencer(BaseInferencer):
List[str]: a list of model names.
"""
return list_models(pattern=pattern, task='Image Retrieval')
def _dispatch_kwargs(self, **kwargs):
"""Dispatch kwargs to preprocess(), forward(), visualize() and
postprocess() according to the actual demands.
Override this method to allow same argument for different methods.
Returns:
Tuple[Dict, Dict, Dict, Dict]: kwargs passed to preprocess,
forward, visualize and postprocess respectively.
"""
method_kwargs = set.union(
self.preprocess_kwargs,
self.forward_kwargs,
self.visualize_kwargs,
self.postprocess_kwargs,
)
union_kwargs = method_kwargs | set(kwargs.keys())
if union_kwargs != method_kwargs:
unknown_kwargs = union_kwargs - method_kwargs
raise ValueError(
f'unknown argument {unknown_kwargs} for `preprocess`, '
'`forward`, `visualize` and `postprocess`')
preprocess_kwargs = {}
forward_kwargs = {}
visualize_kwargs = {}
postprocess_kwargs = {}
for key, value in kwargs.items():
if key in self.preprocess_kwargs:
preprocess_kwargs[key] = value
if key in self.forward_kwargs:
forward_kwargs[key] = value
if key in self.visualize_kwargs:
visualize_kwargs[key] = value
if key in self.postprocess_kwargs:
postprocess_kwargs[key] = value
return (
preprocess_kwargs,
forward_kwargs,
visualize_kwargs,
postprocess_kwargs,
)

View File

@ -30,7 +30,7 @@ class VisualizationHook(Hook):
in the testing process. If None, handle with the backends of the
visualizer. Defaults to None.
**kwargs: other keyword arguments of
:meth:`mmpretrain.visualization.ClsVisualizer.add_datasample`.
:meth:`mmpretrain.visualization.UniversalVisualizer.visualize_cls`.
"""
def __init__(self,
@ -88,11 +88,11 @@ class VisualizationHook(Hook):
draw_args['out_file'] = join_path(self.out_dir,
f'{sample_name}_{step}.png')
self._visualizer.add_datasample(
sample_name,
self._visualizer.visualize_cls(
image=image,
data_sample=data_sample,
step=step,
name=sample_name,
**self.draw_args,
)

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .cls_visualizer import ClsVisualizer
from .utils import create_figure, get_adaptive_scale
from .visualizer import UniversalVisualizer
__all__ = ['ClsVisualizer']
__all__ = ['UniversalVisualizer', 'get_adaptive_scale', 'create_figure']

View File

@ -1,186 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple
import mmcv
import numpy as np
from mmengine.dist import master_only
from mmengine.visualization import Visualizer
from mmpretrain.registry import VISUALIZERS
from mmpretrain.structures import DataSample
def _get_adaptive_scale(img_shape: Tuple[int, int],
min_scale: float = 0.3,
max_scale: float = 3.0) -> float:
"""Get adaptive scale according to image shape.
The target scale depends on the the short edge length of the image. If the
short edge length equals 224, the output is 1.0. And output linear scales
according the short edge length.
You can also specify the minimum scale and the maximum scale to limit the
linear scale.
Args:
img_shape (Tuple[int, int]): The shape of the canvas image.
min_size (int): The minimum scale. Defaults to 0.3.
max_size (int): The maximum scale. Defaults to 3.0.
Returns:
int: The adaptive scale.
"""
short_edge_length = min(img_shape)
scale = short_edge_length / 224.
return min(max(scale, min_scale), max_scale)
@VISUALIZERS.register_module()
class ClsVisualizer(Visualizer):
"""Universal Visualizer for classification task.
Args:
name (str): Name of the instance. Defaults to 'visualizer'.
image (np.ndarray, optional): the origin image to draw. The format
should be RGB. Defaults to None.
vis_backends (list, optional): Visual backend config list.
Defaults to None.
save_dir (str, optional): Save file dir for all storage backends.
If it is None, the backend storage will not save any data.
fig_save_cfg (dict): Keyword parameters of figure for saving.
Defaults to empty dict.
fig_show_cfg (dict): Keyword parameters of figure for showing.
Defaults to empty dict.
Examples:
>>> import torch
>>> import mmcv
>>> from pathlib import Path
>>> from mmpretrain.visualization import ClsVisualizer
>>> from mmpretrain.structures import DataSample
>>> # Example image
>>> img = mmcv.imread("./demo/bird.JPEG", channel_order='rgb')
>>> # Example annotation
>>> data_sample = DataSample().set_gt_label(1).set_pred_label(1).\
... set_pred_score(torch.tensor([0.1, 0.8, 0.1]))
>>> # Setup the visualizer
>>> vis = ClsVisualizer(
... save_dir="./outputs",
... vis_backends=[dict(type='LocalVisBackend')])
>>> # Set classes names
>>> vis.dataset_meta = {'classes': ['cat', 'bird', 'dog']}
>>> # Show the example image with annotation in a figure.
>>> # And it will ignore all preset storage backends.
>>> vis.add_datasample('res', img, data_sample, show=True)
>>> # Save the visualization result by the specified storage backends.
>>> vis.add_datasample('res', img, data_sample)
>>> assert Path('./outputs/vis_data/vis_image/res_0.png').exists()
>>> # Save another visualization result with the same name.
>>> vis.add_datasample('res', img, data_sample, step=1)
>>> assert Path('./outputs/vis_data/vis_image/res_1.png').exists()
"""
@master_only
def add_datasample(self,
name: str,
image: np.ndarray,
data_sample: Optional[DataSample] = None,
draw_gt: bool = True,
draw_pred: bool = True,
draw_score: bool = True,
rescale_factor: Optional[float] = None,
show: bool = False,
text_cfg: dict = dict(),
wait_time: float = 0,
out_file: Optional[str] = None,
step: int = 0) -> None:
"""Draw datasample and save to all backends.
- If ``out_file`` is specified, all storage backends are ignored
and save the image to the ``out_file``.
- If ``show`` is True, plot the result image in a window, please
confirm you are able to access the graphical interface.
Args:
name (str): The image identifier.
image (np.ndarray): The image to draw.
data_sample (:obj:`DataSample`, optional): The annotation of the
image. Defaults to None.
draw_gt (bool): Whether to draw ground truth labels.
Defaults to True.
draw_pred (bool): Whether to draw prediction labels.
Defaults to True.
draw_score (bool): Whether to draw the prediction scores
of prediction categories. Defaults to True.
rescale_factor (float, optional): Rescale the image by the rescale
factor before visualization. Defaults to None.
show (bool): Whether to display the drawn image. Defaults to False.
text_cfg (dict): Extra text setting, which accepts
arguments of :attr:`mmengine.Visualizer.draw_texts`.
Defaults to an empty dict.
wait_time (float): The interval of show (s). Defaults to 0, which
means "forever".
out_file (str, optional): Extra path to save the visualization
result. If specified, the visualizer will only save the result
image to the out_file and ignore its storage backends.
Defaults to None.
step (int): Global step value to record. Defaults to 0.
"""
classes = None
if self.dataset_meta is not None:
classes = self.dataset_meta.get('classes', None)
if rescale_factor is not None:
image = mmcv.imrescale(image, rescale_factor)
texts = []
self.set_image(image)
if draw_gt and 'gt_label' in data_sample:
idx = data_sample.gt_label.tolist()
class_labels = [''] * len(idx)
if classes is not None:
class_labels = [f' ({classes[i]})' for i in idx]
labels = [str(idx[i]) + class_labels[i] for i in range(len(idx))]
prefix = 'Ground truth: '
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))
if draw_pred and 'pred_label' in data_sample:
idx = data_sample.pred_label.tolist()
score_labels = [''] * len(idx)
class_labels = [''] * len(idx)
if draw_score and 'pred_score' in data_sample:
score_labels = [
f', {data_sample.pred_score[i].item():.2f}' for i in idx
]
if classes is not None:
class_labels = [f' ({classes[i]})' for i in idx]
labels = [
str(idx[i]) + score_labels[i] + class_labels[i]
for i in range(len(idx))
]
prefix = 'Prediction: '
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))
img_scale = _get_adaptive_scale(image.shape[:2])
text_cfg = {
'positions': np.array([(img_scale * 5, ) * 2]).astype(np.int32),
'font_sizes': int(img_scale * 7),
'font_families': 'monospace',
'colors': 'white',
'bboxes': dict(facecolor='black', alpha=0.5, boxstyle='Round'),
**text_cfg
}
self.draw_texts('\n'.join(texts), **text_cfg)
drawn_img = self.get_image()
if show:
self.show(drawn_img, win_name=name, wait_time=wait_time)
if out_file is not None:
# save the image to the target file instead of vis_backends
mmcv.imwrite(drawn_img[..., ::-1], out_file)
else:
self.add_image(name, drawn_img, step=step)

View File

@ -0,0 +1,60 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import TYPE_CHECKING, Tuple
if TYPE_CHECKING:
from matplotlib.figure import Figure
def get_adaptive_scale(img_shape: Tuple[int, int],
min_scale: float = 0.3,
max_scale: float = 3.0) -> float:
"""Get adaptive scale according to image shape.
The target scale depends on the the short edge length of the image. If the
short edge length equals 224, the output is 1.0. And output linear scales
according the short edge length.
You can also specify the minimum scale and the maximum scale to limit the
linear scale.
Args:
img_shape (Tuple[int, int]): The shape of the canvas image.
min_size (int): The minimum scale. Defaults to 0.3.
max_size (int): The maximum scale. Defaults to 3.0.
Returns:
int: The adaptive scale.
"""
short_edge_length = min(img_shape)
scale = short_edge_length / 224.
return min(max(scale, min_scale), max_scale)
def create_figure(*args, margin=False, **kwargs) -> 'Figure':
"""Create a independent figure.
Different from the :func:`plt.figure`, the figure from this function won't
be managed by matplotlib. And it has
:obj:`matplotlib.backends.backend_agg.FigureCanvasAgg`, and therefore, you
can use the ``canvas`` attribute to get access the drawn image.
Args:
*args: All positional arguments of :class:`matplotlib.figure.Figure`.
margin: Whether to reserve the white edges of the figure.
Defaults to False.
**kwargs: All keyword arguments of :class:`matplotlib.figure.Figure`.
Return:
matplotlib.figure.Figure: The created figure.
"""
from matplotlib.backends.backend_agg import FigureCanvasAgg
from matplotlib.figure import Figure
figure = Figure(*args, **kwargs)
FigureCanvasAgg(figure)
if not margin:
# remove white edges by set subplot margin
figure.subplots_adjust(left=0, right=1, bottom=0, top=1)
return figure

View File

@ -0,0 +1,324 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence, Tuple, Union
import mmcv
import numpy as np
import torch
import torch.nn.functional as F
from mmengine.dataset import BaseDataset
from mmengine.dist import master_only
from mmengine.visualization import Visualizer
from mmengine.visualization.utils import img_from_canvas
from mmpretrain.registry import VISUALIZERS
from mmpretrain.structures import DataSample
from .utils import create_figure, get_adaptive_scale
@VISUALIZERS.register_module()
class UniversalVisualizer(Visualizer):
"""Universal Visualizer for multiple tasks.
Args:
name (str): Name of the instance. Defaults to 'visualizer'.
image (np.ndarray, optional): the origin image to draw. The format
should be RGB. Defaults to None.
vis_backends (list, optional): Visual backend config list.
Defaults to None.
save_dir (str, optional): Save file dir for all storage backends.
If it is None, the backend storage will not save any data.
fig_save_cfg (dict): Keyword parameters of figure for saving.
Defaults to empty dict.
fig_show_cfg (dict): Keyword parameters of figure for showing.
Defaults to empty dict.
"""
DEFAULT_TEXT_CFG = {
'family': 'monospace',
'color': 'white',
'bbox': dict(facecolor='black', alpha=0.5, boxstyle='Round'),
'verticalalignment': 'top',
'horizontalalignment': 'left',
}
@master_only
def visualize_cls(self,
image: np.ndarray,
data_sample: DataSample,
classes: Optional[Sequence[str]] = None,
draw_gt: bool = True,
draw_pred: bool = True,
draw_score: bool = True,
resize: Optional[int] = None,
rescale_factor: Optional[float] = None,
text_cfg: dict = dict(),
show: bool = False,
wait_time: float = 0,
out_file: Optional[str] = None,
name: str = '',
step: int = 0) -> None:
"""Visualize image classification result.
This method will draw an text box on the input image to visualize the
information about image classification, like the ground-truth label and
prediction label.
Args:
image (np.ndarray): The image to draw. The format should be RGB.
data_sample (:obj:`DataSample`): The annotation of the image.
classes (Sequence[str], optional): The categories names.
Defaults to None.
draw_gt (bool): Whether to draw ground-truth labels.
Defaults to True.
draw_pred (bool): Whether to draw prediction labels.
Defaults to True.
draw_score (bool): Whether to draw the prediction scores
of prediction categories. Defaults to True.
resize (int, optional): Resize the short edge of the image to the
specified length before visualization. Defaults to None.
rescale_factor (float, optional): Rescale the image by the rescale
factor before visualization. Defaults to None.
text_cfg (dict): Extra text setting, which accepts
arguments of :meth:`mmengine.Visualizer.draw_texts`.
Defaults to an empty dict.
show (bool): Whether to display the drawn image in a window, please
confirm your are able to access the graphical interface.
Defaults to False.
wait_time (float): The display time (s). Defaults to 0, which means
"forever".
out_file (str, optional): Extra path to save the visualization
result. If specified, the visualizer will only save the result
image to the out_file and ignore its storage backends.
Defaults to None.
name (str): The image identifier. It's useful when using the
storage backends of the visualizer to save or display the
image. Defaults to an empty string.
step (int): The global step value. It's useful to record a
series of visualization results for the same image with the
storage backends. Defaults to 0.
Returns:
np.ndarray: The visualization image.
"""
if self.dataset_meta is not None:
classes = classes or self.dataset_meta.get('classes', None)
if resize is not None:
h, w = image.shape[:2]
if w < h:
image = mmcv.imresize(image, (resize, resize * h / w))
else:
image = mmcv.imresize(image, (resize * w / h, resize))
elif rescale_factor is not None:
image = mmcv.imrescale(image, rescale_factor)
texts = []
self.set_image(image)
if draw_gt and 'gt_label' in data_sample:
idx = data_sample.gt_label.tolist()
class_labels = [''] * len(idx)
if classes is not None:
class_labels = [f' ({classes[i]})' for i in idx]
labels = [str(idx[i]) + class_labels[i] for i in range(len(idx))]
prefix = 'Ground truth: '
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))
if draw_pred and 'pred_label' in data_sample:
idx = data_sample.pred_label.tolist()
score_labels = [''] * len(idx)
class_labels = [''] * len(idx)
if draw_score and 'pred_score' in data_sample:
score_labels = [
f', {data_sample.pred_score[i].item():.2f}' for i in idx
]
if classes is not None:
class_labels = [f' ({classes[i]})' for i in idx]
labels = [
str(idx[i]) + score_labels[i] + class_labels[i]
for i in range(len(idx))
]
prefix = 'Prediction: '
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))
img_scale = get_adaptive_scale(image.shape[:2])
text_cfg = {
'size': int(img_scale * 7),
**self.DEFAULT_TEXT_CFG,
**text_cfg,
}
self.ax_save.text(
img_scale * 5,
img_scale * 5,
'\n'.join(texts),
**text_cfg,
)
drawn_img = self.get_image()
if show:
self.show(drawn_img, win_name=name, wait_time=wait_time)
if out_file is not None:
# save the image to the target file instead of vis_backends
mmcv.imwrite(drawn_img[..., ::-1], out_file)
else:
self.add_image(name, drawn_img, step=step)
return drawn_img
@master_only
def visualize_image_retrieval(self,
image: np.ndarray,
data_sample: DataSample,
prototype_dataset: BaseDataset,
topk: int = 1,
draw_score: bool = True,
resize: Optional[int] = None,
text_cfg: dict = dict(),
show: bool = False,
wait_time: float = 0,
out_file: Optional[str] = None,
name: Optional[str] = '',
step: int = 0) -> None:
"""Visualize image retrieval result.
This method will draw the input image and the images retrieved from the
prototype dataset.
Args:
image (np.ndarray): The image to draw. The format should be RGB.
data_sample (:obj:`DataSample`): The annotation of the image.
prototype_dataset (:obj:`BaseDataset`): The prototype dataset.
It should have `get_data_info` method and return a dict
includes `img_path`.
draw_score (bool): Whether to draw the match scores of the
retrieved images. Defaults to True.
resize (int, optional): Resize the long edge of the image to the
specified length before visualization. Defaults to None.
text_cfg (dict): Extra text setting, which accepts arguments of
:func:`plt.text`. Defaults to an empty dict.
show (bool): Whether to display the drawn image in a window, please
confirm your are able to access the graphical interface.
Defaults to False.
wait_time (float): The display time (s). Defaults to 0, which means
"forever".
out_file (str, optional): Extra path to save the visualization
result. If specified, the visualizer will only save the result
image to the out_file and ignore its storage backends.
Defaults to None.
name (str): The image identifier. It's useful when using the
storage backends of the visualizer to save or display the
image. Defaults to an empty string.
step (int): The global step value. It's useful to record a
series of visualization results for the same image with the
storage backends. Defaults to 0.
Returns:
np.ndarray: The visualization image.
"""
text_cfg = {**self.DEFAULT_TEXT_CFG, **text_cfg}
if resize is not None:
image = mmcv.imrescale(image, (resize, resize))
match_scores, indices = torch.topk(data_sample.pred_score, k=topk)
figure = create_figure(margin=True)
gs = figure.add_gridspec(2, topk)
query_plot = figure.add_subplot(gs[0, :])
query_plot.axis(False)
query_plot.imshow(image)
for k, (score, sample_idx) in enumerate(zip(match_scores, indices)):
sample = prototype_dataset.get_data_info(sample_idx.item())
value_image = mmcv.imread(sample['img_path'])[..., ::-1]
value_plot = figure.add_subplot(gs[1, k])
value_plot.axis(False)
value_plot.imshow(value_image)
if draw_score:
value_plot.text(
5,
5,
f'{score:.2f}',
**text_cfg,
)
drawn_img = img_from_canvas(figure.canvas)
self.set_image(drawn_img)
if show:
self.show(drawn_img, win_name=name, wait_time=wait_time)
if out_file is not None:
# save the image to the target file instead of vis_backends
mmcv.imwrite(drawn_img[..., ::-1], out_file)
else:
self.add_image(name, drawn_img, step=step)
return drawn_img
@master_only
def visualize_masked_image(self,
image: np.ndarray,
data_sample: DataSample,
resize: Union[int, Tuple[int]] = 224,
color: Union[str, Tuple[int]] = 'black',
alpha: Union[int, float] = 0.8,
show: bool = False,
wait_time: float = 0,
out_file: Optional[str] = None,
name: str = '',
step: int = 0) -> None:
"""Visualize masked image.
This method will draw an image with binary mask.
Args:
image (np.ndarray): The image to draw. The format should be RGB.
data_sample (:obj:`DataSample`): The annotation of the image.
resize (int | Tuple[int]): Resize the input image to the specified
shape. Defaults to 224.
color (str | Tuple[int]): The color of the binary mask.
Defaults to "black".
alpha (int | float): The transparency of the mask. Defaults to 0.8.
show (bool): Whether to display the drawn image in a window, please
confirm your are able to access the graphical interface.
Defaults to False.
wait_time (float): The display time (s). Defaults to 0, which means
"forever".
out_file (str, optional): Extra path to save the visualization
result. If specified, the visualizer will only save the result
image to the out_file and ignore its storage backends.
Defaults to None.
name (str): The image identifier. It's useful when using the
storage backends of the visualizer to save or display the
image. Defaults to an empty string.
step (int): The global step value. It's useful to record a
series of visualization results for the same image with the
storage backends. Defaults to 0.
Returns:
np.ndarray: The visualization image.
"""
if isinstance(resize, int):
resize = (resize, resize)
image = mmcv.imresize(image, resize)
self.set_image(image)
mask = data_sample.mask.float()[None, None, ...]
mask_ = F.interpolate(mask, image.shape[:2], mode='nearest')[0, 0]
self.draw_binary_masks(mask_.bool(), colors=color, alphas=alpha)
drawn_img = self.get_image()
if show:
self.show(drawn_img, win_name=name, wait_time=wait_time)
if out_file is not None:
# save the image to the target file instead of vis_backends
mmcv.imwrite(drawn_img[..., ::-1], out_file)
else:
self.add_image(name, drawn_img, step=step)
return drawn_img

View File

@ -10,7 +10,7 @@ from mmpretrain.apis import (ImageClassificationInferencer, ModelHub,
get_model, inference_model)
from mmpretrain.models import MobileNetV3
from mmpretrain.structures import DataSample
from mmpretrain.visualization import ClsVisualizer
from mmpretrain.visualization import UniversalVisualizer
MODEL = 'mobilenet-v3-small-050_3rdparty_in1k'
WEIGHT = 'https://download.openmmlab.com/mmclassification/v0/mobilenet_v3/mobilenet-v3-small-050_3rdparty_in1k_20221114-e0b86be1.pth' # noqa: E501
@ -69,22 +69,25 @@ class TestImageClassificationInferencer(TestCase):
with TemporaryDirectory() as tmpdir:
inferencer(img, show_dir=tmpdir)
self.assertIsInstance(inferencer.visualizer, ClsVisualizer)
self.assertIsInstance(inferencer.visualizer, UniversalVisualizer)
self.assertTrue(osp.exists(osp.join(tmpdir, '0.png')))
inferencer.visualizer = MagicMock(wraps=inferencer.visualizer)
inferencer(
img_path, rescale_factor=2., draw_score=False, show_dir=tmpdir)
self.assertTrue(osp.exists(osp.join(tmpdir, 'color.png')))
inferencer.visualizer.add_datasample.assert_called_once_with(
'color',
inferencer.visualizer.visualize_cls.assert_called_once_with(
ANY,
ANY,
classes=inferencer.classes,
resize=None,
show=False,
wait_time=0,
rescale_factor=2.,
draw_gt=False,
draw_pred=True,
draw_score=False,
name='color',
out_file=osp.join(tmpdir, 'color.png'))

View File

@ -10,13 +10,13 @@ from mmengine.runner import EpochBasedTrainLoop, IterBasedTrainLoop
from mmpretrain.engine import VisualizationHook
from mmpretrain.registry import HOOKS
from mmpretrain.structures import DataSample
from mmpretrain.visualization import ClsVisualizer
from mmpretrain.visualization import UniversalVisualizer
class TestVisualizationHook(TestCase):
def setUp(self) -> None:
ClsVisualizer.get_instance('visualizer')
UniversalVisualizer.get_instance('visualizer')
data_sample = DataSample().set_gt_label(1).set_pred_label(2)
data_sample.set_metainfo({'img_path': 'tests/data/color.jpg'})
@ -33,62 +33,68 @@ class TestVisualizationHook(TestCase):
# test enable=False
cfg = dict(type='VisualizationHook', enable=False)
hook: VisualizationHook = HOOKS.build(cfg)
with patch.object(hook._visualizer, 'add_datasample') as mock:
with patch.object(hook._visualizer, 'visualize_cls') as mock:
hook._draw_samples(1, self.data_batch, self.outputs, step=1)
mock.assert_not_called()
# test enable=True
cfg = dict(type='VisualizationHook', enable=True, show=True)
hook: VisualizationHook = HOOKS.build(cfg)
with patch.object(hook._visualizer, 'add_datasample') as mock:
with patch.object(hook._visualizer, 'visualize_cls') as mock:
hook._draw_samples(0, self.data_batch, self.outputs, step=0)
mock.assert_called_once_with(
'color.jpg',
image=ANY,
data_sample=self.outputs[0],
step=0,
show=True)
show=True,
name='color.jpg')
# test samples without path
cfg = dict(type='VisualizationHook', enable=True)
hook: VisualizationHook = HOOKS.build(cfg)
with patch.object(hook._visualizer, 'add_datasample') as mock:
with patch.object(hook._visualizer, 'visualize_cls') as mock:
outputs = [DataSample()] * 10
hook._draw_samples(0, self.data_batch, outputs, step=0)
mock.assert_called_once_with(
'0', image=ANY, data_sample=outputs[0], step=0, show=False)
image=ANY,
data_sample=outputs[0],
step=0,
show=False,
name='0')
# test out_dir
cfg = dict(
type='VisualizationHook', enable=True, out_dir=self.tmpdir.name)
hook: VisualizationHook = HOOKS.build(cfg)
with patch.object(hook._visualizer, 'add_datasample') as mock:
with patch.object(hook._visualizer, 'visualize_cls') as mock:
hook._draw_samples(0, self.data_batch, self.outputs, step=0)
mock.assert_called_once_with(
'color.jpg',
image=ANY,
data_sample=self.outputs[0],
step=0,
show=False,
name='color.jpg',
out_file=osp.join(self.tmpdir.name, 'color.jpg_0.png'))
# test sample idx
cfg = dict(type='VisualizationHook', enable=True, interval=4)
hook: VisualizationHook = HOOKS.build(cfg)
with patch.object(hook._visualizer, 'add_datasample') as mock:
with patch.object(hook._visualizer, 'visualize_cls') as mock:
hook._draw_samples(1, self.data_batch, self.outputs, step=0)
mock.assert_called_with(
'color.jpg',
image=ANY,
data_sample=self.outputs[2],
step=0,
show=False)
show=False,
name='color.jpg',
)
mock.assert_called_with(
'color.jpg',
image=ANY,
data_sample=self.outputs[6],
step=0,
show=False)
show=False,
name='color.jpg',
)
def test_after_val_iter(self):
runner = MagicMock()
@ -98,42 +104,45 @@ class TestVisualizationHook(TestCase):
runner.epoch = 5
cfg = dict(type='VisualizationHook', enable=True)
hook = HOOKS.build(cfg)
with patch.object(hook._visualizer, 'add_datasample') as mock:
with patch.object(hook._visualizer, 'visualize_cls') as mock:
hook.after_val_iter(runner, 0, self.data_batch, self.outputs)
mock.assert_called_once_with(
'color.jpg',
image=ANY,
data_sample=self.outputs[0],
step=5,
show=False)
show=False,
name='color.jpg',
)
# test iter-based
runner.train_loop = MagicMock(spec=IterBasedTrainLoop)
runner.iter = 300
cfg = dict(type='VisualizationHook', enable=True)
hook = HOOKS.build(cfg)
with patch.object(hook._visualizer, 'add_datasample') as mock:
with patch.object(hook._visualizer, 'visualize_cls') as mock:
hook.after_val_iter(runner, 0, self.data_batch, self.outputs)
mock.assert_called_once_with(
'color.jpg',
image=ANY,
data_sample=self.outputs[0],
step=300,
show=False)
show=False,
name='color.jpg',
)
def test_after_test_iter(self):
runner = MagicMock()
cfg = dict(type='VisualizationHook', enable=True)
hook = HOOKS.build(cfg)
with patch.object(hook._visualizer, 'add_datasample') as mock:
with patch.object(hook._visualizer, 'visualize_cls') as mock:
hook.after_test_iter(runner, 0, self.data_batch, self.outputs)
mock.assert_called_once_with(
'color.jpg',
image=ANY,
data_sample=self.outputs[0],
step=0,
show=False)
show=False,
name='color.jpg',
)
def tearDown(self) -> None:
self.tmpdir.cleanup()

View File

@ -0,0 +1,200 @@
# Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp
import tempfile
from unittest import TestCase
from unittest.mock import patch
import numpy as np
import torch
from mmpretrain.structures import DataSample
from mmpretrain.visualization import UniversalVisualizer
class TestUniversalVisualizer(TestCase):
def setUp(self) -> None:
super().setUp()
tmpdir = tempfile.TemporaryDirectory()
self.tmpdir = tmpdir
self.vis = UniversalVisualizer(
save_dir=tmpdir.name,
vis_backends=[dict(type='LocalVisBackend')],
)
def test_visualize_cls(self):
image = np.ones((10, 10, 3), np.uint8)
data_sample = DataSample().set_gt_label(1).set_pred_label(1).\
set_pred_score(torch.tensor([0.1, 0.8, 0.1]))
# Test show
def mock_show(drawn_img, win_name, wait_time):
self.assertFalse((image == drawn_img).all())
self.assertEqual(win_name, 'test_cls')
self.assertEqual(wait_time, 0)
with patch.object(self.vis, 'show', mock_show):
self.vis.visualize_cls(
image=image,
data_sample=data_sample,
show=True,
name='test_cls',
step=1)
# Test storage backend.
save_file = osp.join(self.tmpdir.name,
'vis_data/vis_image/test_cls_1.png')
self.assertTrue(osp.exists(save_file))
# Test out_file
out_file = osp.join(self.tmpdir.name, 'results.png')
self.vis.visualize_cls(
image=image, data_sample=data_sample, out_file=out_file)
self.assertTrue(osp.exists(out_file))
# Test with dataset_meta
self.vis.dataset_meta = {'classes': ['cat', 'bird', 'dog']}
def patch_texts(text, *_, **__):
self.assertEqual(
text, '\n'.join([
'Ground truth: 1 (bird)',
'Prediction: 1, 0.80 (bird)',
]))
with patch.object(self.vis, 'draw_texts', patch_texts):
self.vis.visualize_cls(image, data_sample)
# Test without pred_label
def patch_texts(text, *_, **__):
self.assertEqual(text, '\n'.join([
'Ground truth: 1 (bird)',
]))
with patch.object(self.vis, 'draw_texts', patch_texts):
self.vis.visualize_cls(image, data_sample, draw_pred=False)
# Test without gt_label
def patch_texts(text, *_, **__):
self.assertEqual(text, '\n'.join([
'Prediction: 1, 0.80 (bird)',
]))
with patch.object(self.vis, 'draw_texts', patch_texts):
self.vis.visualize_cls(image, data_sample, draw_gt=False)
# Test without score
del data_sample.pred_score
def patch_texts(text, *_, **__):
self.assertEqual(
text, '\n'.join([
'Ground truth: 1 (bird)',
'Prediction: 1 (bird)',
]))
with patch.object(self.vis, 'draw_texts', patch_texts):
self.vis.visualize_cls(image, data_sample)
# Test adaptive font size
def assert_font_size(target_size):
def draw_texts(text, font_sizes, *_, **__):
self.assertEqual(font_sizes, target_size)
return draw_texts
with patch.object(self.vis, 'draw_texts', assert_font_size(7)):
self.vis.visualize_cls(
np.ones((224, 384, 3), np.uint8), data_sample)
with patch.object(self.vis, 'draw_texts', assert_font_size(2)):
self.vis.visualize_cls(
np.ones((10, 384, 3), np.uint8), data_sample)
with patch.object(self.vis, 'draw_texts', assert_font_size(21)):
self.vis.visualize_cls(
np.ones((1000, 1000, 3), np.uint8), data_sample)
# Test rescale image
with patch.object(self.vis, 'draw_texts', assert_font_size(14)):
self.vis.visualize_cls(
np.ones((224, 384, 3), np.uint8),
data_sample,
rescale_factor=2.)
def test_visualize_image_retrieval(self):
image = np.ones((10, 10, 3), np.uint8)
data_sample = DataSample().set_pred_score([0.1, 0.8, 0.1])
class ToyPrototype:
def get_data_info(self, idx):
img_path = osp.join(osp.dirname(__file__), '../data/color.jpg')
return {'img_path': img_path, 'sample_idx': idx}
prototype_dataset = ToyPrototype()
# Test show
def mock_show(drawn_img, win_name, wait_time):
if image.shape == drawn_img.shape:
self.assertFalse((image == drawn_img).all())
self.assertEqual(win_name, 'test_retrieval')
self.assertEqual(wait_time, 0)
with patch.object(self.vis, 'show', mock_show):
self.vis.visualize_image_retrieval(
image,
data_sample,
prototype_dataset,
show=True,
name='test_retrieval',
step=1)
# Test storage backend.
save_file = osp.join(self.tmpdir.name,
'vis_data/vis_image/test_retrieval_1.png')
self.assertTrue(osp.exists(save_file))
# Test out_file
out_file = osp.join(self.tmpdir.name, 'results.png')
self.vis.visualize_image_retrieval(
image,
data_sample,
prototype_dataset,
out_file=out_file,
)
self.assertTrue(osp.exists(out_file))
def test_visualize_masked_image(self):
image = np.ones((10, 10, 3), np.uint8)
data_sample = DataSample().set_mask(
torch.tensor([
[0, 0, 1, 1],
[0, 1, 1, 0],
[1, 1, 0, 0],
[1, 0, 0, 1],
]))
# Test show
def mock_show(drawn_img, win_name, wait_time):
self.assertTupleEqual(drawn_img.shape, (224, 224, 3))
self.assertEqual(win_name, 'test_mask')
self.assertEqual(wait_time, 0)
with patch.object(self.vis, 'show', mock_show):
self.vis.visualize_masked_image(
image, data_sample, show=True, name='test_mask', step=1)
# Test storage backend.
save_file = osp.join(self.tmpdir.name,
'vis_data/vis_image/test_mask_1.png')
self.assertTrue(osp.exists(save_file))
# Test out_file
out_file = osp.join(self.tmpdir.name, 'results.png')
self.vis.visualize_masked_image(image, data_sample, out_file=out_file)
self.assertTrue(osp.exists(out_file))
def tearDown(self):
self.tmpdir.cleanup()

View File

@ -1,133 +0,0 @@
# Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp
import tempfile
from unittest import TestCase
from unittest.mock import patch
import numpy as np
import torch
from mmpretrain.structures import DataSample
from mmpretrain.visualization import ClsVisualizer
class TestClsVisualizer(TestCase):
def setUp(self) -> None:
super().setUp()
tmpdir = tempfile.TemporaryDirectory()
self.tmpdir = tmpdir
self.vis = ClsVisualizer(
save_dir=tmpdir.name,
vis_backends=[dict(type='LocalVisBackend')],
)
def test_add_datasample(self):
image = np.ones((10, 10, 3), np.uint8)
data_sample = DataSample().set_gt_label(1).set_pred_label(1).\
set_pred_score(torch.tensor([0.1, 0.8, 0.1]))
# Test show
def mock_show(drawn_img, win_name, wait_time):
self.assertFalse((image == drawn_img).all())
self.assertEqual(win_name, 'test')
self.assertEqual(wait_time, 0)
with patch.object(self.vis, 'show', mock_show):
self.vis.add_datasample(
'test', image=image, data_sample=data_sample, show=True)
# Test out_file
out_file = osp.join(self.tmpdir.name, 'results.png')
self.vis.add_datasample(
'test', image=image, data_sample=data_sample, out_file=out_file)
self.assertTrue(osp.exists(out_file))
# Test storage backend.
save_file = osp.join(self.tmpdir.name, 'vis_data/vis_image/test_0.png')
self.assertTrue(osp.exists(save_file))
# Test with dataset_meta
self.vis.dataset_meta = {'classes': ['cat', 'bird', 'dog']}
def test_texts(text, *_, **__):
self.assertEqual(
text, '\n'.join([
'Ground truth: 1 (bird)',
'Prediction: 1, 0.80 (bird)',
]))
with patch.object(self.vis, 'draw_texts', test_texts):
self.vis.add_datasample(
'test', image=image, data_sample=data_sample)
# Test without pred_label
def test_texts(text, *_, **__):
self.assertEqual(text, '\n'.join([
'Ground truth: 1 (bird)',
]))
with patch.object(self.vis, 'draw_texts', test_texts):
self.vis.add_datasample(
'test', image=image, data_sample=data_sample, draw_pred=False)
# Test without gt_label
def test_texts(text, *_, **__):
self.assertEqual(text, '\n'.join([
'Prediction: 1, 0.80 (bird)',
]))
with patch.object(self.vis, 'draw_texts', test_texts):
self.vis.add_datasample(
'test', image=image, data_sample=data_sample, draw_gt=False)
# Test without score
del data_sample.pred_score
def test_texts(text, *_, **__):
self.assertEqual(
text, '\n'.join([
'Ground truth: 1 (bird)',
'Prediction: 1 (bird)',
]))
with patch.object(self.vis, 'draw_texts', test_texts):
self.vis.add_datasample(
'test', image=image, data_sample=data_sample)
# Test adaptive font size
def assert_font_size(target_size):
def draw_texts(text, font_sizes, *_, **__):
self.assertEqual(font_sizes, target_size)
return draw_texts
with patch.object(self.vis, 'draw_texts', assert_font_size(7)):
self.vis.add_datasample(
'test',
image=np.ones((224, 384, 3), np.uint8),
data_sample=data_sample)
with patch.object(self.vis, 'draw_texts', assert_font_size(2)):
self.vis.add_datasample(
'test',
image=np.ones((10, 384, 3), np.uint8),
data_sample=data_sample)
with patch.object(self.vis, 'draw_texts', assert_font_size(21)):
self.vis.add_datasample(
'test',
image=np.ones((1000, 1000, 3), np.uint8),
data_sample=data_sample)
# Test rescale image
with patch.object(self.vis, 'draw_texts', assert_font_size(14)):
self.vis.add_datasample(
'test',
image=np.ones((224, 384, 3), np.uint8),
rescale_factor=2.,
data_sample=data_sample)
def tearDown(self):
self.tmpdir.cleanup()

View File

@ -10,7 +10,7 @@ from mmengine import DictAction
from mmpretrain.datasets import build_dataset
from mmpretrain.structures import ClsDataSample
from mmpretrain.visualization import ClsVisualizer
from mmpretrain.visualization import UniversalVisualizer
def parse_args():
@ -47,7 +47,7 @@ def parse_args():
def save_imgs(result_dir, folder_name, results, dataset, rescale_factor=None):
full_dir = osp.join(result_dir, folder_name)
vis = ClsVisualizer()
vis = UniversalVisualizer()
vis.dataset_meta = {'classes': dataset.CLASSES}
# save imgs
@ -67,8 +67,8 @@ def save_imgs(result_dir, folder_name, results, dataset, rescale_factor=None):
raise ValueError('Cannot load images from the dataset infos.')
if rescale_factor is not None:
img = mmcv.imrescale(img, rescale_factor)
vis.add_datasample(
name, img, data_sample, out_file=osp.join(full_dir, name + '.png'))
vis.visualize_cls(
img, data_sample, out_file=osp.join(full_dir, name + '.png'))
for k, v in result.items():
if isinstance(v, torch.Tensor):

View File

@ -69,12 +69,6 @@ def convert_revvit(ckpt):
else:
final_ckpt[k] = v
# add pos embed for cls token
if k == 'backbone.pos_embed':
v = torch.cat([torch.ones_like(v).mean(dim=1, keepdim=True), v],
dim=1)
final_ckpt[k] = v
return final_ckpt

View File

@ -3,18 +3,15 @@ import argparse
import os.path as osp
import sys
import cv2
import mmcv
import numpy as np
from mmengine.config import Config, DictAction
from mmengine.dataset import Compose
from mmengine.registry import init_default_scope
from mmengine.utils import ProgressBar
from mmengine.visualization import Visualizer
from mmengine.visualization.utils import img_from_canvas
from mmpretrain.datasets.builder import build_dataset
from mmpretrain.visualization import ClsVisualizer
from mmpretrain.visualization.cls_visualizer import _get_adaptive_scale
from mmpretrain.visualization import UniversalVisualizer, create_figure
def parse_args():
@ -90,49 +87,20 @@ def parse_args():
def make_grid(imgs, names, rescale_factor=None):
"""Concat list of pictures into a single big picture, align height here."""
vis = Visualizer()
figure = create_figure()
gs = figure.add_gridspec(1, len(imgs))
ori_shapes = [img.shape[:2] for img in imgs]
if rescale_factor is not None:
imgs = [mmcv.imrescale(img, rescale_factor) for img in imgs]
max_height = int(max(img.shape[0] for img in imgs) * 1.1)
min_width = min(img.shape[1] for img in imgs)
horizontal_gap = min_width // 10
img_scale = _get_adaptive_scale((max_height, min_width))
texts = []
text_positions = []
start_x = 0
for i, img in enumerate(imgs):
pad_height = (max_height - img.shape[0]) // 2
pad_width = horizontal_gap // 2
# make border
imgs[i] = cv2.copyMakeBorder(
img,
pad_height,
max_height - img.shape[0] - pad_height + int(img_scale * 30 * 2),
pad_width,
pad_width,
cv2.BORDER_CONSTANT,
value=(255, 255, 255))
subplot = figure.add_subplot(gs[0, i])
subplot.axis(False)
subplot.imshow(img)
subplot.set_title(f'{names[i]}\n{ori_shapes[i]}')
texts.append(f'{names[i]}\n{ori_shapes[i]}')
text_positions.append(
[start_x + img.shape[1] // 2 + pad_width, max_height])
start_x += img.shape[1] + horizontal_gap
display_img = np.concatenate(imgs, axis=1)
vis.set_image(display_img)
img_scale = _get_adaptive_scale(display_img.shape[:2])
vis.draw_texts(
texts,
positions=np.array(text_positions),
font_sizes=img_scale * 7,
colors='black',
horizontal_alignments='center',
font_families='monospace')
return vis.get_image()
return img_from_canvas(figure.canvas)
class InspectCompose(Compose):
@ -148,7 +116,7 @@ class InspectCompose(Compose):
def __call__(self, data):
if 'img' in data:
self.intermediate_imgs.append({
'name': 'original',
'name': 'Original',
'img': data['img'].copy()
})
@ -181,7 +149,7 @@ def main():
# init visualizer
cfg.visualizer.pop('type')
visualizer = ClsVisualizer(**cfg.visualizer)
visualizer = UniversalVisualizer(**cfg.visualizer)
visualizer.dataset_meta = dataset.metainfo
# init visualization image number
@ -220,13 +188,13 @@ def main():
out_file = osp.join(args.output_dir,
filename) if args.output_dir is not None else None
visualizer.add_datasample(
filename,
visualizer.visualize_cls(
image if args.channel_order == 'RGB' else image[..., ::-1],
data_sample,
rescale_factor=rescale_factor,
show=not args.not_show,
wait_time=args.show_interval,
name=filename,
out_file=out_file)
progress_bar.update()

View File

@ -2,24 +2,27 @@
import argparse
import os.path as osp
import time
from functools import partial
from typing import Optional
from collections import defaultdict
import matplotlib.pyplot as plt
import mmengine
import numpy as np
import rich.progress as progress
import torch
import torch.nn.functional as F
from mmengine.config import Config, DictAction
from mmengine.dataset import default_collate, worker_init_fn
from mmengine.dist import get_rank
from mmengine.device import get_device
from mmengine.logging import MMLogger
from mmengine.runner import Runner
from mmengine.utils import mkdir_or_exist
from sklearn.manifold import TSNE
from torch.utils.data import DataLoader
from mmpretrain.apis import get_model
from mmpretrain.registry import DATA_SAMPLERS, DATASETS
from mmpretrain.registry import DATASETS
try:
from sklearn.manifold import TSNE
except ImportError as e:
raise ImportError('Please install `sklearn` to calculate '
'TSNE by `pip install scikit-learn`') from e
def parse_args():
@ -27,22 +30,30 @@ def parse_args():
parser.add_argument('config', help='tsne config file path')
parser.add_argument('--checkpoint', default=None, help='checkpoint file')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--test-cfg',
help='tsne config file path to load config of test dataloader.')
parser.add_argument(
'--vis-stage',
choices=['backbone', 'neck', 'pre_logits'],
default='backbone',
help='the visualization stage of the model')
help='The visualization stage of the model')
parser.add_argument(
'--class-idx',
nargs='+',
type=int,
help='The categories used to calculate t-SNE.')
parser.add_argument(
'--max-num-class',
type=int,
default=20,
help='the maximum number of classes to apply t-SNE algorithms, now the'
'function supports maximum 20 classes')
parser.add_argument('--seed', type=int, default=0, help='random seed')
help='The first N categories to apply t-SNE algorithms. '
'Defaults to 20.')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
'--max-num-samples',
type=int,
default=100,
help='The maximum number of samples per category. '
'Higher number need longer time to calculate. Defaults to 100.')
parser.add_argument(
'--cfg-options',
nargs='+',
@ -53,8 +64,15 @@ def parse_args():
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument('--device', help='Device used for inference')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
'--legend',
action='store_true',
help='Show the legend of all categories.')
parser.add_argument(
'--show',
action='store_true',
help='Display the result in a graphical window.')
# t-SNE settings
parser.add_argument(
@ -98,19 +116,12 @@ def parse_args():
return args
def post_process():
pass
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
@ -135,82 +146,76 @@ def main():
log_level=cfg.log_level)
# build the model from a config file and a checkpoint file
model = get_model(cfg, args.checkpoint, device=args.device)
logger.info(f'Model loaded and the output indices of backbone is '
f'{model.backbone.out_indices}.')
device = args.device or get_device()
model = get_model(cfg, args.checkpoint, device=device)
logger.info('Model loaded.')
# build the dataset
tsne_dataloader_cfg = cfg.get('test_dataloader')
tsne_dataset_cfg = tsne_dataloader_cfg.pop('dataset')
if isinstance(tsne_dataset_cfg, dict):
dataset = DATASETS.build(tsne_dataset_cfg)
if hasattr(dataset, 'full_init'):
dataset.full_init()
if args.test_cfg is not None:
dataloader_cfg = Config.fromfile(args.test_cfg).get('test_dataloader')
elif 'test_dataloader' not in cfg:
raise ValueError('No `test_dataloader` in the config, you can '
'specify another config file that includes test '
'dataloader settings by the `--test-cfg` option.')
else:
dataloader_cfg = cfg.get('test_dataloader')
dataset = DATASETS.build(dataloader_cfg.pop('dataset'))
classes = dataset.metainfo.get('classes')
if args.class_idx is None:
num_classes = args.max_num_class if classes is None else len(classes)
args.class_idx = list(range(num_classes))[:args.max_num_class]
if classes is not None:
classes = [classes[idx] for idx in args.class_idx]
else:
classes = args.class_idx
# compress dataset, select that the label is less then max_num_class
subset_idx_list = []
counter = defaultdict(int)
for i in range(len(dataset)):
if dataset.get_data_info(i)['gt_label'] < args.max_num_class:
gt_label = dataset.get_data_info(i)['gt_label']
if (gt_label in args.class_idx
and counter[gt_label] < args.max_num_samples):
subset_idx_list.append(i)
counter[gt_label] += 1
dataset.get_subset_(subset_idx_list)
logger.info(f'Apply t-SNE to visualize {len(subset_idx_list)} samples.')
# build sampler
sampler_cfg = tsne_dataloader_cfg.pop('sampler')
if isinstance(sampler_cfg, dict):
sampler = DATA_SAMPLERS.build(
sampler_cfg, default_args=dict(dataset=dataset, seed=args.seed))
# build dataloader
init_fn: Optional[partial]
if args.seed is not None:
init_fn = partial(
worker_init_fn,
num_workers=tsne_dataloader_cfg.get('num_workers'),
rank=get_rank(),
seed=args.seed)
else:
init_fn = None
tsne_dataloader = DataLoader(
dataset=dataset,
sampler=sampler,
collate_fn=default_collate,
worker_init_fn=init_fn,
**tsne_dataloader_cfg)
dataloader_cfg.dataset = dataset
dataloader_cfg.setdefault('collate_fn', dict(type='default_collate'))
dataloader = Runner.build_dataloader(dataloader_cfg)
results = dict()
features = []
labels = []
progress_bar = mmengine.ProgressBar(len(tsne_dataloader))
for _, data in enumerate(tsne_dataloader):
for data in progress.track(dataloader, description='Calculating...'):
with torch.no_grad():
# preprocess data
data = model.data_preprocessor(data)
batch_inputs, batch_data_samples = \
data['inputs'], data['data_samples']
batch_labels = torch.cat([i.gt_label for i in batch_data_samples])
# extract backbone features
batch_features = model.extract_feat(
batch_inputs, stage=args.vis_stage)
extract_args = {}
if args.vis_stage:
extract_args['stage'] = args.vis_stage
batch_features = model.extract_feat(batch_inputs, **extract_args)
# post process
if args.vis_stage == 'backbone':
if getattr(model.backbone, 'output_cls_token', False) is False:
batch_features = [
F.adaptive_avg_pool2d(inputs, 1).squeeze()
for inputs in batch_features
]
else:
# output_cls_token is True, here t-SNE uses cls_token
batch_features = [feat[-1] for feat in batch_features]
batch_labels = torch.cat([i.gt_label for i in batch_data_samples])
if batch_features[0].ndim == 4:
# For (N, C, H, W) feature
batch_features = [
F.adaptive_avg_pool2d(inputs, 1).squeeze()
for inputs in batch_features
]
# save batch features
features.append(batch_features)
labels.extend(batch_labels.cpu().numpy())
progress_bar.update()
for i in range(len(features[0])):
key = 'feat_' + str(model.backbone.out_indices[i])
@ -238,15 +243,20 @@ def main():
result = tsne_model.fit_transform(val)
res_min, res_max = result.min(0), result.max(0)
res_norm = (result - res_min) / (res_max - res_min)
plt.figure(figsize=(10, 10))
plt.scatter(
_, ax = plt.subplots(figsize=(10, 10))
scatter = ax.scatter(
res_norm[:, 0],
res_norm[:, 1],
alpha=1.0,
s=15,
c=labels,
cmap='tab20')
if args.legend:
legend = ax.legend(scatter.legend_elements()[0], classes)
ax.add_artist(legend)
plt.savefig(f'{tsne_work_dir}{key}.png')
if args.show:
plt.show()
logger.info(f'Save features and results to {tsne_work_dir}')