From 3472ee5d2c29d7a02745eea6193ed715c1f29ed3 Mon Sep 17 00:00:00 2001 From: Ma Zerun Date: Thu, 9 Mar 2023 11:36:54 +0800 Subject: [PATCH] [Feature] Implememnt the universal visualizer for multiple task. (#1404) * [Feature] Implememnt the universal visualizer for multiple task. * Update tools * Improve according to comments. * Fix tools docs * Add --test-cfg option and set default collate function. --- .../benchmark_regression/1-benchmark_valid.py | 18 +- configs/_base_/default_runtime.py | 2 +- configs/tsne/resnet50_imagenet.py | 46 --- configs/tsne/swin-base-w6_imagenet.py | 46 --- configs/tsne/vit-base-p16_imagenet.py | 48 --- docs/en/advanced_guides/runtime.md | 6 +- docs/en/api/apis.rst | 2 + docs/en/api/visualization.rst | 7 +- docs/en/migration.md | 4 +- docs/en/useful_tools/cam_visualization.md | 18 +- docs/en/useful_tools/dataset_visualization.md | 10 +- docs/en/useful_tools/model_serving.md | 4 +- .../useful_tools/scheduler_visualization.md | 6 +- docs/en/useful_tools/verify_dataset.md | 2 +- docs/en/user_guides/config.md | 2 +- docs/zh_CN/advanced_guides/runtime.md | 6 +- docs/zh_CN/migration.md | 18 +- docs/zh_CN/useful_tools/cam_visualization.md | 18 +- .../useful_tools/dataset_visualization.md | 10 +- docs/zh_CN/useful_tools/model_serving.md | 4 +- .../useful_tools/scheduler_visualization.md | 6 +- docs/zh_CN/useful_tools/verify_dataset.md | 2 +- docs/zh_CN/user_guides/config.md | 2 +- mmpretrain/apis/feature_extractor.py | 1 - mmpretrain/apis/image_classification.py | 22 +- mmpretrain/apis/image_retrieval.py | 104 +++++- mmpretrain/engine/hooks/visualization_hook.py | 6 +- mmpretrain/visualization/__init__.py | 5 +- mmpretrain/visualization/cls_visualizer.py | 186 ---------- mmpretrain/visualization/utils.py | 60 ++++ mmpretrain/visualization/visualizer.py | 324 ++++++++++++++++++ tests/test_apis/test_inference.py | 11 +- .../test_hooks/test_visualization_hook.py | 57 +-- tests/test_visualization/test_visualizer.py | 200 +++++++++++ tests/test_visualizations/test_visualizer.py | 133 ------- tools/analysis_tools/analyze_results.py | 8 +- .../model_converters/revvit_to_mmpretrain.py | 6 - tools/visualization/browse_dataset.py | 58 +--- tools/visualization/vis_tsne.py | 160 +++++---- 39 files changed, 922 insertions(+), 706 deletions(-) delete mode 100644 configs/tsne/resnet50_imagenet.py delete mode 100644 configs/tsne/swin-base-w6_imagenet.py delete mode 100644 configs/tsne/vit-base-p16_imagenet.py delete mode 100644 mmpretrain/visualization/cls_visualizer.py create mode 100644 mmpretrain/visualization/utils.py create mode 100644 mmpretrain/visualization/visualizer.py create mode 100644 tests/test_visualization/test_visualizer.py delete mode 100644 tests/test_visualizations/test_visualizer.py diff --git a/.dev_scripts/benchmark_regression/1-benchmark_valid.py b/.dev_scripts/benchmark_regression/1-benchmark_valid.py index 586e7566..c73a4167 100644 --- a/.dev_scripts/benchmark_regression/1-benchmark_valid.py +++ b/.dev_scripts/benchmark_regression/1-benchmark_valid.py @@ -22,7 +22,7 @@ from rich.table import Table from mmpretrain.apis import get_model from mmpretrain.datasets import CIFAR10, CIFAR100, ImageNet from mmpretrain.utils import register_all_modules -from mmpretrain.visualization import ClsVisualizer +from mmpretrain.visualization import UniversalVisualizer console = Console() MMCLS_ROOT = Path(__file__).absolute().parents[2] @@ -166,7 +166,7 @@ def show_summary(summary_data, args): for model_name, summary in summary_data.items(): row = [model_name] valid = summary['valid'] - color = 'green' if valid == 'PASS' else 'red' + color = {'PASS': 'green', 'CUDA OOM': 'yellow'}.get(valid, 'red') row.append(f'[{color}]{valid}[/{color}]') if valid == 'PASS': row.append(str(summary['resolution'])) @@ -248,15 +248,19 @@ def main(args): result = inference(MMCLS_ROOT / config, checkpoint, tmpdir.name, args, model_name) result['valid'] = 'PASS' - except Exception: - import traceback - logger.error(f'"{config}" :\n{traceback.format_exc()}') - result = {'valid': 'FAIL'} + except Exception as e: + if 'CUDA out of memory' in str(e): + logger.error(f'"{config}" :\nCUDA out of memory') + result = {'valid': 'CUDA OOM'} + else: + import traceback + logger.error(f'"{config}" :\n{traceback.format_exc()}') + result = {'valid': 'FAIL'} summary_data[model_name] = result # show the results if args.show: - vis = ClsVisualizer.get_instance('valid') + vis = UniversalVisualizer.get_instance('valid') vis.set_image(mmcv.imread(args.img)) vis.draw_texts( texts='\n'.join([f'{k}: {v}' for k, v in result.items()]), diff --git a/configs/_base_/default_runtime.py b/configs/_base_/default_runtime.py index 6e66911c..3816d423 100644 --- a/configs/_base_/default_runtime.py +++ b/configs/_base_/default_runtime.py @@ -36,7 +36,7 @@ env_cfg = dict( # set visualizer vis_backends = [dict(type='LocalVisBackend')] -visualizer = dict(type='ClsVisualizer', vis_backends=vis_backends) +visualizer = dict(type='UniversalVisualizer', vis_backends=vis_backends) # set log level log_level = 'INFO' diff --git a/configs/tsne/resnet50_imagenet.py b/configs/tsne/resnet50_imagenet.py deleted file mode 100644 index bd4fbad0..00000000 --- a/configs/tsne/resnet50_imagenet.py +++ /dev/null @@ -1,46 +0,0 @@ -_base_ = '../_base_/default_runtime.py' - -model = dict( - type='ImageClassifier', - backbone=dict( - type='ResNet', - depth=50, - in_channels=3, - num_stages=4, - out_indices=(3, ), - norm_cfg=dict(type='BN'), - frozen_stages=-1), - neck=dict(type='GlobalAveragePooling'), - head=dict( - type='LinearClsHead', - num_classes=1000, - in_channels=2048, - loss=dict(type='CrossEntropyLoss', loss_weight=1.0), - topk=(1, 5), - )) - -dataset_type = 'ImageNet' -data_root = 'data/imagenet/' -data_preprocessor = dict( - num_classes=1000, - mean=[123.675, 116.28, 103.53], - std=[58.395, 57.12, 57.375], - to_rgb=True, -) -test_pipeline = [ - dict(type='LoadImageFromFile'), - dict(type='ResizeEdge', scale=256, edge='short'), - dict(type='CenterCrop', crop_size=224), - dict(type='PackInputs'), -] -test_dataloader = dict( - batch_size=8, - num_workers=4, - dataset=dict( - type=dataset_type, - data_root='data/imagenet', - ann_file='meta/val.txt', - data_prefix='val', - pipeline=test_pipeline), - sampler=dict(type='DefaultSampler', shuffle=False), -) diff --git a/configs/tsne/swin-base-w6_imagenet.py b/configs/tsne/swin-base-w6_imagenet.py deleted file mode 100644 index dd3c8f69..00000000 --- a/configs/tsne/swin-base-w6_imagenet.py +++ /dev/null @@ -1,46 +0,0 @@ -_base_ = '../_base_/default_runtime.py' - -model = dict( - type='ImageClassifier', - backbone=dict( - type='SwinTransformer', - arch='base', - img_size=192, - out_indices=-1, - drop_path_rate=0.1, - stage_cfgs=dict(block_cfgs=dict(window_size=6))), - neck=dict(type='GlobalAveragePooling'), - head=dict( - type='LinearClsHead', - num_classes=1000, - in_channels=1024, - init_cfg=None, # suppress the default init_cfg of LinearClsHead. - loss=dict( - type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), - cal_acc=False)) - -dataset_type = 'ImageNet' -data_root = 'data/imagenet/' -data_preprocessor = dict( - num_classes=1000, - mean=[123.675, 116.28, 103.53], - std=[58.395, 57.12, 57.375], - to_rgb=True, -) -test_pipeline = [ - dict(type='LoadImageFromFile'), - dict(type='ResizeEdge', scale=256, edge='short'), - dict(type='CenterCrop', crop_size=224), - dict(type='PackInputs'), -] -test_dataloader = dict( - batch_size=8, - num_workers=4, - dataset=dict( - type=dataset_type, - data_root='data/imagenet', - ann_file='meta/val.txt', - data_prefix='val', - pipeline=test_pipeline), - sampler=dict(type='DefaultSampler', shuffle=False), -) diff --git a/configs/tsne/vit-base-p16_imagenet.py b/configs/tsne/vit-base-p16_imagenet.py deleted file mode 100644 index 609ddf3d..00000000 --- a/configs/tsne/vit-base-p16_imagenet.py +++ /dev/null @@ -1,48 +0,0 @@ -_base_ = '../_base_/default_runtime.py' - -model = dict( - type='ImageClassifier', - backbone=dict( - type='VisionTransformer', - arch='base', - img_size=224, - patch_size=16, - out_indices=-1, - drop_path_rate=0.1, - out_type='featmap', - final_norm=False), - neck=None, - head=dict( - type='LinearClsHead', - num_classes=1000, - in_channels=768, - loss=dict( - type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), - init_cfg=[dict(type='TruncNormal', layer='Linear', std=2e-5)]), -) - -dataset_type = 'ImageNet' -data_root = 'data/imagenet/' -data_preprocessor = dict( - num_classes=1000, - mean=[123.675, 116.28, 103.53], - std=[58.395, 57.12, 57.375], - to_rgb=True, -) -test_pipeline = [ - dict(type='LoadImageFromFile'), - dict(type='ResizeEdge', scale=256, edge='short'), - dict(type='CenterCrop', crop_size=224), - dict(type='PackInputs'), -] -test_dataloader = dict( - batch_size=8, - num_workers=4, - dataset=dict( - type=dataset_type, - data_root='data/imagenet', - ann_file='meta/val.txt', - data_prefix='val', - pipeline=test_pipeline), - sampler=dict(type='DefaultSampler', shuffle=False), -) diff --git a/docs/en/advanced_guides/runtime.md b/docs/en/advanced_guides/runtime.md index 3c568c17..60dace1d 100644 --- a/docs/en/advanced_guides/runtime.md +++ b/docs/en/advanced_guides/runtime.md @@ -169,7 +169,7 @@ scalars. By default, the recorded information will be saved at the `vis_data` fo ```python visualizer = dict( - type='ClsVisualizer', + type='UniversalVisualizer', vis_backends=[ dict(type='LocalVisBackend'), ] @@ -181,7 +181,7 @@ For example, to save them to TensorBoard, simply set them as below: ```python visualizer = dict( - type='ClsVisualizer', + type='UniversalVisualizer', vis_backends=[ dict(type='LocalVisBackend'), dict(type='TensorboardVisBackend'), @@ -193,7 +193,7 @@ Or save them to WandB as below: ```python visualizer = dict( - type='ClsVisualizer', + type='UniversalVisualizer', vis_backends=[ dict(type='LocalVisBackend'), dict(type='WandbVisBackend'), diff --git a/docs/en/api/apis.rst b/docs/en/api/apis.rst index fd8c7970..ca8a1a19 100644 --- a/docs/en/api/apis.rst +++ b/docs/en/api/apis.rst @@ -33,6 +33,8 @@ Inference :template: callable.rst ImageClassificationInferencer + ImageRetrievalInferencer + FeatureExtractor .. autosummary:: :toctree: generated diff --git a/docs/en/api/visualization.rst b/docs/en/api/visualization.rst index 757f84ed..85742a1c 100644 --- a/docs/en/api/visualization.rst +++ b/docs/en/api/visualization.rst @@ -6,8 +6,9 @@ mmpretrain.visualization =================================== -This package includes visualizer components for classification tasks. +This package includes visualizer and some helper functions for visualization. -ClsVisualizer +Visualizer ------------- -.. autoclass:: ClsVisualizer +.. autoclass:: UniversalVisualizer + :members: diff --git a/docs/en/migration.md b/docs/en/migration.md index ba27e262..a32c06f9 100644 --- a/docs/en/migration.md +++ b/docs/en/migration.md @@ -431,7 +431,7 @@ default_hooks = dict( ) visualizer = dict( - type='ClsVisualizer', + type='UniversalVisualizer', vis_backends=[dict(type='LocalVisBackend'), dict(type='TensorboardVisBackend')], ) ``` @@ -472,7 +472,7 @@ See the {external+mmengine:doc}`MMEngine tutorial =1.3.6"` command to install [pytorch-grad-cam](https://github.com/jacobgil/pytorch-grad-cam). +MMPretrain provides `tools/visualization/vis_cam.py` tool to visualize class activation map. Please use `pip install "grad-cam>=1.3.6"` command to install [pytorch-grad-cam](https://github.com/jacobgil/pytorch-grad-cam). The supported methods are as follows: @@ -18,7 +18,7 @@ The supported methods are as follows: **Command**: ```bash -python tools/visualizations/vis_cam.py \ +python tools/visualization/vis_cam.py \ ${IMG} \ ${CONFIG_FILE} \ ${CHECKPOINT} \ @@ -73,7 +73,7 @@ For example, the `backbone.layer4[-1]` is the same as `backbone.layer4.2` since 1. Use different methods to visualize CAM for `ResNet50`, the `target-category` is the predicted result by the given checkpoint, using the default `target-layers`. ```shell - python tools/visualizations/vis_cam.py \ + python tools/visualization/vis_cam.py \ demo/bird.JPEG \ configs/resnet/resnet50_8xb32_in1k.py \ https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_batch256_imagenet_20200708-cfb998bf.pth \ @@ -88,7 +88,7 @@ For example, the `backbone.layer4[-1]` is the same as `backbone.layer4.2` since 2. Use different `target-category` to get CAM from the same picture. In `ImageNet` dataset, the category 238 is 'Greater Swiss Mountain dog', the category 281 is 'tabby, tabby cat'. ```shell - python tools/visualizations/vis_cam.py \ + python tools/visualization/vis_cam.py \ demo/cat-dog.png configs/resnet/resnet50_8xb32_in1k.py \ https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_batch256_imagenet_20200708-cfb998bf.pth \ --target-layers 'backbone.layer4.2' \ @@ -105,7 +105,7 @@ For example, the `backbone.layer4[-1]` is the same as `backbone.layer4.2` since 3. Use `--eigen-smooth` and `--aug-smooth` to improve visual effects. ```shell - python tools/visualizations/vis_cam.py \ + python tools/visualization/vis_cam.py \ demo/dog.jpg \ configs/mobilenet_v3/mobilenet-v3-large_8xb128_in1k.py \ https://download.openmmlab.com/mmclassification/v0/mobilenet_v3/convert/mobilenet_v3_large-3ea3c186.pth \ @@ -129,12 +129,12 @@ For ViT-like networks, such as ViT, T2T-ViT and Swin-Transformer, the features a Besides the flattened features, some ViT-like networks also add extra tokens like the class token in ViT and T2T-ViT, and the distillation token in DeiT. In these networks, the final classification is done on the tokens computed in the last attention block, and therefore, the classification score will not be affected by other features and the gradient of the classification score with respect to them, will be zero. Therefore, you shouldn't use the output of the last attention block as the target layer in these networks. -To exclude these extra tokens, we need know the number of extra tokens. Almost all transformer-based backbones in MMClassification have the `num_extra_tokens` attribute. If you want to use this tool in a new or third-party network that don't have the `num_extra_tokens` attribute, please specify it the `--num-extra-tokens` argument. +To exclude these extra tokens, we need know the number of extra tokens. Almost all transformer-based backbones in MMPretrain have the `num_extra_tokens` attribute. If you want to use this tool in a new or third-party network that don't have the `num_extra_tokens` attribute, please specify it the `--num-extra-tokens` argument. 1. Visualize CAM for `Swin Transformer`, using default `target-layers`: ```shell - python tools/visualizations/vis_cam.py \ + python tools/visualization/vis_cam.py \ demo/bird.JPEG \ configs/swin_transformer/swin-tiny_16xb64_in1k.py \ https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_b16x64_300e_imagenet_20210616_090925-66df6be6.pth \ @@ -144,7 +144,7 @@ To exclude these extra tokens, we need know the number of extra tokens. Almost a 2. Visualize CAM for `Vision Transformer(ViT)`: ```shell - python tools/visualizations/vis_cam.py \ + python tools/visualization/vis_cam.py \ demo/bird.JPEG \ configs/vision_transformer/vit-base-p16_ft-64xb64_in1k-384.py \ https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-base-p16_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-98e8652b.pth \ @@ -155,7 +155,7 @@ To exclude these extra tokens, we need know the number of extra tokens. Almost a 3. Visualize CAM for `T2T-ViT`: ```shell - python tools/visualizations/vis_cam.py \ + python tools/visualization/vis_cam.py \ demo/bird.JPEG \ configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py \ https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-14_3rdparty_8xb64_in1k_20210928-b7c09b62.pth \ diff --git a/docs/en/useful_tools/dataset_visualization.md b/docs/en/useful_tools/dataset_visualization.md index 11ef2842..5bacee94 100644 --- a/docs/en/useful_tools/dataset_visualization.md +++ b/docs/en/useful_tools/dataset_visualization.md @@ -3,7 +3,7 @@ ## Introduce the dataset visualization tool ```bash -python tools/visualizations/browse_dataset.py \ +python tools/visualization/browse_dataset.py \ ${CONFIG_FILE} \ [-o, --output-dir ${OUTPUT_DIR}] \ [-p, --phase ${DATASET_PHASE}] \ @@ -42,7 +42,7 @@ python tools/visualizations/browse_dataset.py \ In **'original'** mode: ```shell -python ./tools/visualizations/browse_dataset.py ./configs/resnet/resnet101_8xb16_cifar10.py --phase val --output-dir tmp --mode original --show-number 100 --rescale-factor 10 --channel-order RGB +python ./tools/visualization/browse_dataset.py ./configs/resnet/resnet101_8xb16_cifar10.py --phase val --output-dir tmp --mode original --show-number 100 --rescale-factor 10 --channel-order RGB ``` - `--phase val`: Visual validation set, can be simplified to `-p val`; @@ -59,7 +59,7 @@ python ./tools/visualizations/browse_dataset.py ./configs/resnet/resnet101_8xb16 In **'transformed'** mode: ```shell -python ./tools/visualizations/browse_dataset.py ./configs/resnet/resnet50_8xb32_in1k.py -n 100 -r 2 +python ./tools/visualization/browse_dataset.py ./configs/resnet/resnet50_8xb32_in1k.py -n 100 -r 2 ```
@@ -69,7 +69,7 @@ python ./tools/visualizations/browse_dataset.py ./configs/resnet/resnet50_8xb32_ In **'concat'** mode: ```shell -python ./tools/visualizations/browse_dataset.py configs/swin_transformer/swin-small_16xb64_in1k.py -n 10 -m concat +python ./tools/visualization/browse_dataset.py configs/swin_transformer/swin-small_16xb64_in1k.py -n 10 -m concat ```
@@ -77,7 +77,7 @@ python ./tools/visualizations/browse_dataset.py configs/swin_transformer/swin-sm 4. In **'pipeline'** mode: ```shell -python ./tools/visualizations/browse_dataset.py configs/swin_transformer/swin-small_16xb64_in1k.py -m pipeline +python ./tools/visualization/browse_dataset.py configs/swin_transformer/swin-small_16xb64_in1k.py -m pipeline ```
diff --git a/docs/en/useful_tools/model_serving.md b/docs/en/useful_tools/model_serving.md index 450eb98b..2ae6497f 100644 --- a/docs/en/useful_tools/model_serving.md +++ b/docs/en/useful_tools/model_serving.md @@ -1,8 +1,8 @@ # Torchserve Deployment -In order to serve an `MMClassification` model with [`TorchServe`](https://pytorch.org/serve/), you can follow the steps: +In order to serve an `MMPretrain` model with [`TorchServe`](https://pytorch.org/serve/), you can follow the steps: -## 1. Convert model from MMClassification to TorchServe +## 1. Convert model from MMPretrain to TorchServe ```shell python tools/torchserve/mmpretrain2torchserve.py ${CONFIG_FILE} ${CHECKPOINT_FILE} \ diff --git a/docs/en/useful_tools/scheduler_visualization.md b/docs/en/useful_tools/scheduler_visualization.md index 3d30aada..1ff6ba23 100644 --- a/docs/en/useful_tools/scheduler_visualization.md +++ b/docs/en/useful_tools/scheduler_visualization.md @@ -5,7 +5,7 @@ This tool aims to help the user to check the hyper-parameter scheduler of the op ## Introduce the scheduler visualization tool ```bash -python tools/visualizations/vis_scheduler.py \ +python tools/visualization/vis_scheduler.py \ ${CONFIG_FILE} \ [-p, --parameter ${PARAMETER_NAME}] \ [-d, --dataset-size ${DATASET_SIZE}] \ @@ -38,7 +38,7 @@ Loading annotations maybe consume much time, you can directly specify the size o You can use the following command to plot the step learning rate schedule used in the config `configs/resnet/resnet50_b16x8_cifar100.py`: ```bash -python tools/visualizations/vis_scheduler.py configs/resnet/resnet50_b16x8_cifar100.py +python tools/visualization/vis_scheduler.py configs/resnet/resnet50_b16x8_cifar100.py ```
@@ -46,7 +46,7 @@ python tools/visualizations/vis_scheduler.py configs/resnet/resnet50_b16x8_cifar When using ImageNet, directly specify the size of ImageNet, as below: ```bash -python tools/visualizations/vis_scheduler.py configs/repvgg/repvgg-B3g4_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py --dataset-size 1281167 --ngpus 4 --save-path ./repvgg-B3g4_4xb64-lr.jpg +python tools/visualization/vis_scheduler.py configs/repvgg/repvgg-B3g4_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py --dataset-size 1281167 --ngpus 4 --save-path ./repvgg-B3g4_4xb64-lr.jpg ```
diff --git a/docs/en/useful_tools/verify_dataset.md b/docs/en/useful_tools/verify_dataset.md index be80cbc1..d27948f4 100644 --- a/docs/en/useful_tools/verify_dataset.md +++ b/docs/en/useful_tools/verify_dataset.md @@ -1,6 +1,6 @@ # Verify Dataset -In MMClassification, we also provide a tool `tools/misc/verify_dataset.py` to check whether there exists **broken pictures** in the given dataset. +In MMPretrain, we also provide a tool `tools/misc/verify_dataset.py` to check whether there exists **broken pictures** in the given dataset. ## Introduce the tool diff --git a/docs/en/user_guides/config.md b/docs/en/user_guides/config.md index e0f3f8ba..97bb7c76 100644 --- a/docs/en/user_guides/config.md +++ b/docs/en/user_guides/config.md @@ -257,7 +257,7 @@ env_cfg = dict( # set visualizer vis_backends = [dict(type='LocalVisBackend')] # use local HDD backend visualizer = dict( - type='ClsVisualizer', vis_backends=vis_backends, name='visualizer') + type='UniversalVisualizer', vis_backends=vis_backends, name='visualizer') # set log level log_level = 'INFO' diff --git a/docs/zh_CN/advanced_guides/runtime.md b/docs/zh_CN/advanced_guides/runtime.md index b0f2d9ac..8f5bd57a 100644 --- a/docs/zh_CN/advanced_guides/runtime.md +++ b/docs/zh_CN/advanced_guides/runtime.md @@ -162,7 +162,7 @@ Visualizer用于记录训练和测试过程中的各种信息,包括日志、 ```python visualizer = dict( - type='ClsVisualizer', + type='UniversalVisualizer', vis_backends=[ dict(type='LocalVisBackend'), ] @@ -174,7 +174,7 @@ visualizer = dict( ```python visualizer = dict( - type='ClsVisualizer', + type='UniversalVisualizer', vis_backends=[ dict(type='LocalVisBackend'), dict(type='TensorboardVisBackend'), @@ -186,7 +186,7 @@ visualizer = dict( ```python visualizer = dict( - type='ClsVisualizer', + type='UniversalVisualizer', vis_backends=[ dict(type='LocalVisBackend'), dict(type='WandbVisBackend'), diff --git a/docs/zh_CN/migration.md b/docs/zh_CN/migration.md index b0ac7901..0231007a 100644 --- a/docs/zh_CN/migration.md +++ b/docs/zh_CN/migration.md @@ -419,7 +419,7 @@ default_hooks = dict( ) visualizer = dict( - type='ClsVisualizer', + type='UniversalVisualizer', vis_backends=[dict(type='LocalVisBackend'), dict(type='TensorboardVisBackend')], ) ``` @@ -459,7 +459,7 @@ env_cfg = dict( ```python visualizer = dict( - type='ClsVisualizer', + type='UniversalVisualizer', vis_backends=[ dict(type='LocalVisBackend'), # 将下行取消注释,即可将日志和可视化结果保存至 TesnorBoard @@ -491,13 +491,13 @@ visualizer = dict( `mmpretrain.core` 包被重命名为 [`mmpretrain.engine`](mmpretrain.engine) -| 子包 | 变动 | -| :-------------: | :-------------------------------------------------------------------------------------------------------------------- | -| `evaluation` | 移除,使用 [`mmpretrain.evaluation`](mmpretrain.evaluation) | -| `hook` | 移动至 [`mmpretrain.engine.hooks`](mmpretrain.engine.hooks) | -| `optimizers` | 移动至 [`mmpretrain.engine.optimizers`](mmpretrain.engine.optimizers) | -| `utils` | 移除,分布式环境相关的函数统一至 [`mmengine.dist`](mmengine.dist) 包 | -| `visualization` | 移除,其中可视化相关的功能被移动至 [`mmpretrain.visualization.ClsVisualizer`](mmpretrain.visualization.ClsVisualizer) | +| 子包 | 变动 | +| :-------------: | :-------------------------------------------------------------------------------------------------------------------------------- | +| `evaluation` | 移除,使用 [`mmpretrain.evaluation`](mmpretrain.evaluation) | +| `hook` | 移动至 [`mmpretrain.engine.hooks`](mmpretrain.engine.hooks) | +| `optimizers` | 移动至 [`mmpretrain.engine.optimizers`](mmpretrain.engine.optimizers) | +| `utils` | 移除,分布式环境相关的函数统一至 [`mmengine.dist`](mmengine.dist) 包 | +| `visualization` | 移除,其中可视化相关的功能被移动至 [`mmpretrain.visualization.UniversalVisualizer`](mmpretrain.visualization.UniversalVisualizer) | `hooks` 包中的 `MMClsWandbHook` 尚未实现。 diff --git a/docs/zh_CN/useful_tools/cam_visualization.md b/docs/zh_CN/useful_tools/cam_visualization.md index 23750255..33e3f4e8 100644 --- a/docs/zh_CN/useful_tools/cam_visualization.md +++ b/docs/zh_CN/useful_tools/cam_visualization.md @@ -2,7 +2,7 @@ ## 类别激活图可视化工具介绍 -MMClassification 提供 `tools\visualizations\vis_cam.py` 工具来可视化类别激活图。请使用 `pip install "grad-cam>=1.3.6"` 安装依赖的 [pytorch-grad-cam](https://github.com/jacobgil/pytorch-grad-cam)。 +MMPretrain 提供 `tools/visualization/vis_cam.py` 工具来可视化类别激活图。请使用 `pip install "grad-cam>=1.3.6"` 安装依赖的 [pytorch-grad-cam](https://github.com/jacobgil/pytorch-grad-cam)。 目前支持的方法有: @@ -18,7 +18,7 @@ MMClassification 提供 `tools\visualizations\vis_cam.py` 工具来可视化类 **命令行**: ```bash -python tools/visualizations/vis_cam.py \ +python tools/visualization/vis_cam.py \ ${IMG} \ ${CONFIG_FILE} \ ${CHECKPOINT} \ @@ -71,7 +71,7 @@ python tools/visualizations/vis_cam.py \ 1. 使用不同方法可视化 `ResNet50`,默认 `target-category` 为模型检测的结果,使用默认推导的 `target-layers`。 ```shell - python tools/visualizations/vis_cam.py \ + python tools/visualization/vis_cam.py \ demo/bird.JPEG \ configs/resnet/resnet50_8xb32_in1k.py \ https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_batch256_imagenet_20200708-cfb998bf.pth \ @@ -86,7 +86,7 @@ python tools/visualizations/vis_cam.py \ 2. 同一张图不同类别的激活图效果图,在 `ImageNet` 数据集中,类别238为 'Greater Swiss Mountain dog',类别281为 'tabby, tabby cat'。 ```shell - python tools/visualizations/vis_cam.py \ + python tools/visualization/vis_cam.py \ demo/cat-dog.png configs/resnet/resnet50_8xb32_in1k.py \ https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_batch256_imagenet_20200708-cfb998bf.pth \ --target-layers 'backbone.layer4.2' \ @@ -103,7 +103,7 @@ python tools/visualizations/vis_cam.py \ 3. 使用 `--eigen-smooth` 以及 `--aug-smooth` 获取更好的可视化效果。 ```shell - python tools/visualizations/vis_cam.py \ + python tools/visualization/vis_cam.py \ demo/dog.jpg \ configs/mobilenet_v3/mobilenet-v3-large_8xb128_in1k.py \ https://download.openmmlab.com/mmclassification/v0/mobilenet_v3/convert/mobilenet_v3_large-3ea3c186.pth \ @@ -127,12 +127,12 @@ python tools/visualizations/vis_cam.py \ 除了特征被展平之外,一些类 ViT 的网络还会添加额外的 tokens。比如 ViT 和 T2T-ViT 中添加了分类 token,DeiT 中还添加了蒸馏 token。在这些网络中,分类计算在最后一个注意力模块之后就已经完成了,分类得分也只和这些额外的 tokens 有关,与特征图无关,也就是说,分类得分对这些特征图的导数为 0。因此,我们不能使用最后一个注意力模块的输出作为 CAM 绘制的目标层。 -另外,为了去除这些额外的 toekns 以获得特征图,我们需要知道这些额外 tokens 的数量。MMClassification 中几乎所有 Transformer-based 的网络都拥有 `num_extra_tokens` 属性。而如果你希望将此工具应用于新的,或者第三方的网络,而且该网络没有指定 `num_extra_tokens` 属性,那么可以使用 `--num-extra-tokens` 参数手动指定其数量。 +另外,为了去除这些额外的 toekns 以获得特征图,我们需要知道这些额外 tokens 的数量。MMPretrain 中几乎所有 Transformer-based 的网络都拥有 `num_extra_tokens` 属性。而如果你希望将此工具应用于新的,或者第三方的网络,而且该网络没有指定 `num_extra_tokens` 属性,那么可以使用 `--num-extra-tokens` 参数手动指定其数量。 1. 对 `Swin Transformer` 使用默认 `target-layers` 进行 CAM 可视化: ```shell - python tools/visualizations/vis_cam.py \ + python tools/visualization/vis_cam.py \ demo/bird.JPEG \ configs/swin_transformer/swin-tiny_16xb64_in1k.py \ https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_b16x64_300e_imagenet_20210616_090925-66df6be6.pth \ @@ -142,7 +142,7 @@ python tools/visualizations/vis_cam.py \ 2. 对 `Vision Transformer(ViT)` 进行 CAM 可视化: ```shell - python tools/visualizations/vis_cam.py \ + python tools/visualization/vis_cam.py \ demo/bird.JPEG \ configs/vision_transformer/vit-base-p16_ft-64xb64_in1k-384.py \ https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-base-p16_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-98e8652b.pth \ @@ -153,7 +153,7 @@ python tools/visualizations/vis_cam.py \ 3. 对 `T2T-ViT` 进行 CAM 可视化: ```shell - python tools/visualizations/vis_cam.py \ + python tools/visualization/vis_cam.py \ demo/bird.JPEG \ configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py \ https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-14_3rdparty_8xb64_in1k_20210928-b7c09b62.pth \ diff --git a/docs/zh_CN/useful_tools/dataset_visualization.md b/docs/zh_CN/useful_tools/dataset_visualization.md index 3ea663ec..09acfb2c 100644 --- a/docs/zh_CN/useful_tools/dataset_visualization.md +++ b/docs/zh_CN/useful_tools/dataset_visualization.md @@ -3,7 +3,7 @@ ## 数据集可视化工具简介 ```bash -python tools/visualizations/browse_dataset.py \ +python tools/visualization/browse_dataset.py \ ${CONFIG_FILE} \ [-o, --output-dir ${OUTPUT_DIR}] \ [-p, --phase ${DATASET_PHASE}] \ @@ -43,7 +43,7 @@ python tools/visualizations/browse_dataset.py \ 使用 **'original'** 模式 : ```shell -python ./tools/visualizations/browse_dataset.py ./configs/resnet/resnet101_8xb16_cifar10.py --phase val --output-dir tmp --mode original --show-number 100 --rescale-factor 10 --channel-order RGB +python ./tools/visualization/browse_dataset.py ./configs/resnet/resnet101_8xb16_cifar10.py --phase val --output-dir tmp --mode original --show-number 100 --rescale-factor 10 --channel-order RGB ``` - `--phase val`: 可视化验证集, 可简化为 `-p val`; @@ -60,7 +60,7 @@ python ./tools/visualizations/browse_dataset.py ./configs/resnet/resnet101_8xb16 使用 **'transformed'** 模式: ```shell -python ./tools/visualizations/browse_dataset.py ./configs/resnet/resnet50_8xb32_in1k.py -n 100 -r 2 +python ./tools/visualization/browse_dataset.py ./configs/resnet/resnet50_8xb32_in1k.py -n 100 -r 2 ```
@@ -70,7 +70,7 @@ python ./tools/visualizations/browse_dataset.py ./configs/resnet/resnet50_8xb32_ 使用 **'concat'** 模式: ```shell -python ./tools/visualizations/browse_dataset.py configs/swin_transformer/swin-small_16xb64_in1k.py -n 10 -m concat +python ./tools/visualization/browse_dataset.py configs/swin_transformer/swin-small_16xb64_in1k.py -n 10 -m concat ```
@@ -78,7 +78,7 @@ python ./tools/visualizations/browse_dataset.py configs/swin_transformer/swin-sm 使用 **'pipeline'** 模式: ```shell -python ./tools/visualizations/browse_dataset.py configs/swin_transformer/swin-small_16xb64_in1k.py -m pipeline +python ./tools/visualization/browse_dataset.py configs/swin_transformer/swin-small_16xb64_in1k.py -m pipeline ```
diff --git a/docs/zh_CN/useful_tools/model_serving.md b/docs/zh_CN/useful_tools/model_serving.md index 821c936c..78423def 100644 --- a/docs/zh_CN/useful_tools/model_serving.md +++ b/docs/zh_CN/useful_tools/model_serving.md @@ -1,8 +1,8 @@ # TorchServe 部署 -为了使用 [`TorchServe`](https://pytorch.org/serve/) 部署一个 `MMClassification` 模型,需要进行以下几步: +为了使用 [`TorchServe`](https://pytorch.org/serve/) 部署一个 `MMPretrain` 模型,需要进行以下几步: -## 1. 转换 MMClassification 模型至 TorchServe +## 1. 转换 MMPretrain 模型至 TorchServe ```shell python tools/torchserve/mmpretrain2torchserve.py ${CONFIG_FILE} ${CHECKPOINT_FILE} \ diff --git a/docs/zh_CN/useful_tools/scheduler_visualization.md b/docs/zh_CN/useful_tools/scheduler_visualization.md index 1824136a..cd4ac825 100644 --- a/docs/zh_CN/useful_tools/scheduler_visualization.md +++ b/docs/zh_CN/useful_tools/scheduler_visualization.md @@ -5,7 +5,7 @@ ## 工具简介 ```bash -python tools/visualizations/vis_scheduler.py \ +python tools/visualization/vis_scheduler.py \ ${CONFIG_FILE} \ [-p, --parameter ${PARAMETER_NAME}] \ [-d, --dataset-size ${DATASET_SIZE}] \ @@ -38,7 +38,7 @@ python tools/visualizations/vis_scheduler.py \ 你可以使用如下命令来绘制配置文件 `configs/resnet/resnet50_b16x8_cifar100.py` 将会使用的变化率曲线: ```bash -python tools/visualizations/vis_scheduler.py configs/resnet/resnet50_b16x8_cifar100.py +python tools/visualization/vis_scheduler.py configs/resnet/resnet50_b16x8_cifar100.py ```
@@ -46,7 +46,7 @@ python tools/visualizations/vis_scheduler.py configs/resnet/resnet50_b16x8_cifar 当数据集为 ImageNet 时,通过直接指定数据集大小来节约时间,并保存图片: ```bash -python tools/visualizations/vis_scheduler.py configs/repvgg/repvgg-B3g4_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py --dataset-size 1281167 --ngpus 4 --save-path ./repvgg-B3g4_4xb64-lr.jpg +python tools/visualization/vis_scheduler.py configs/repvgg/repvgg-B3g4_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py --dataset-size 1281167 --ngpus 4 --save-path ./repvgg-B3g4_4xb64-lr.jpg ```
diff --git a/docs/zh_CN/useful_tools/verify_dataset.md b/docs/zh_CN/useful_tools/verify_dataset.md index 8906c6d8..d08f0848 100644 --- a/docs/zh_CN/useful_tools/verify_dataset.md +++ b/docs/zh_CN/useful_tools/verify_dataset.md @@ -1,6 +1,6 @@ # 数据集验证 -在MMClassification中,`tools/misc/verify_dataset.py` 脚本会检查数据集的所有图片,查看是否有**已经损坏**的图片。 +在 MMPretrain 中,`tools/misc/verify_dataset.py` 脚本会检查数据集的所有图片,查看是否有**已经损坏**的图片。 ## 工具介绍 diff --git a/docs/zh_CN/user_guides/config.md b/docs/zh_CN/user_guides/config.md index 8aacb4eb..47f38515 100644 --- a/docs/zh_CN/user_guides/config.md +++ b/docs/zh_CN/user_guides/config.md @@ -250,7 +250,7 @@ env_cfg = dict( # 设置可视化工具 vis_backends = [dict(type='LocalVisBackend')] # 使用磁盘(HDD)后端 visualizer = dict( - type='ClsVisualizer', vis_backends=vis_backends, name='visualizer') + type='UniversalVisualizer', vis_backends=vis_backends, name='visualizer') # 设置日志级别 log_level = 'INFO' diff --git a/mmpretrain/apis/feature_extractor.py b/mmpretrain/apis/feature_extractor.py index e15dcec3..513717fc 100644 --- a/mmpretrain/apis/feature_extractor.py +++ b/mmpretrain/apis/feature_extractor.py @@ -82,7 +82,6 @@ class FeatureExtractor(BaseInferencer): @torch.no_grad() def forward(self, inputs: Union[dict, tuple], **kwargs): - """Feed the inputs to the model.""" inputs = self.model.data_preprocessor(inputs, False)['inputs'] outputs = self.model.extract_feat(inputs, **kwargs) diff --git a/mmpretrain/apis/image_classification.py b/mmpretrain/apis/image_classification.py index 4ae689b8..7edbc42f 100644 --- a/mmpretrain/apis/image_classification.py +++ b/mmpretrain/apis/image_classification.py @@ -57,7 +57,8 @@ class ImageClassificationInferencer(BaseInferencer): """ # noqa: E501 visualize_kwargs: set = { - 'rescale_factor', 'draw_score', 'show', 'show_dir' + 'resize', 'rescale_factor', 'draw_score', 'show', 'show_dir', + 'wait_time' } def __init__( @@ -102,6 +103,8 @@ class ImageClassificationInferencer(BaseInferencer): return_datasamples (bool): Whether to return results as :obj:`DataSample`. Defaults to False. batch_size (int): Batch size. Defaults to 1. + resize (int, optional): Resize the short edge of the image to the + specified length before visualization. Defaults to None. rescale_factor (float, optional): Rescale the image by the rescale factor for visualization. This is helpful when the image is too large or too small for visualization. Defaults to None. @@ -109,6 +112,8 @@ class ImageClassificationInferencer(BaseInferencer): of prediction categories. Defaults to True. show (bool): Whether to display the visualization result in a window. Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". show_dir (str, optional): If not None, save the visualization results in the specified directory. Defaults to None. @@ -148,6 +153,8 @@ class ImageClassificationInferencer(BaseInferencer): ori_inputs: List[InputType], preds: List[DataSample], show: bool = False, + wait_time: int = 0, + resize: Optional[int] = None, rescale_factor: Optional[float] = None, draw_score=True, show_dir=None): @@ -155,10 +162,8 @@ class ImageClassificationInferencer(BaseInferencer): return None if self.visualizer is None: - from mmpretrain.visualization import ClsVisualizer - self.visualizer = ClsVisualizer() - if self.classes is not None: - self.visualizer._dataset_meta = dict(classes=self.classes) + from mmpretrain.visualization import UniversalVisualizer + self.visualizer = UniversalVisualizer() visualization = [] for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)): @@ -177,15 +182,18 @@ class ImageClassificationInferencer(BaseInferencer): else: out_file = None - self.visualizer.add_datasample( - name, + self.visualizer.visualize_cls( image, data_sample, + classes=self.classes, + resize=resize, show=show, + wait_time=wait_time, rescale_factor=rescale_factor, draw_gt=False, draw_pred=True, draw_score=draw_score, + name=name, out_file=out_file) visualization.append(self.visualizer.get_image()) if show: diff --git a/mmpretrain/apis/image_retrieval.py b/mmpretrain/apis/image_retrieval.py index f4556dfe..99507495 100644 --- a/mmpretrain/apis/image_retrieval.py +++ b/mmpretrain/apis/image_retrieval.py @@ -56,6 +56,9 @@ class ImageRetrievalInferencer(BaseInferencer): >>> inferencer(['demo/dog.jpg', 'demo/bird.JPEG'], show_dir="./visualize/") """ # noqa: E501 + visualize_kwargs: set = { + 'draw_score', 'resize', 'show_dir', 'show', 'wait_time' + } postprocess_kwargs: set = {'topk'} def __init__( @@ -87,6 +90,10 @@ class ImageRetrievalInferencer(BaseInferencer): self.prototype_dataset = self._prepare_prototype( prototype, prototype_vecs, prepare_batch_size) + # An ugly hack to escape from the duplicated arguments check in the + # base class + self.visualize_kwargs.add('topk') + def _prepare_prototype(self, prototype, prototype_vecs=None, batch_size=8): from mmengine.dataset import DefaultSampler from torch.utils.data import DataLoader @@ -157,13 +164,14 @@ class ImageRetrievalInferencer(BaseInferencer): return_datasamples (bool): Whether to return results as :obj:`DataSample`. Defaults to False. batch_size (int): Batch size. Defaults to 1. - rescale_factor (float, optional): Rescale the image by the rescale - factor for visualization. This is helpful when the image is too - large or too small for visualization. Defaults to None. - draw_score (bool): Whether to draw the prediction scores - of prediction categories. Defaults to True. + resize (int, optional): Resize the long edge of the image to the + specified length before visualization. Defaults to None. + draw_score (bool): Whether to draw the match scores. + Defaults to True. show (bool): Whether to display the visualization result in a window. Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". show_dir (str, optional): If not None, save the visualization results in the specified directory. Defaults to None. @@ -202,13 +210,51 @@ class ImageRetrievalInferencer(BaseInferencer): def visualize(self, ori_inputs: List[InputType], preds: List[DataSample], + topk: int = 3, + resize: Optional[int] = 224, show: bool = False, + wait_time: int = 0, draw_score=True, show_dir=None): if not show and show_dir is None: return None - raise NotImplementedError('Not implemented yet.') + if self.visualizer is None: + from mmpretrain.visualization import UniversalVisualizer + self.visualizer = UniversalVisualizer() + + visualization = [] + for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)): + image = imread(input_) + if isinstance(input_, str): + # The image loaded from path is BGR format. + image = image[..., ::-1] + name = Path(input_).stem + else: + name = str(i) + + if show_dir is not None: + show_dir = Path(show_dir) + show_dir.mkdir(exist_ok=True) + out_file = str((show_dir / name).with_suffix('.png')) + else: + out_file = None + + self.visualizer.visualize_image_retrieval( + image, + data_sample, + self.prototype_dataset, + topk=topk, + resize=resize, + draw_score=draw_score, + show=show, + wait_time=wait_time, + name=name, + out_file=out_file) + visualization.append(self.visualizer.get_image()) + if show: + self.visualizer.close() + return visualization def postprocess( self, @@ -248,3 +294,49 @@ class ImageRetrievalInferencer(BaseInferencer): List[str]: a list of model names. """ return list_models(pattern=pattern, task='Image Retrieval') + + def _dispatch_kwargs(self, **kwargs): + """Dispatch kwargs to preprocess(), forward(), visualize() and + postprocess() according to the actual demands. + + Override this method to allow same argument for different methods. + + Returns: + Tuple[Dict, Dict, Dict, Dict]: kwargs passed to preprocess, + forward, visualize and postprocess respectively. + """ + method_kwargs = set.union( + self.preprocess_kwargs, + self.forward_kwargs, + self.visualize_kwargs, + self.postprocess_kwargs, + ) + + union_kwargs = method_kwargs | set(kwargs.keys()) + if union_kwargs != method_kwargs: + unknown_kwargs = union_kwargs - method_kwargs + raise ValueError( + f'unknown argument {unknown_kwargs} for `preprocess`, ' + '`forward`, `visualize` and `postprocess`') + + preprocess_kwargs = {} + forward_kwargs = {} + visualize_kwargs = {} + postprocess_kwargs = {} + + for key, value in kwargs.items(): + if key in self.preprocess_kwargs: + preprocess_kwargs[key] = value + if key in self.forward_kwargs: + forward_kwargs[key] = value + if key in self.visualize_kwargs: + visualize_kwargs[key] = value + if key in self.postprocess_kwargs: + postprocess_kwargs[key] = value + + return ( + preprocess_kwargs, + forward_kwargs, + visualize_kwargs, + postprocess_kwargs, + ) diff --git a/mmpretrain/engine/hooks/visualization_hook.py b/mmpretrain/engine/hooks/visualization_hook.py index fdff5cae..64d2230a 100644 --- a/mmpretrain/engine/hooks/visualization_hook.py +++ b/mmpretrain/engine/hooks/visualization_hook.py @@ -30,7 +30,7 @@ class VisualizationHook(Hook): in the testing process. If None, handle with the backends of the visualizer. Defaults to None. **kwargs: other keyword arguments of - :meth:`mmpretrain.visualization.ClsVisualizer.add_datasample`. + :meth:`mmpretrain.visualization.UniversalVisualizer.visualize_cls`. """ def __init__(self, @@ -88,11 +88,11 @@ class VisualizationHook(Hook): draw_args['out_file'] = join_path(self.out_dir, f'{sample_name}_{step}.png') - self._visualizer.add_datasample( - sample_name, + self._visualizer.visualize_cls( image=image, data_sample=data_sample, step=step, + name=sample_name, **self.draw_args, ) diff --git a/mmpretrain/visualization/__init__.py b/mmpretrain/visualization/__init__.py index 55abb0eb..0dbeecfb 100644 --- a/mmpretrain/visualization/__init__.py +++ b/mmpretrain/visualization/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .cls_visualizer import ClsVisualizer +from .utils import create_figure, get_adaptive_scale +from .visualizer import UniversalVisualizer -__all__ = ['ClsVisualizer'] +__all__ = ['UniversalVisualizer', 'get_adaptive_scale', 'create_figure'] diff --git a/mmpretrain/visualization/cls_visualizer.py b/mmpretrain/visualization/cls_visualizer.py deleted file mode 100644 index 14075f75..00000000 --- a/mmpretrain/visualization/cls_visualizer.py +++ /dev/null @@ -1,186 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Optional, Tuple - -import mmcv -import numpy as np -from mmengine.dist import master_only -from mmengine.visualization import Visualizer - -from mmpretrain.registry import VISUALIZERS -from mmpretrain.structures import DataSample - - -def _get_adaptive_scale(img_shape: Tuple[int, int], - min_scale: float = 0.3, - max_scale: float = 3.0) -> float: - """Get adaptive scale according to image shape. - - The target scale depends on the the short edge length of the image. If the - short edge length equals 224, the output is 1.0. And output linear scales - according the short edge length. - - You can also specify the minimum scale and the maximum scale to limit the - linear scale. - - Args: - img_shape (Tuple[int, int]): The shape of the canvas image. - min_size (int): The minimum scale. Defaults to 0.3. - max_size (int): The maximum scale. Defaults to 3.0. - - Returns: - int: The adaptive scale. - """ - short_edge_length = min(img_shape) - scale = short_edge_length / 224. - return min(max(scale, min_scale), max_scale) - - -@VISUALIZERS.register_module() -class ClsVisualizer(Visualizer): - """Universal Visualizer for classification task. - - Args: - name (str): Name of the instance. Defaults to 'visualizer'. - image (np.ndarray, optional): the origin image to draw. The format - should be RGB. Defaults to None. - vis_backends (list, optional): Visual backend config list. - Defaults to None. - save_dir (str, optional): Save file dir for all storage backends. - If it is None, the backend storage will not save any data. - fig_save_cfg (dict): Keyword parameters of figure for saving. - Defaults to empty dict. - fig_show_cfg (dict): Keyword parameters of figure for showing. - Defaults to empty dict. - - Examples: - >>> import torch - >>> import mmcv - >>> from pathlib import Path - >>> from mmpretrain.visualization import ClsVisualizer - >>> from mmpretrain.structures import DataSample - >>> # Example image - >>> img = mmcv.imread("./demo/bird.JPEG", channel_order='rgb') - >>> # Example annotation - >>> data_sample = DataSample().set_gt_label(1).set_pred_label(1).\ - ... set_pred_score(torch.tensor([0.1, 0.8, 0.1])) - >>> # Setup the visualizer - >>> vis = ClsVisualizer( - ... save_dir="./outputs", - ... vis_backends=[dict(type='LocalVisBackend')]) - >>> # Set classes names - >>> vis.dataset_meta = {'classes': ['cat', 'bird', 'dog']} - >>> # Show the example image with annotation in a figure. - >>> # And it will ignore all preset storage backends. - >>> vis.add_datasample('res', img, data_sample, show=True) - >>> # Save the visualization result by the specified storage backends. - >>> vis.add_datasample('res', img, data_sample) - >>> assert Path('./outputs/vis_data/vis_image/res_0.png').exists() - >>> # Save another visualization result with the same name. - >>> vis.add_datasample('res', img, data_sample, step=1) - >>> assert Path('./outputs/vis_data/vis_image/res_1.png').exists() - """ - - @master_only - def add_datasample(self, - name: str, - image: np.ndarray, - data_sample: Optional[DataSample] = None, - draw_gt: bool = True, - draw_pred: bool = True, - draw_score: bool = True, - rescale_factor: Optional[float] = None, - show: bool = False, - text_cfg: dict = dict(), - wait_time: float = 0, - out_file: Optional[str] = None, - step: int = 0) -> None: - """Draw datasample and save to all backends. - - - If ``out_file`` is specified, all storage backends are ignored - and save the image to the ``out_file``. - - If ``show`` is True, plot the result image in a window, please - confirm you are able to access the graphical interface. - - Args: - name (str): The image identifier. - image (np.ndarray): The image to draw. - data_sample (:obj:`DataSample`, optional): The annotation of the - image. Defaults to None. - draw_gt (bool): Whether to draw ground truth labels. - Defaults to True. - draw_pred (bool): Whether to draw prediction labels. - Defaults to True. - draw_score (bool): Whether to draw the prediction scores - of prediction categories. Defaults to True. - rescale_factor (float, optional): Rescale the image by the rescale - factor before visualization. Defaults to None. - show (bool): Whether to display the drawn image. Defaults to False. - text_cfg (dict): Extra text setting, which accepts - arguments of :attr:`mmengine.Visualizer.draw_texts`. - Defaults to an empty dict. - wait_time (float): The interval of show (s). Defaults to 0, which - means "forever". - out_file (str, optional): Extra path to save the visualization - result. If specified, the visualizer will only save the result - image to the out_file and ignore its storage backends. - Defaults to None. - step (int): Global step value to record. Defaults to 0. - """ - classes = None - if self.dataset_meta is not None: - classes = self.dataset_meta.get('classes', None) - - if rescale_factor is not None: - image = mmcv.imrescale(image, rescale_factor) - - texts = [] - self.set_image(image) - - if draw_gt and 'gt_label' in data_sample: - idx = data_sample.gt_label.tolist() - class_labels = [''] * len(idx) - if classes is not None: - class_labels = [f' ({classes[i]})' for i in idx] - labels = [str(idx[i]) + class_labels[i] for i in range(len(idx))] - prefix = 'Ground truth: ' - texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels)) - - if draw_pred and 'pred_label' in data_sample: - idx = data_sample.pred_label.tolist() - score_labels = [''] * len(idx) - class_labels = [''] * len(idx) - if draw_score and 'pred_score' in data_sample: - score_labels = [ - f', {data_sample.pred_score[i].item():.2f}' for i in idx - ] - - if classes is not None: - class_labels = [f' ({classes[i]})' for i in idx] - - labels = [ - str(idx[i]) + score_labels[i] + class_labels[i] - for i in range(len(idx)) - ] - prefix = 'Prediction: ' - texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels)) - - img_scale = _get_adaptive_scale(image.shape[:2]) - text_cfg = { - 'positions': np.array([(img_scale * 5, ) * 2]).astype(np.int32), - 'font_sizes': int(img_scale * 7), - 'font_families': 'monospace', - 'colors': 'white', - 'bboxes': dict(facecolor='black', alpha=0.5, boxstyle='Round'), - **text_cfg - } - self.draw_texts('\n'.join(texts), **text_cfg) - drawn_img = self.get_image() - - if show: - self.show(drawn_img, win_name=name, wait_time=wait_time) - - if out_file is not None: - # save the image to the target file instead of vis_backends - mmcv.imwrite(drawn_img[..., ::-1], out_file) - else: - self.add_image(name, drawn_img, step=step) diff --git a/mmpretrain/visualization/utils.py b/mmpretrain/visualization/utils.py new file mode 100644 index 00000000..91a1d81f --- /dev/null +++ b/mmpretrain/visualization/utils.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import TYPE_CHECKING, Tuple + +if TYPE_CHECKING: + from matplotlib.figure import Figure + + +def get_adaptive_scale(img_shape: Tuple[int, int], + min_scale: float = 0.3, + max_scale: float = 3.0) -> float: + """Get adaptive scale according to image shape. + + The target scale depends on the the short edge length of the image. If the + short edge length equals 224, the output is 1.0. And output linear scales + according the short edge length. + + You can also specify the minimum scale and the maximum scale to limit the + linear scale. + + Args: + img_shape (Tuple[int, int]): The shape of the canvas image. + min_size (int): The minimum scale. Defaults to 0.3. + max_size (int): The maximum scale. Defaults to 3.0. + + Returns: + int: The adaptive scale. + """ + short_edge_length = min(img_shape) + scale = short_edge_length / 224. + return min(max(scale, min_scale), max_scale) + + +def create_figure(*args, margin=False, **kwargs) -> 'Figure': + """Create a independent figure. + + Different from the :func:`plt.figure`, the figure from this function won't + be managed by matplotlib. And it has + :obj:`matplotlib.backends.backend_agg.FigureCanvasAgg`, and therefore, you + can use the ``canvas`` attribute to get access the drawn image. + + Args: + *args: All positional arguments of :class:`matplotlib.figure.Figure`. + margin: Whether to reserve the white edges of the figure. + Defaults to False. + **kwargs: All keyword arguments of :class:`matplotlib.figure.Figure`. + + Return: + matplotlib.figure.Figure: The created figure. + """ + from matplotlib.backends.backend_agg import FigureCanvasAgg + from matplotlib.figure import Figure + + figure = Figure(*args, **kwargs) + FigureCanvasAgg(figure) + + if not margin: + # remove white edges by set subplot margin + figure.subplots_adjust(left=0, right=1, bottom=0, top=1) + + return figure diff --git a/mmpretrain/visualization/visualizer.py b/mmpretrain/visualization/visualizer.py new file mode 100644 index 00000000..f84f0a1f --- /dev/null +++ b/mmpretrain/visualization/visualizer.py @@ -0,0 +1,324 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence, Tuple, Union + +import mmcv +import numpy as np +import torch +import torch.nn.functional as F +from mmengine.dataset import BaseDataset +from mmengine.dist import master_only +from mmengine.visualization import Visualizer +from mmengine.visualization.utils import img_from_canvas + +from mmpretrain.registry import VISUALIZERS +from mmpretrain.structures import DataSample +from .utils import create_figure, get_adaptive_scale + + +@VISUALIZERS.register_module() +class UniversalVisualizer(Visualizer): + """Universal Visualizer for multiple tasks. + + Args: + name (str): Name of the instance. Defaults to 'visualizer'. + image (np.ndarray, optional): the origin image to draw. The format + should be RGB. Defaults to None. + vis_backends (list, optional): Visual backend config list. + Defaults to None. + save_dir (str, optional): Save file dir for all storage backends. + If it is None, the backend storage will not save any data. + fig_save_cfg (dict): Keyword parameters of figure for saving. + Defaults to empty dict. + fig_show_cfg (dict): Keyword parameters of figure for showing. + Defaults to empty dict. + """ + DEFAULT_TEXT_CFG = { + 'family': 'monospace', + 'color': 'white', + 'bbox': dict(facecolor='black', alpha=0.5, boxstyle='Round'), + 'verticalalignment': 'top', + 'horizontalalignment': 'left', + } + + @master_only + def visualize_cls(self, + image: np.ndarray, + data_sample: DataSample, + classes: Optional[Sequence[str]] = None, + draw_gt: bool = True, + draw_pred: bool = True, + draw_score: bool = True, + resize: Optional[int] = None, + rescale_factor: Optional[float] = None, + text_cfg: dict = dict(), + show: bool = False, + wait_time: float = 0, + out_file: Optional[str] = None, + name: str = '', + step: int = 0) -> None: + """Visualize image classification result. + + This method will draw an text box on the input image to visualize the + information about image classification, like the ground-truth label and + prediction label. + + Args: + image (np.ndarray): The image to draw. The format should be RGB. + data_sample (:obj:`DataSample`): The annotation of the image. + classes (Sequence[str], optional): The categories names. + Defaults to None. + draw_gt (bool): Whether to draw ground-truth labels. + Defaults to True. + draw_pred (bool): Whether to draw prediction labels. + Defaults to True. + draw_score (bool): Whether to draw the prediction scores + of prediction categories. Defaults to True. + resize (int, optional): Resize the short edge of the image to the + specified length before visualization. Defaults to None. + rescale_factor (float, optional): Rescale the image by the rescale + factor before visualization. Defaults to None. + text_cfg (dict): Extra text setting, which accepts + arguments of :meth:`mmengine.Visualizer.draw_texts`. + Defaults to an empty dict. + show (bool): Whether to display the drawn image in a window, please + confirm your are able to access the graphical interface. + Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + out_file (str, optional): Extra path to save the visualization + result. If specified, the visualizer will only save the result + image to the out_file and ignore its storage backends. + Defaults to None. + name (str): The image identifier. It's useful when using the + storage backends of the visualizer to save or display the + image. Defaults to an empty string. + step (int): The global step value. It's useful to record a + series of visualization results for the same image with the + storage backends. Defaults to 0. + + Returns: + np.ndarray: The visualization image. + """ + if self.dataset_meta is not None: + classes = classes or self.dataset_meta.get('classes', None) + + if resize is not None: + h, w = image.shape[:2] + if w < h: + image = mmcv.imresize(image, (resize, resize * h / w)) + else: + image = mmcv.imresize(image, (resize * w / h, resize)) + elif rescale_factor is not None: + image = mmcv.imrescale(image, rescale_factor) + + texts = [] + self.set_image(image) + + if draw_gt and 'gt_label' in data_sample: + idx = data_sample.gt_label.tolist() + class_labels = [''] * len(idx) + if classes is not None: + class_labels = [f' ({classes[i]})' for i in idx] + labels = [str(idx[i]) + class_labels[i] for i in range(len(idx))] + prefix = 'Ground truth: ' + texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels)) + + if draw_pred and 'pred_label' in data_sample: + idx = data_sample.pred_label.tolist() + score_labels = [''] * len(idx) + class_labels = [''] * len(idx) + if draw_score and 'pred_score' in data_sample: + score_labels = [ + f', {data_sample.pred_score[i].item():.2f}' for i in idx + ] + + if classes is not None: + class_labels = [f' ({classes[i]})' for i in idx] + + labels = [ + str(idx[i]) + score_labels[i] + class_labels[i] + for i in range(len(idx)) + ] + prefix = 'Prediction: ' + texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels)) + + img_scale = get_adaptive_scale(image.shape[:2]) + text_cfg = { + 'size': int(img_scale * 7), + **self.DEFAULT_TEXT_CFG, + **text_cfg, + } + self.ax_save.text( + img_scale * 5, + img_scale * 5, + '\n'.join(texts), + **text_cfg, + ) + drawn_img = self.get_image() + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + # save the image to the target file instead of vis_backends + mmcv.imwrite(drawn_img[..., ::-1], out_file) + else: + self.add_image(name, drawn_img, step=step) + + return drawn_img + + @master_only + def visualize_image_retrieval(self, + image: np.ndarray, + data_sample: DataSample, + prototype_dataset: BaseDataset, + topk: int = 1, + draw_score: bool = True, + resize: Optional[int] = None, + text_cfg: dict = dict(), + show: bool = False, + wait_time: float = 0, + out_file: Optional[str] = None, + name: Optional[str] = '', + step: int = 0) -> None: + """Visualize image retrieval result. + + This method will draw the input image and the images retrieved from the + prototype dataset. + + Args: + image (np.ndarray): The image to draw. The format should be RGB. + data_sample (:obj:`DataSample`): The annotation of the image. + prototype_dataset (:obj:`BaseDataset`): The prototype dataset. + It should have `get_data_info` method and return a dict + includes `img_path`. + draw_score (bool): Whether to draw the match scores of the + retrieved images. Defaults to True. + resize (int, optional): Resize the long edge of the image to the + specified length before visualization. Defaults to None. + text_cfg (dict): Extra text setting, which accepts arguments of + :func:`plt.text`. Defaults to an empty dict. + show (bool): Whether to display the drawn image in a window, please + confirm your are able to access the graphical interface. + Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + out_file (str, optional): Extra path to save the visualization + result. If specified, the visualizer will only save the result + image to the out_file and ignore its storage backends. + Defaults to None. + name (str): The image identifier. It's useful when using the + storage backends of the visualizer to save or display the + image. Defaults to an empty string. + step (int): The global step value. It's useful to record a + series of visualization results for the same image with the + storage backends. Defaults to 0. + + Returns: + np.ndarray: The visualization image. + """ + text_cfg = {**self.DEFAULT_TEXT_CFG, **text_cfg} + if resize is not None: + image = mmcv.imrescale(image, (resize, resize)) + + match_scores, indices = torch.topk(data_sample.pred_score, k=topk) + + figure = create_figure(margin=True) + gs = figure.add_gridspec(2, topk) + query_plot = figure.add_subplot(gs[0, :]) + query_plot.axis(False) + query_plot.imshow(image) + + for k, (score, sample_idx) in enumerate(zip(match_scores, indices)): + sample = prototype_dataset.get_data_info(sample_idx.item()) + value_image = mmcv.imread(sample['img_path'])[..., ::-1] + value_plot = figure.add_subplot(gs[1, k]) + value_plot.axis(False) + value_plot.imshow(value_image) + if draw_score: + value_plot.text( + 5, + 5, + f'{score:.2f}', + **text_cfg, + ) + drawn_img = img_from_canvas(figure.canvas) + self.set_image(drawn_img) + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + # save the image to the target file instead of vis_backends + mmcv.imwrite(drawn_img[..., ::-1], out_file) + else: + self.add_image(name, drawn_img, step=step) + + return drawn_img + + @master_only + def visualize_masked_image(self, + image: np.ndarray, + data_sample: DataSample, + resize: Union[int, Tuple[int]] = 224, + color: Union[str, Tuple[int]] = 'black', + alpha: Union[int, float] = 0.8, + show: bool = False, + wait_time: float = 0, + out_file: Optional[str] = None, + name: str = '', + step: int = 0) -> None: + """Visualize masked image. + + This method will draw an image with binary mask. + + Args: + image (np.ndarray): The image to draw. The format should be RGB. + data_sample (:obj:`DataSample`): The annotation of the image. + resize (int | Tuple[int]): Resize the input image to the specified + shape. Defaults to 224. + color (str | Tuple[int]): The color of the binary mask. + Defaults to "black". + alpha (int | float): The transparency of the mask. Defaults to 0.8. + show (bool): Whether to display the drawn image in a window, please + confirm your are able to access the graphical interface. + Defaults to False. + wait_time (float): The display time (s). Defaults to 0, which means + "forever". + out_file (str, optional): Extra path to save the visualization + result. If specified, the visualizer will only save the result + image to the out_file and ignore its storage backends. + Defaults to None. + name (str): The image identifier. It's useful when using the + storage backends of the visualizer to save or display the + image. Defaults to an empty string. + step (int): The global step value. It's useful to record a + series of visualization results for the same image with the + storage backends. Defaults to 0. + + Returns: + np.ndarray: The visualization image. + """ + if isinstance(resize, int): + resize = (resize, resize) + + image = mmcv.imresize(image, resize) + self.set_image(image) + + mask = data_sample.mask.float()[None, None, ...] + mask_ = F.interpolate(mask, image.shape[:2], mode='nearest')[0, 0] + + self.draw_binary_masks(mask_.bool(), colors=color, alphas=alpha) + + drawn_img = self.get_image() + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + # save the image to the target file instead of vis_backends + mmcv.imwrite(drawn_img[..., ::-1], out_file) + else: + self.add_image(name, drawn_img, step=step) + + return drawn_img diff --git a/tests/test_apis/test_inference.py b/tests/test_apis/test_inference.py index d95c716e..1e9d7cbc 100644 --- a/tests/test_apis/test_inference.py +++ b/tests/test_apis/test_inference.py @@ -10,7 +10,7 @@ from mmpretrain.apis import (ImageClassificationInferencer, ModelHub, get_model, inference_model) from mmpretrain.models import MobileNetV3 from mmpretrain.structures import DataSample -from mmpretrain.visualization import ClsVisualizer +from mmpretrain.visualization import UniversalVisualizer MODEL = 'mobilenet-v3-small-050_3rdparty_in1k' WEIGHT = 'https://download.openmmlab.com/mmclassification/v0/mobilenet_v3/mobilenet-v3-small-050_3rdparty_in1k_20221114-e0b86be1.pth' # noqa: E501 @@ -69,22 +69,25 @@ class TestImageClassificationInferencer(TestCase): with TemporaryDirectory() as tmpdir: inferencer(img, show_dir=tmpdir) - self.assertIsInstance(inferencer.visualizer, ClsVisualizer) + self.assertIsInstance(inferencer.visualizer, UniversalVisualizer) self.assertTrue(osp.exists(osp.join(tmpdir, '0.png'))) inferencer.visualizer = MagicMock(wraps=inferencer.visualizer) inferencer( img_path, rescale_factor=2., draw_score=False, show_dir=tmpdir) self.assertTrue(osp.exists(osp.join(tmpdir, 'color.png'))) - inferencer.visualizer.add_datasample.assert_called_once_with( - 'color', + inferencer.visualizer.visualize_cls.assert_called_once_with( ANY, ANY, + classes=inferencer.classes, + resize=None, show=False, + wait_time=0, rescale_factor=2., draw_gt=False, draw_pred=True, draw_score=False, + name='color', out_file=osp.join(tmpdir, 'color.png')) diff --git a/tests/test_engine/test_hooks/test_visualization_hook.py b/tests/test_engine/test_hooks/test_visualization_hook.py index e68000a7..2fe0ae30 100644 --- a/tests/test_engine/test_hooks/test_visualization_hook.py +++ b/tests/test_engine/test_hooks/test_visualization_hook.py @@ -10,13 +10,13 @@ from mmengine.runner import EpochBasedTrainLoop, IterBasedTrainLoop from mmpretrain.engine import VisualizationHook from mmpretrain.registry import HOOKS from mmpretrain.structures import DataSample -from mmpretrain.visualization import ClsVisualizer +from mmpretrain.visualization import UniversalVisualizer class TestVisualizationHook(TestCase): def setUp(self) -> None: - ClsVisualizer.get_instance('visualizer') + UniversalVisualizer.get_instance('visualizer') data_sample = DataSample().set_gt_label(1).set_pred_label(2) data_sample.set_metainfo({'img_path': 'tests/data/color.jpg'}) @@ -33,62 +33,68 @@ class TestVisualizationHook(TestCase): # test enable=False cfg = dict(type='VisualizationHook', enable=False) hook: VisualizationHook = HOOKS.build(cfg) - with patch.object(hook._visualizer, 'add_datasample') as mock: + with patch.object(hook._visualizer, 'visualize_cls') as mock: hook._draw_samples(1, self.data_batch, self.outputs, step=1) mock.assert_not_called() # test enable=True cfg = dict(type='VisualizationHook', enable=True, show=True) hook: VisualizationHook = HOOKS.build(cfg) - with patch.object(hook._visualizer, 'add_datasample') as mock: + with patch.object(hook._visualizer, 'visualize_cls') as mock: hook._draw_samples(0, self.data_batch, self.outputs, step=0) mock.assert_called_once_with( - 'color.jpg', image=ANY, data_sample=self.outputs[0], step=0, - show=True) + show=True, + name='color.jpg') # test samples without path cfg = dict(type='VisualizationHook', enable=True) hook: VisualizationHook = HOOKS.build(cfg) - with patch.object(hook._visualizer, 'add_datasample') as mock: + with patch.object(hook._visualizer, 'visualize_cls') as mock: outputs = [DataSample()] * 10 hook._draw_samples(0, self.data_batch, outputs, step=0) mock.assert_called_once_with( - '0', image=ANY, data_sample=outputs[0], step=0, show=False) + image=ANY, + data_sample=outputs[0], + step=0, + show=False, + name='0') # test out_dir cfg = dict( type='VisualizationHook', enable=True, out_dir=self.tmpdir.name) hook: VisualizationHook = HOOKS.build(cfg) - with patch.object(hook._visualizer, 'add_datasample') as mock: + with patch.object(hook._visualizer, 'visualize_cls') as mock: hook._draw_samples(0, self.data_batch, self.outputs, step=0) mock.assert_called_once_with( - 'color.jpg', image=ANY, data_sample=self.outputs[0], step=0, show=False, + name='color.jpg', out_file=osp.join(self.tmpdir.name, 'color.jpg_0.png')) # test sample idx cfg = dict(type='VisualizationHook', enable=True, interval=4) hook: VisualizationHook = HOOKS.build(cfg) - with patch.object(hook._visualizer, 'add_datasample') as mock: + with patch.object(hook._visualizer, 'visualize_cls') as mock: hook._draw_samples(1, self.data_batch, self.outputs, step=0) mock.assert_called_with( - 'color.jpg', image=ANY, data_sample=self.outputs[2], step=0, - show=False) + show=False, + name='color.jpg', + ) mock.assert_called_with( - 'color.jpg', image=ANY, data_sample=self.outputs[6], step=0, - show=False) + show=False, + name='color.jpg', + ) def test_after_val_iter(self): runner = MagicMock() @@ -98,42 +104,45 @@ class TestVisualizationHook(TestCase): runner.epoch = 5 cfg = dict(type='VisualizationHook', enable=True) hook = HOOKS.build(cfg) - with patch.object(hook._visualizer, 'add_datasample') as mock: + with patch.object(hook._visualizer, 'visualize_cls') as mock: hook.after_val_iter(runner, 0, self.data_batch, self.outputs) mock.assert_called_once_with( - 'color.jpg', image=ANY, data_sample=self.outputs[0], step=5, - show=False) + show=False, + name='color.jpg', + ) # test iter-based runner.train_loop = MagicMock(spec=IterBasedTrainLoop) runner.iter = 300 cfg = dict(type='VisualizationHook', enable=True) hook = HOOKS.build(cfg) - with patch.object(hook._visualizer, 'add_datasample') as mock: + with patch.object(hook._visualizer, 'visualize_cls') as mock: hook.after_val_iter(runner, 0, self.data_batch, self.outputs) mock.assert_called_once_with( - 'color.jpg', image=ANY, data_sample=self.outputs[0], step=300, - show=False) + show=False, + name='color.jpg', + ) def test_after_test_iter(self): runner = MagicMock() cfg = dict(type='VisualizationHook', enable=True) hook = HOOKS.build(cfg) - with patch.object(hook._visualizer, 'add_datasample') as mock: + with patch.object(hook._visualizer, 'visualize_cls') as mock: hook.after_test_iter(runner, 0, self.data_batch, self.outputs) mock.assert_called_once_with( - 'color.jpg', image=ANY, data_sample=self.outputs[0], step=0, - show=False) + show=False, + name='color.jpg', + ) def tearDown(self) -> None: self.tmpdir.cleanup() diff --git a/tests/test_visualization/test_visualizer.py b/tests/test_visualization/test_visualizer.py new file mode 100644 index 00000000..900e495c --- /dev/null +++ b/tests/test_visualization/test_visualizer.py @@ -0,0 +1,200 @@ +# Copyright (c) Open-MMLab. All rights reserved. +import os.path as osp +import tempfile +from unittest import TestCase +from unittest.mock import patch + +import numpy as np +import torch + +from mmpretrain.structures import DataSample +from mmpretrain.visualization import UniversalVisualizer + + +class TestUniversalVisualizer(TestCase): + + def setUp(self) -> None: + super().setUp() + tmpdir = tempfile.TemporaryDirectory() + self.tmpdir = tmpdir + self.vis = UniversalVisualizer( + save_dir=tmpdir.name, + vis_backends=[dict(type='LocalVisBackend')], + ) + + def test_visualize_cls(self): + image = np.ones((10, 10, 3), np.uint8) + data_sample = DataSample().set_gt_label(1).set_pred_label(1).\ + set_pred_score(torch.tensor([0.1, 0.8, 0.1])) + + # Test show + def mock_show(drawn_img, win_name, wait_time): + self.assertFalse((image == drawn_img).all()) + self.assertEqual(win_name, 'test_cls') + self.assertEqual(wait_time, 0) + + with patch.object(self.vis, 'show', mock_show): + self.vis.visualize_cls( + image=image, + data_sample=data_sample, + show=True, + name='test_cls', + step=1) + + # Test storage backend. + save_file = osp.join(self.tmpdir.name, + 'vis_data/vis_image/test_cls_1.png') + self.assertTrue(osp.exists(save_file)) + + # Test out_file + out_file = osp.join(self.tmpdir.name, 'results.png') + self.vis.visualize_cls( + image=image, data_sample=data_sample, out_file=out_file) + self.assertTrue(osp.exists(out_file)) + + # Test with dataset_meta + self.vis.dataset_meta = {'classes': ['cat', 'bird', 'dog']} + + def patch_texts(text, *_, **__): + self.assertEqual( + text, '\n'.join([ + 'Ground truth: 1 (bird)', + 'Prediction: 1, 0.80 (bird)', + ])) + + with patch.object(self.vis, 'draw_texts', patch_texts): + self.vis.visualize_cls(image, data_sample) + + # Test without pred_label + def patch_texts(text, *_, **__): + self.assertEqual(text, '\n'.join([ + 'Ground truth: 1 (bird)', + ])) + + with patch.object(self.vis, 'draw_texts', patch_texts): + self.vis.visualize_cls(image, data_sample, draw_pred=False) + + # Test without gt_label + def patch_texts(text, *_, **__): + self.assertEqual(text, '\n'.join([ + 'Prediction: 1, 0.80 (bird)', + ])) + + with patch.object(self.vis, 'draw_texts', patch_texts): + self.vis.visualize_cls(image, data_sample, draw_gt=False) + + # Test without score + del data_sample.pred_score + + def patch_texts(text, *_, **__): + self.assertEqual( + text, '\n'.join([ + 'Ground truth: 1 (bird)', + 'Prediction: 1 (bird)', + ])) + + with patch.object(self.vis, 'draw_texts', patch_texts): + self.vis.visualize_cls(image, data_sample) + + # Test adaptive font size + def assert_font_size(target_size): + + def draw_texts(text, font_sizes, *_, **__): + self.assertEqual(font_sizes, target_size) + + return draw_texts + + with patch.object(self.vis, 'draw_texts', assert_font_size(7)): + self.vis.visualize_cls( + np.ones((224, 384, 3), np.uint8), data_sample) + + with patch.object(self.vis, 'draw_texts', assert_font_size(2)): + self.vis.visualize_cls( + np.ones((10, 384, 3), np.uint8), data_sample) + + with patch.object(self.vis, 'draw_texts', assert_font_size(21)): + self.vis.visualize_cls( + np.ones((1000, 1000, 3), np.uint8), data_sample) + + # Test rescale image + with patch.object(self.vis, 'draw_texts', assert_font_size(14)): + self.vis.visualize_cls( + np.ones((224, 384, 3), np.uint8), + data_sample, + rescale_factor=2.) + + def test_visualize_image_retrieval(self): + image = np.ones((10, 10, 3), np.uint8) + data_sample = DataSample().set_pred_score([0.1, 0.8, 0.1]) + + class ToyPrototype: + + def get_data_info(self, idx): + img_path = osp.join(osp.dirname(__file__), '../data/color.jpg') + return {'img_path': img_path, 'sample_idx': idx} + + prototype_dataset = ToyPrototype() + + # Test show + def mock_show(drawn_img, win_name, wait_time): + if image.shape == drawn_img.shape: + self.assertFalse((image == drawn_img).all()) + self.assertEqual(win_name, 'test_retrieval') + self.assertEqual(wait_time, 0) + + with patch.object(self.vis, 'show', mock_show): + self.vis.visualize_image_retrieval( + image, + data_sample, + prototype_dataset, + show=True, + name='test_retrieval', + step=1) + + # Test storage backend. + save_file = osp.join(self.tmpdir.name, + 'vis_data/vis_image/test_retrieval_1.png') + self.assertTrue(osp.exists(save_file)) + + # Test out_file + out_file = osp.join(self.tmpdir.name, 'results.png') + self.vis.visualize_image_retrieval( + image, + data_sample, + prototype_dataset, + out_file=out_file, + ) + self.assertTrue(osp.exists(out_file)) + + def test_visualize_masked_image(self): + image = np.ones((10, 10, 3), np.uint8) + data_sample = DataSample().set_mask( + torch.tensor([ + [0, 0, 1, 1], + [0, 1, 1, 0], + [1, 1, 0, 0], + [1, 0, 0, 1], + ])) + + # Test show + def mock_show(drawn_img, win_name, wait_time): + self.assertTupleEqual(drawn_img.shape, (224, 224, 3)) + self.assertEqual(win_name, 'test_mask') + self.assertEqual(wait_time, 0) + + with patch.object(self.vis, 'show', mock_show): + self.vis.visualize_masked_image( + image, data_sample, show=True, name='test_mask', step=1) + + # Test storage backend. + save_file = osp.join(self.tmpdir.name, + 'vis_data/vis_image/test_mask_1.png') + self.assertTrue(osp.exists(save_file)) + + # Test out_file + out_file = osp.join(self.tmpdir.name, 'results.png') + self.vis.visualize_masked_image(image, data_sample, out_file=out_file) + self.assertTrue(osp.exists(out_file)) + + def tearDown(self): + self.tmpdir.cleanup() diff --git a/tests/test_visualizations/test_visualizer.py b/tests/test_visualizations/test_visualizer.py deleted file mode 100644 index f454a0ce..00000000 --- a/tests/test_visualizations/test_visualizer.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright (c) Open-MMLab. All rights reserved. -import os.path as osp -import tempfile -from unittest import TestCase -from unittest.mock import patch - -import numpy as np -import torch - -from mmpretrain.structures import DataSample -from mmpretrain.visualization import ClsVisualizer - - -class TestClsVisualizer(TestCase): - - def setUp(self) -> None: - super().setUp() - tmpdir = tempfile.TemporaryDirectory() - self.tmpdir = tmpdir - self.vis = ClsVisualizer( - save_dir=tmpdir.name, - vis_backends=[dict(type='LocalVisBackend')], - ) - - def test_add_datasample(self): - image = np.ones((10, 10, 3), np.uint8) - data_sample = DataSample().set_gt_label(1).set_pred_label(1).\ - set_pred_score(torch.tensor([0.1, 0.8, 0.1])) - - # Test show - def mock_show(drawn_img, win_name, wait_time): - self.assertFalse((image == drawn_img).all()) - self.assertEqual(win_name, 'test') - self.assertEqual(wait_time, 0) - - with patch.object(self.vis, 'show', mock_show): - self.vis.add_datasample( - 'test', image=image, data_sample=data_sample, show=True) - - # Test out_file - out_file = osp.join(self.tmpdir.name, 'results.png') - self.vis.add_datasample( - 'test', image=image, data_sample=data_sample, out_file=out_file) - self.assertTrue(osp.exists(out_file)) - - # Test storage backend. - save_file = osp.join(self.tmpdir.name, 'vis_data/vis_image/test_0.png') - self.assertTrue(osp.exists(save_file)) - - # Test with dataset_meta - self.vis.dataset_meta = {'classes': ['cat', 'bird', 'dog']} - - def test_texts(text, *_, **__): - self.assertEqual( - text, '\n'.join([ - 'Ground truth: 1 (bird)', - 'Prediction: 1, 0.80 (bird)', - ])) - - with patch.object(self.vis, 'draw_texts', test_texts): - self.vis.add_datasample( - 'test', image=image, data_sample=data_sample) - - # Test without pred_label - def test_texts(text, *_, **__): - self.assertEqual(text, '\n'.join([ - 'Ground truth: 1 (bird)', - ])) - - with patch.object(self.vis, 'draw_texts', test_texts): - self.vis.add_datasample( - 'test', image=image, data_sample=data_sample, draw_pred=False) - - # Test without gt_label - def test_texts(text, *_, **__): - self.assertEqual(text, '\n'.join([ - 'Prediction: 1, 0.80 (bird)', - ])) - - with patch.object(self.vis, 'draw_texts', test_texts): - self.vis.add_datasample( - 'test', image=image, data_sample=data_sample, draw_gt=False) - - # Test without score - del data_sample.pred_score - - def test_texts(text, *_, **__): - self.assertEqual( - text, '\n'.join([ - 'Ground truth: 1 (bird)', - 'Prediction: 1 (bird)', - ])) - - with patch.object(self.vis, 'draw_texts', test_texts): - self.vis.add_datasample( - 'test', image=image, data_sample=data_sample) - - # Test adaptive font size - def assert_font_size(target_size): - - def draw_texts(text, font_sizes, *_, **__): - self.assertEqual(font_sizes, target_size) - - return draw_texts - - with patch.object(self.vis, 'draw_texts', assert_font_size(7)): - self.vis.add_datasample( - 'test', - image=np.ones((224, 384, 3), np.uint8), - data_sample=data_sample) - - with patch.object(self.vis, 'draw_texts', assert_font_size(2)): - self.vis.add_datasample( - 'test', - image=np.ones((10, 384, 3), np.uint8), - data_sample=data_sample) - - with patch.object(self.vis, 'draw_texts', assert_font_size(21)): - self.vis.add_datasample( - 'test', - image=np.ones((1000, 1000, 3), np.uint8), - data_sample=data_sample) - - # Test rescale image - with patch.object(self.vis, 'draw_texts', assert_font_size(14)): - self.vis.add_datasample( - 'test', - image=np.ones((224, 384, 3), np.uint8), - rescale_factor=2., - data_sample=data_sample) - - def tearDown(self): - self.tmpdir.cleanup() diff --git a/tools/analysis_tools/analyze_results.py b/tools/analysis_tools/analyze_results.py index b4837ece..a1019a74 100644 --- a/tools/analysis_tools/analyze_results.py +++ b/tools/analysis_tools/analyze_results.py @@ -10,7 +10,7 @@ from mmengine import DictAction from mmpretrain.datasets import build_dataset from mmpretrain.structures import ClsDataSample -from mmpretrain.visualization import ClsVisualizer +from mmpretrain.visualization import UniversalVisualizer def parse_args(): @@ -47,7 +47,7 @@ def parse_args(): def save_imgs(result_dir, folder_name, results, dataset, rescale_factor=None): full_dir = osp.join(result_dir, folder_name) - vis = ClsVisualizer() + vis = UniversalVisualizer() vis.dataset_meta = {'classes': dataset.CLASSES} # save imgs @@ -67,8 +67,8 @@ def save_imgs(result_dir, folder_name, results, dataset, rescale_factor=None): raise ValueError('Cannot load images from the dataset infos.') if rescale_factor is not None: img = mmcv.imrescale(img, rescale_factor) - vis.add_datasample( - name, img, data_sample, out_file=osp.join(full_dir, name + '.png')) + vis.visualize_cls( + img, data_sample, out_file=osp.join(full_dir, name + '.png')) for k, v in result.items(): if isinstance(v, torch.Tensor): diff --git a/tools/model_converters/revvit_to_mmpretrain.py b/tools/model_converters/revvit_to_mmpretrain.py index b3908d31..ec9bc0b4 100644 --- a/tools/model_converters/revvit_to_mmpretrain.py +++ b/tools/model_converters/revvit_to_mmpretrain.py @@ -69,12 +69,6 @@ def convert_revvit(ckpt): else: final_ckpt[k] = v - # add pos embed for cls token - if k == 'backbone.pos_embed': - v = torch.cat([torch.ones_like(v).mean(dim=1, keepdim=True), v], - dim=1) - final_ckpt[k] = v - return final_ckpt diff --git a/tools/visualization/browse_dataset.py b/tools/visualization/browse_dataset.py index b2611d5f..97945939 100644 --- a/tools/visualization/browse_dataset.py +++ b/tools/visualization/browse_dataset.py @@ -3,18 +3,15 @@ import argparse import os.path as osp import sys -import cv2 import mmcv -import numpy as np from mmengine.config import Config, DictAction from mmengine.dataset import Compose from mmengine.registry import init_default_scope from mmengine.utils import ProgressBar -from mmengine.visualization import Visualizer +from mmengine.visualization.utils import img_from_canvas from mmpretrain.datasets.builder import build_dataset -from mmpretrain.visualization import ClsVisualizer -from mmpretrain.visualization.cls_visualizer import _get_adaptive_scale +from mmpretrain.visualization import UniversalVisualizer, create_figure def parse_args(): @@ -90,49 +87,20 @@ def parse_args(): def make_grid(imgs, names, rescale_factor=None): """Concat list of pictures into a single big picture, align height here.""" - vis = Visualizer() + figure = create_figure() + gs = figure.add_gridspec(1, len(imgs)) ori_shapes = [img.shape[:2] for img in imgs] if rescale_factor is not None: imgs = [mmcv.imrescale(img, rescale_factor) for img in imgs] - max_height = int(max(img.shape[0] for img in imgs) * 1.1) - min_width = min(img.shape[1] for img in imgs) - horizontal_gap = min_width // 10 - img_scale = _get_adaptive_scale((max_height, min_width)) - - texts = [] - text_positions = [] - start_x = 0 for i, img in enumerate(imgs): - pad_height = (max_height - img.shape[0]) // 2 - pad_width = horizontal_gap // 2 - # make border - imgs[i] = cv2.copyMakeBorder( - img, - pad_height, - max_height - img.shape[0] - pad_height + int(img_scale * 30 * 2), - pad_width, - pad_width, - cv2.BORDER_CONSTANT, - value=(255, 255, 255)) + subplot = figure.add_subplot(gs[0, i]) + subplot.axis(False) + subplot.imshow(img) + subplot.set_title(f'{names[i]}\n{ori_shapes[i]}') - texts.append(f'{names[i]}\n{ori_shapes[i]}') - text_positions.append( - [start_x + img.shape[1] // 2 + pad_width, max_height]) - start_x += img.shape[1] + horizontal_gap - - display_img = np.concatenate(imgs, axis=1) - vis.set_image(display_img) - img_scale = _get_adaptive_scale(display_img.shape[:2]) - vis.draw_texts( - texts, - positions=np.array(text_positions), - font_sizes=img_scale * 7, - colors='black', - horizontal_alignments='center', - font_families='monospace') - return vis.get_image() + return img_from_canvas(figure.canvas) class InspectCompose(Compose): @@ -148,7 +116,7 @@ class InspectCompose(Compose): def __call__(self, data): if 'img' in data: self.intermediate_imgs.append({ - 'name': 'original', + 'name': 'Original', 'img': data['img'].copy() }) @@ -181,7 +149,7 @@ def main(): # init visualizer cfg.visualizer.pop('type') - visualizer = ClsVisualizer(**cfg.visualizer) + visualizer = UniversalVisualizer(**cfg.visualizer) visualizer.dataset_meta = dataset.metainfo # init visualization image number @@ -220,13 +188,13 @@ def main(): out_file = osp.join(args.output_dir, filename) if args.output_dir is not None else None - visualizer.add_datasample( - filename, + visualizer.visualize_cls( image if args.channel_order == 'RGB' else image[..., ::-1], data_sample, rescale_factor=rescale_factor, show=not args.not_show, wait_time=args.show_interval, + name=filename, out_file=out_file) progress_bar.update() diff --git a/tools/visualization/vis_tsne.py b/tools/visualization/vis_tsne.py index 66f2d242..88661b8e 100644 --- a/tools/visualization/vis_tsne.py +++ b/tools/visualization/vis_tsne.py @@ -2,24 +2,27 @@ import argparse import os.path as osp import time -from functools import partial -from typing import Optional +from collections import defaultdict import matplotlib.pyplot as plt -import mmengine import numpy as np +import rich.progress as progress import torch import torch.nn.functional as F from mmengine.config import Config, DictAction -from mmengine.dataset import default_collate, worker_init_fn -from mmengine.dist import get_rank +from mmengine.device import get_device from mmengine.logging import MMLogger +from mmengine.runner import Runner from mmengine.utils import mkdir_or_exist -from sklearn.manifold import TSNE -from torch.utils.data import DataLoader from mmpretrain.apis import get_model -from mmpretrain.registry import DATA_SAMPLERS, DATASETS +from mmpretrain.registry import DATASETS + +try: + from sklearn.manifold import TSNE +except ImportError as e: + raise ImportError('Please install `sklearn` to calculate ' + 'TSNE by `pip install scikit-learn`') from e def parse_args(): @@ -27,22 +30,30 @@ def parse_args(): parser.add_argument('config', help='tsne config file path') parser.add_argument('--checkpoint', default=None, help='checkpoint file') parser.add_argument('--work-dir', help='the dir to save logs and models') + parser.add_argument( + '--test-cfg', + help='tsne config file path to load config of test dataloader.') parser.add_argument( '--vis-stage', choices=['backbone', 'neck', 'pre_logits'], - default='backbone', - help='the visualization stage of the model') + help='The visualization stage of the model') + parser.add_argument( + '--class-idx', + nargs='+', + type=int, + help='The categories used to calculate t-SNE.') parser.add_argument( '--max-num-class', type=int, default=20, - help='the maximum number of classes to apply t-SNE algorithms, now the' - 'function supports maximum 20 classes') - parser.add_argument('--seed', type=int, default=0, help='random seed') + help='The first N categories to apply t-SNE algorithms. ' + 'Defaults to 20.') parser.add_argument( - '--deterministic', - action='store_true', - help='whether to set deterministic options for CUDNN backend.') + '--max-num-samples', + type=int, + default=100, + help='The maximum number of samples per category. ' + 'Higher number need longer time to calculate. Defaults to 100.') parser.add_argument( '--cfg-options', nargs='+', @@ -53,8 +64,15 @@ def parse_args(): 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'Note that the quotation marks are necessary and that no white space ' 'is allowed.') + parser.add_argument('--device', help='Device used for inference') parser.add_argument( - '--device', default='cuda:0', help='Device used for inference') + '--legend', + action='store_true', + help='Show the legend of all categories.') + parser.add_argument( + '--show', + action='store_true', + help='Display the result in a graphical window.') # t-SNE settings parser.add_argument( @@ -98,19 +116,12 @@ def parse_args(): return args -def post_process(): - pass - - def main(): args = parse_args() cfg = Config.fromfile(args.config) if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) - # set cudnn_benchmark - if cfg.get('cudnn_benchmark', False): - torch.backends.cudnn.benchmark = True # work_dir is determined in this priority: CLI > segment in file > filename if args.work_dir is not None: # update configs according to CLI args if args.work_dir is not None @@ -135,82 +146,76 @@ def main(): log_level=cfg.log_level) # build the model from a config file and a checkpoint file - model = get_model(cfg, args.checkpoint, device=args.device) - logger.info(f'Model loaded and the output indices of backbone is ' - f'{model.backbone.out_indices}.') + device = args.device or get_device() + model = get_model(cfg, args.checkpoint, device=device) + logger.info('Model loaded.') # build the dataset - tsne_dataloader_cfg = cfg.get('test_dataloader') - tsne_dataset_cfg = tsne_dataloader_cfg.pop('dataset') - if isinstance(tsne_dataset_cfg, dict): - dataset = DATASETS.build(tsne_dataset_cfg) - if hasattr(dataset, 'full_init'): - dataset.full_init() + if args.test_cfg is not None: + dataloader_cfg = Config.fromfile(args.test_cfg).get('test_dataloader') + elif 'test_dataloader' not in cfg: + raise ValueError('No `test_dataloader` in the config, you can ' + 'specify another config file that includes test ' + 'dataloader settings by the `--test-cfg` option.') + else: + dataloader_cfg = cfg.get('test_dataloader') + + dataset = DATASETS.build(dataloader_cfg.pop('dataset')) + classes = dataset.metainfo.get('classes') + + if args.class_idx is None: + num_classes = args.max_num_class if classes is None else len(classes) + args.class_idx = list(range(num_classes))[:args.max_num_class] + + if classes is not None: + classes = [classes[idx] for idx in args.class_idx] + else: + classes = args.class_idx # compress dataset, select that the label is less then max_num_class subset_idx_list = [] + counter = defaultdict(int) for i in range(len(dataset)): - if dataset.get_data_info(i)['gt_label'] < args.max_num_class: + gt_label = dataset.get_data_info(i)['gt_label'] + if (gt_label in args.class_idx + and counter[gt_label] < args.max_num_samples): subset_idx_list.append(i) + counter[gt_label] += 1 dataset.get_subset_(subset_idx_list) logger.info(f'Apply t-SNE to visualize {len(subset_idx_list)} samples.') - # build sampler - sampler_cfg = tsne_dataloader_cfg.pop('sampler') - if isinstance(sampler_cfg, dict): - sampler = DATA_SAMPLERS.build( - sampler_cfg, default_args=dict(dataset=dataset, seed=args.seed)) - - # build dataloader - init_fn: Optional[partial] - if args.seed is not None: - init_fn = partial( - worker_init_fn, - num_workers=tsne_dataloader_cfg.get('num_workers'), - rank=get_rank(), - seed=args.seed) - else: - init_fn = None - - tsne_dataloader = DataLoader( - dataset=dataset, - sampler=sampler, - collate_fn=default_collate, - worker_init_fn=init_fn, - **tsne_dataloader_cfg) + dataloader_cfg.dataset = dataset + dataloader_cfg.setdefault('collate_fn', dict(type='default_collate')) + dataloader = Runner.build_dataloader(dataloader_cfg) results = dict() features = [] labels = [] - progress_bar = mmengine.ProgressBar(len(tsne_dataloader)) - for _, data in enumerate(tsne_dataloader): + for data in progress.track(dataloader, description='Calculating...'): with torch.no_grad(): # preprocess data data = model.data_preprocessor(data) batch_inputs, batch_data_samples = \ data['inputs'], data['data_samples'] + batch_labels = torch.cat([i.gt_label for i in batch_data_samples]) # extract backbone features - batch_features = model.extract_feat( - batch_inputs, stage=args.vis_stage) + extract_args = {} + if args.vis_stage: + extract_args['stage'] = args.vis_stage + batch_features = model.extract_feat(batch_inputs, **extract_args) # post process - if args.vis_stage == 'backbone': - if getattr(model.backbone, 'output_cls_token', False) is False: - batch_features = [ - F.adaptive_avg_pool2d(inputs, 1).squeeze() - for inputs in batch_features - ] - else: - # output_cls_token is True, here t-SNE uses cls_token - batch_features = [feat[-1] for feat in batch_features] - - batch_labels = torch.cat([i.gt_label for i in batch_data_samples]) + if batch_features[0].ndim == 4: + # For (N, C, H, W) feature + batch_features = [ + F.adaptive_avg_pool2d(inputs, 1).squeeze() + for inputs in batch_features + ] # save batch features features.append(batch_features) labels.extend(batch_labels.cpu().numpy()) - progress_bar.update() for i in range(len(features[0])): key = 'feat_' + str(model.backbone.out_indices[i]) @@ -238,15 +243,20 @@ def main(): result = tsne_model.fit_transform(val) res_min, res_max = result.min(0), result.max(0) res_norm = (result - res_min) / (res_max - res_min) - plt.figure(figsize=(10, 10)) - plt.scatter( + _, ax = plt.subplots(figsize=(10, 10)) + scatter = ax.scatter( res_norm[:, 0], res_norm[:, 1], alpha=1.0, s=15, c=labels, cmap='tab20') + if args.legend: + legend = ax.legend(scatter.legend_elements()[0], classes) + ax.add_artist(legend) plt.savefig(f'{tsne_work_dir}{key}.png') + if args.show: + plt.show() logger.info(f'Save features and results to {tsne_work_dir}')