[Feature] Support pixel reconstruction visualization (#570)

* refactor reconstruction visualization

* support simmim visualization

* fix reconstruction bug of MAE

* support visualization of MaskFeat

* refaction mae visualization demo

* add unit test

* fix lint and ut

* update

* add docs

* set random seed

* update

* update docstring

* add torch version check

* update

* rename

* update version

* update

* fix lint

* add docstring

* update docs
pull/616/head
Yixiao Fang 2022-12-06 19:45:01 +08:00 committed by GitHub
parent d73c953804
commit 73cd764b5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 693 additions and 350 deletions

View File

@ -37,3 +37,19 @@ train_dataloader = dict(
ann_file='meta/train.txt',
data_prefix=dict(img_path='train/'),
pipeline=train_pipeline))
# for visualization
vis_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='Resize', scale=(224, 224), backend='pillow'),
dict(
type='BEiTMaskGenerator',
input_size=14,
num_masking_patches=78,
min_num_patches=15,
),
dict(
type='PackSelfSupInputs',
algorithm_keys=['mask'],
meta_keys=['img_path'])
]

View File

@ -34,3 +34,19 @@ train_dataloader = dict(
ann_file='meta/train.txt',
data_prefix=dict(img_path='train/'),
pipeline=train_pipeline))
# for visualization
vis_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='Resize', scale=(192, 192), backend='pillow'),
dict(
type='SimMIMMaskGenerator',
input_size=192,
mask_patch_size=32,
model_patch_size=4,
mask_ratio=0.6),
dict(
type='PackSelfSupInputs',
algorithm_keys=['mask'],
meta_keys=['img_path'])
]

View File

@ -6,8 +6,6 @@
"source": [
"Copyright (c) OpenMMLab. All rights reserved.\n",
"\n",
"Copyright (c) Meta Platforms, Inc. and affiliates.\n",
"\n",
"Modified from https://colab.research.google.com/github/facebookresearch/mae/blob/main/demo/mae_visualize.ipynb\n",
"\n",
"## Masked Autoencoders: Visualization Demo\n",
@ -36,7 +34,8 @@
" print('Running in Colab.')\n",
" !pip3 install openmim\n",
" !pip install -U openmim\n",
" !mim install 'mmengine==0.1.0' 'mmcv>=2.0.0rc1'\n",
" !mim install mmengine\n",
" !mim install 'mmcv>=2.0.0rc1'\n",
"\n",
" !git clone https://github.com/open-mmlab/mmselfsup.git\n",
" %cd mmselfsup/\n",
@ -51,18 +50,19 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"from argparse import ArgumentParser\n",
"from typing import Tuple, Optional\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import torch\n",
"from mmengine.dataset import Compose, default_collate\n",
"\n",
"from mmselfsup.apis import inference_model\n",
"from mmselfsup.models.utils import SelfSupDataPreprocessor\n",
"from mmselfsup.registry import MODELS\n",
"from mmselfsup.apis import inference_model, init_model\n",
"from mmselfsup.utils import register_all_modules"
]
},
@ -75,7 +75,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
@ -84,49 +84,81 @@
"imagenet_mean = np.array([0.485, 0.456, 0.406])\n",
"imagenet_std = np.array([0.229, 0.224, 0.225])\n",
"\n",
"def show_image(image, title=''):\n",
"\n",
"def show_image(img: torch.Tensor, title: str = '') -> None:\n",
" # image is [H, W, 3]\n",
" assert image.shape[2] == 3\n",
" image = torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int()\n",
" plt.imshow(image)\n",
" assert img.shape[2] == 3\n",
"\n",
" plt.imshow(img)\n",
" plt.title(title, fontsize=16)\n",
" plt.axis('off')\n",
" return\n",
"\n",
"\n",
"def show_images(x, im_masked, y, im_paste):\n",
"def save_images(original_img: torch.Tensor, img_masked: torch.Tensor,\n",
" pred_img: torch.Tensor, img_paste: torch.Tensor,\n",
" out_file: Optional[str] =None) -> None:\n",
" # make the plt figure larger\n",
" plt.rcParams['figure.figsize'] = [24, 6]\n",
"\n",
" plt.subplot(1, 4, 1)\n",
" show_image(x, \"original\")\n",
" show_image(original_img, 'original')\n",
"\n",
" plt.subplot(1, 4, 2)\n",
" show_image(im_masked, \"masked\")\n",
" show_image(img_masked, 'masked')\n",
"\n",
" plt.subplot(1, 4, 3)\n",
" show_image(y, \"reconstruction\")\n",
" show_image(pred_img, 'reconstruction')\n",
"\n",
" plt.subplot(1, 4, 4)\n",
" show_image(im_paste, \"reconstruction + visible\")\n",
" show_image(img_paste, 'reconstruction + visible')\n",
"\n",
" plt.show()\n",
" if out_file is None:\n",
" plt.show()\n",
" else:\n",
" plt.savefig(out_file)\n",
" print(f'Images are saved to {out_file}')\n",
"\n",
"\n",
"def post_process(x, y, mask):\n",
" x = torch.einsum('nchw->nhwc', x.cpu())\n",
"def recover_norm(img: torch.Tensor,\n",
" mean: np.ndarray = imagenet_mean,\n",
" std: np.ndarray = imagenet_std):\n",
" if mean is not None and std is not None:\n",
" img = torch.clip((img * std + mean) * 255, 0, 255).int()\n",
" return img\n",
"\n",
"\n",
"def post_process(\n",
" original_img: torch.Tensor,\n",
" pred_img: torch.Tensor,\n",
" mask: torch.Tensor,\n",
" mean: np.ndarray = imagenet_mean,\n",
" std: np.ndarray = imagenet_std\n",
") -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n",
" # channel conversion\n",
" original_img = torch.einsum('nchw->nhwc', original_img.cpu())\n",
" # masked image\n",
" im_masked = x * (1 - mask)\n",
" # MAE reconstruction pasted with visible patches\n",
" im_paste = x * (1 - mask) + y * mask\n",
" return x[0], im_masked[0], y[0], im_paste[0]"
" img_masked = original_img * (1 - mask)\n",
" # reconstructed image pasted with visible patches\n",
" img_paste = original_img * (1 - mask) + pred_img * mask\n",
"\n",
" # muptiply std and add mean to each image\n",
" original_img = recover_norm(original_img[0])\n",
" img_masked = recover_norm(img_masked[0])\n",
"\n",
" pred_img = recover_norm(pred_img[0])\n",
" img_paste = recover_norm(img_paste[0])\n",
"\n",
" return original_img, img_masked, pred_img, img_paste\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Prepare config file"
"### Load a pre-trained MAE model\n",
"\n",
"This is an MAE model trained with config 'mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k.py'.\n"
]
},
{
@ -138,55 +170,20 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Overwriting ../configs/selfsup/mae/mae_visualization.py\n"
"--2022-11-08 11:00:50-- https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth\n",
"正在解析主机 download.openmmlab.com (download.openmmlab.com)... 47.102.71.233\n",
"正在连接 download.openmmlab.com (download.openmmlab.com)|47.102.71.233|:443... 已连接。\n",
"已发出 HTTP 请求,正在等待回应... 200 OK\n",
"长度: 1355429265 (1.3G) [application/octet-stream]\n",
"正在保存至: “mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth”\n",
"\n",
"e_in1k_20220825-cc7 99%[==================> ] 1.26G 913KB/s 剩余 0s s "
]
}
],
"source": [
"%%writefile ../configs/selfsup/mae/mae_visualization.py\n",
"model = dict(\n",
" type='MAE',\n",
" data_preprocessor=dict(\n",
" mean=[123.675, 116.28, 103.53],\n",
" std=[58.395, 57.12, 57.375],\n",
" bgr_to_rgb=True),\n",
" backbone=dict(type='MAEViT', arch='l', patch_size=16, mask_ratio=0.75),\n",
" neck=dict(\n",
" type='MAEPretrainDecoder',\n",
" patch_size=16,\n",
" in_chans=3,\n",
" embed_dim=1024,\n",
" decoder_embed_dim=512,\n",
" decoder_depth=8,\n",
" decoder_num_heads=16,\n",
" mlp_ratio=4.,\n",
" ),\n",
" head=dict(\n",
" type='MAEPretrainHead',\n",
" norm_pix=True,\n",
" patch_size=16,\n",
" loss=dict(type='MAEReconstructionLoss')),\n",
" init_cfg=[\n",
" dict(type='Xavier', distribution='uniform', layer='Linear'),\n",
" dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0)\n",
" ])\n",
"\n",
"file_client_args = dict(backend='disk')\n",
"\n",
"# dataset summary\n",
"test_dataloader = dict(\n",
" dataset=dict(pipeline=[\n",
" dict(type='LoadImageFromFile', file_client_args=file_client_args),\n",
" dict(type='Resize', scale=(224, 224)),\n",
" dict(type='PackSelfSupInputs', meta_keys=['img_path'])\n",
" ]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load a pre-trained MAE model"
"# download checkpoint if not exist\n",
"!wget -nc https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth"
]
},
{
@ -198,46 +195,21 @@
"name": "stdout",
"output_type": "stream",
"text": [
"--2022-09-03 00:34:55-- https://download.openmmlab.com/mmselfsup/mae/mae_visualize_vit_large.pth\n",
"正在解析主机 download.openmmlab.com (download.openmmlab.com)... 47.107.10.247\n",
"正在连接 download.openmmlab.com (download.openmmlab.com)|47.107.10.247|:443... 已连接。\n",
"已发出 HTTP 请求,正在等待回应... 200 OK\n",
"长度: 1318299501 (1.2G) [application/octet-stream]\n",
"正在保存至: “mae_visualize_vit_large.pth”\n",
"local loads checkpoint from path: mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth\n",
"The model and loaded state dict do not match exactly\n",
"\n",
"mae_visualize_vit_l 100%[===================>] 1.23G 3.22MB/s 用时 6m 4s \n",
"unexpected key in source state_dict: data_preprocessor.mean, data_preprocessor.std\n",
"\n",
"2022-09-03 00:40:59 (3.46 MB/s) - 已保存 “mae_visualize_vit_large.pth” [1318299501/1318299501])\n",
"\n"
]
}
],
"source": [
"# This is an MAE model trained with pixels as targets for visualization (ViT-large, training mask ratio=0.75)\n",
"\n",
"# download checkpoint if not exist\n",
"# This ckpt is converted from https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth\n",
"!wget -nc https://download.openmmlab.com/mmselfsup/mae/mae_visualize_vit_large.pth"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"local loads checkpoint from path: mae_visualize_vit_large.pth\n",
"Model loaded.\n"
]
}
],
"source": [
"from mmselfsup.apis import init_model\n",
"ckpt_path = \"mae_visualize_vit_large.pth\"\n",
"model = init_model('../configs/selfsup/mae/mae_visualization.py', ckpt_path, device='cpu')\n",
"ckpt_path = \"mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth\"\n",
"model = init_model(\n",
" '../configs/selfsup/mae/mae_vit-large-p16_8xb512-amp-coslr-1600e_in1k.py',\n",
" ckpt_path,\n",
" device='cpu')\n",
"print('Model loaded.')"
]
},
@ -250,16 +222,16 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch._C.Generator at 0x7f5029d19950>"
"<torch._C.Generator at 0x7fb2ccfbac90>"
]
},
"execution_count": 7,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@ -272,23 +244,23 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2022-09-03 00:41:01-- https://download.openmmlab.com/mmselfsup/mae/fox.jpg\n",
"正在解析主机 download.openmmlab.com (download.openmmlab.com)... 101.133.111.186\n",
"正在连接 download.openmmlab.com (download.openmmlab.com)|101.133.111.186|:443... 已连接。\n",
"--2022-11-08 11:21:14-- https://download.openmmlab.com/mmselfsup/mae/fox.jpg\n",
"正在解析主机 download.openmmlab.com (download.openmmlab.com)... 47.102.71.233\n",
"正在连接 download.openmmlab.com (download.openmmlab.com)|47.102.71.233|:443... 已连接。\n",
"已发出 HTTP 请求,正在等待回应... 200 OK\n",
"长度: 60133 (59K) [image/jpeg]\n",
"正在保存至: “fox.jpg”\n",
"\n",
"fox.jpg 100%[===================>] 58.72K --.-KB/s 用时 0.06s \n",
"fox.jpg 100%[===================>] 58.72K --.-KB/s 用时 0.05s \n",
"\n",
"2022-09-03 00:41:01 (962 KB/s) - 已保存 “fox.jpg” [60133/60133])\n",
"2022-11-08 11:21:15 (1.08 MB/s) - 已保存 “fox.jpg” [60133/60133])\n",
"\n"
]
}
@ -299,22 +271,34 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"img_path = 'fox.jpg'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Build Pipeline"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"cfg = model.cfg\n",
"test_pipeline = Compose(cfg.test_dataloader.dataset.pipeline)\n",
"data_preprocessor = MODELS.build(cfg.model.data_preprocessor)"
"model.cfg.test_dataloader = dict(\n",
" dataset=dict(pipeline=[\n",
" dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),\n",
" dict(type='Resize', scale=(224, 224), backend='pillow'),\n",
" dict(type='PackSelfSupInputs', meta_keys=['img_path'])\n",
" ]))\n",
"\n",
"vis_pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline)"
]
},
{
@ -324,36 +308,62 @@
"outputs": [],
"source": [
"data = dict(img_path=img_path)\n",
"data = test_pipeline(data)\n",
"data = vis_pipeline(data)\n",
"data = default_collate([data])\n",
"img, _ = data_preprocessor(data, False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.rcParams['figure.figsize'] = [5, 5]\n",
"show_image(torch.einsum('nchw->nhwc', img[0].cpu())[0])"
"img, _ = model.data_preprocessor(data, False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Run MAE on the image"
"### Reconstruction pipeline"
]
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"results = inference_model(model, img_path)\n",
"x, im_masked, y, im_paste = post_process(img[0], results.pred.value, results.mask.value)"
"# for MAE reconstruction\n",
"img_embedding = model.head.patchify(img[0])\n",
"# normalize the target image\n",
"mean = img_embedding.mean(dim=-1, keepdim=True)\n",
"std = (img_embedding.var(dim=-1, keepdim=True) + 1.e-6)**.5"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"# get reconstruction image\n",
"features = inference_model(model, img_path)\n",
"results = model.reconstruct(features, mean=mean, std=std)\n"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"original_target = img[0]\n",
"original_img, img_masked, pred_img, img_paste = post_process(\n",
" original_target,\n",
" results.pred.value,\n",
" results.mask.value,\n",
" mean=mean,\n",
" std=std)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Show the image"
]
},
{
@ -362,21 +372,13 @@
"metadata": {},
"outputs": [],
"source": [
"print('MAE with pixel reconstruction:')\n",
"show_images(x, im_masked, y, im_paste)"
"save_images(original_img, img_masked, pred_img, img_paste)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3.7.0 ('openmmlab')",
"language": "python",
"name": "python3"
},
@ -390,11 +392,11 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
"version": "3.7.0"
},
"vscode": {
"interpreter": {
"hash": "1742319693997e01e5942276ccf039297cd0a474ab9a20f711b7fa536eca5436"
"hash": "5909b3386efe3692f76356628babf720cfd47771f5d858315790cc041eb41361"
}
}
},

View File

@ -7,8 +7,6 @@
- [Publish a model](#publish-a-model)
- [Reproducibility](#reproducibility)
- [Log Analysis](#log-analysis)
- [Visualize Datasets](#visualize-datasets)
- [Use t-SNE](#use-t-sne)
## Count number of parameters
@ -92,53 +90,3 @@ Examples:
time std over epochs is 0.0028
average iter time: 1.1959 s/iter
```
## Visualize Datasets
`tools/misc/browse_dataset.py` helps the user to browse a mmselfsup dataset (transformed images) visually, or save the image to a designated directory.
```shell
python tools/misc/browse_dataset.py ${CONFIG} [-h] [--skip-type ${SKIP_TYPE[SKIP_TYPE...]}] [--output-dir ${OUTPUT_DIR}] [--not-show] [--show-interval ${SHOW_INTERVAL}]
```
An example:
```shell
python tools/misc/browse_dataset.py configs/selfsup/simsiam/simsiam_resnet50_8xb32-coslr-100e_in1k.py
```
An example of visualization:
<div align="center">
<img src="https://user-images.githubusercontent.com/36138628/199387454-219e6f6c-fbb7-43bb-b319-61d3e6266abc.png" width="600" />
</div>
- The left two pictures are images from contrastive learning data pipeline.
- The right one is a masked image.
## Use t-SNE
We provide an off-the-shelf tool to visualize the quality of image representations by t-SNE.
```shell
python tools/analysis_tools/visualize_tsne.py ${CONFIG_FILE} --checkpoint ${CKPT_PATH} --work-dir ${WORK_DIR} [optional arguments]
```
Arguments:
- `CONFIG_FILE`: config file for the pre-trained model.
- `CKPT_PATH`: the path of model's checkpoint.
- `WORK_DIR`: the directory to save the results of visualization.
- `[optional arguments]`: for optional arguments, you can refer to [visualize_tsne.py](https://github.com/open-mmlab/mmselfsup/blob/master/tools/analysis_tools/visualize_tsne.py)
An example:
```shell
python tools/analysis_tools/visualize_tsne.py configs/selfsup/simsiam/simsiam_resnet50_8xb32-coslr-100e_in1k.py --checkpoint epoch_100.pth --work-dir work_dirs/selfsup/simsiam_resnet50_8xb32-coslr-200e_in1k
```
An example of visualization:
<div align="center">
<img src="https://user-images.githubusercontent.com/36138628/199388251-476a5ad2-f9c1-4dfb-afe2-73cf41b5793b.jpg" width="800" />
</div>

View File

@ -7,8 +7,11 @@ Visualization can give an intuitive interpretation of the performance of the mod
- [Visualization](#visualization)
- [How visualization is implemented](#how-visualization-is-implemented)
- [What Visualization do in MMSelfsup](#what-visualization-do-in-mmselfsup)
- [Use different storage backends](#use-different-storage-backends)
- [Use Different Storage Backends](#use-different-storage-backends)
- [Customize Visualization](#customize-visualization)
- [Visualize Datasets](#visualize-datasets)
- [Visualize t-SNE](#visualize-t-sne)
- [Visualize Low-level Feature Reconstruction](#visualize-low-level-feature-reconstruction)
<!-- /TOC -->
@ -43,7 +46,7 @@ def after_train_iter(...):
The function [`add_datasample()`](https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/mmselfsup/visualization/selfsup_visualizer.py#L151) is impleted in [`SelfSupVisualizer`](mmselfsup.visualization.SelfSupVisualizer), and it is mainly used in [browse_dataset.py](https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/tools/analysis_tools/browse_dataset.py) for browsing dataset. More tutorial is in [analysis_tools.md](analysis_tools.md)
## Use different storage backends
## Use Different Storage Backends
If you want to use a different backend (Wandb, Tensorboard, or a custom backend with a remote window), just change the `vis_backends` in the config, as follows:
@ -86,3 +89,114 @@ E.g.
## Customize Visualization
The customization of the visualization is similar to other components. If you want to customize `Visualizer`, `VisBackend` or `VisualizationHook`, you can refer to [Visualization Doc](https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/visualization.md) in MMEngine.
## Visualize Datasets
`tools/misc/browse_dataset.py` helps the user to browse a mmselfsup dataset (transformed images) visually, or save the image to a designated directory.
```shell
python tools/misc/browse_dataset.py ${CONFIG} [-h] [--skip-type ${SKIP_TYPE[SKIP_TYPE...]}] [--output-dir ${OUTPUT_DIR}] [--not-show] [--show-interval ${SHOW_INTERVAL}]
```
An example:
```shell
python tools/misc/browse_dataset.py configs/selfsup/simsiam/simsiam_resnet50_8xb32-coslr-100e_in1k.py
```
An example of visualization:
<div align="center">
<img src="https://user-images.githubusercontent.com/36138628/199387454-219e6f6c-fbb7-43bb-b319-61d3e6266abc.png" width="600" />
</div>
- The left two pictures are images from contrastive learning data pipeline.
- The right one is a masked image.
## Visualize t-SNE
We provide an off-the-shelf tool to visualize the quality of image representations by t-SNE.
```shell
python tools/analysis_tools/visualize_tsne.py ${CONFIG_FILE} --checkpoint ${CKPT_PATH} --work-dir ${WORK_DIR} [optional arguments]
```
Arguments:
- `CONFIG_FILE`: config file for the pre-trained model.
- `CKPT_PATH`: the path of model's checkpoint.
- `WORK_DIR`: the directory to save the results of visualization.
- `[optional arguments]`: for optional arguments, you can refer to [visualize_tsne.py](https://github.com/open-mmlab/mmselfsup/blob/master/tools/analysis_tools/visualize_tsne.py)
An example:
```shell
python tools/analysis_tools/visualize_tsne.py configs/selfsup/simsiam/simsiam_resnet50_8xb32-coslr-100e_in1k.py --checkpoint epoch_100.pth --work-dir work_dirs/selfsup/simsiam_resnet50_8xb32-coslr-200e_in1k
```
An example of visualization:
<div align="center">
<img src="https://user-images.githubusercontent.com/36138628/199388251-476a5ad2-f9c1-4dfb-afe2-73cf41b5793b.jpg" width="800" />
</div>
## Visualize Low-level Feature Reconstruction
We provide several reconstruction visualization for listed algorithms:
- MAE
- SimMIM
- MaskFeat
Users can run command below to visualize the reconstruction.
```shell
python tools/analysis_tools/visualize_reconstruction.py ${CONFIG_FILE} \
--checkpoint ${CKPT_PATH} \
--img-path ${IMAGE_PATH} \
--out-file ${OUTPUT_PATH}
```
Arguments:
- `CONFIG_FILE`: config file for the pre-trained model.
- `CKPT_PATH`: the path of model's checkpoint.
- `IMAGE_PATH`: the input image path.
- `OUTPUT_PATH`: the output image path, including 4 sub-images.
- `[optional arguments]`: for optional arguments, you can refer to [visualize_reconstruction.py](https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/tools/analysis_tools/visualize_reconstruction.py)
An example:
```shell
python tools/analysis_tools/visualize_reconstruction.py configs/selfsup/mae/mae_vit-huge-p16_8xb512-amp-coslr-1600e_in1k.py \
--checkpoint https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-huge-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-huge-p16_8xb512-fp16-coslr-1600e_in1k_20220916-ff848775.pth \
--img-path data/imagenet/val/ILSVRC2012_val_00000003.JPEG \
--out-file test_mae.jpg \
--norm-pix
# As for SimMIM, it generates the mask in data pipeline, thus we use '--use-vis-pipeline' to apply 'vis_pipeline' defined in config instead of the pipeline defined in script.
python tools/analysis_tools/visualize_reconstruction.py configs/selfsup/simmim/simmim_swin-large_16xb128-amp-coslr-800e_in1k-192.py \
--checkpoint https://download.openmmlab.com/mmselfsup/1.x/simmim/simmim_swin-large_16xb128-amp-coslr-800e_in1k-192/simmim_swin-large_16xb128-amp-coslr-800e_in1k-192_20220916-4ad216d3.pth \
--img-path data/imagenet/val/ILSVRC2012_val_00000003.JPEG \
--out-file test_simmim.jpg \
--use-vis-pipeline
```
Results of MAE:
<div align="center">
<img src="https://user-images.githubusercontent.com/36138628/200465826-83f316ed-5a46-46a9-b665-784b5332d348.jpg" width="800" />
</div>
Results of SimMIM:
<div align="center">
<img src="https://user-images.githubusercontent.com/36138628/200466133-b77bc9af-224b-4810-863c-eed81ddd1afa.jpg" width="800" />
</div>
Results of MaskFeat:
<div align="center">
<img src="https://user-images.githubusercontent.com/36138628/200465876-7e7dcb6f-5e8d-4d80-b300-9e1847cb975f.jpg" width="800" />
</div>

View File

@ -81,5 +81,6 @@ def inference_model(model: nn.Module,
# forward the model
with torch.no_grad():
results = model.test_step(data)
return results
inputs, data_samples = model.data_preprocessor(data, False)
features = model(inputs, data_samples, mode='tensor')
return features

View File

@ -119,7 +119,7 @@ class BaseModel(_BaseModel):
or ``dict of tensor for custom use.
"""
if mode == 'tensor':
feats = self.extract_feat(inputs)
feats = self.extract_feat(inputs, data_samples=data_samples)
return feats
elif mode == 'loss':
return self.loss(inputs, data_samples)

View File

@ -17,7 +17,9 @@ class MAE(BaseModel):
<https://arxiv.org/abs/2111.06377>`_.
"""
def extract_feat(self, inputs: List[torch.Tensor],
def extract_feat(self,
inputs: List[torch.Tensor],
data_samples: Optional[List[SelfSupDataSample]] = None,
**kwarg) -> Tuple[torch.Tensor]:
"""The forward function to extract features from neck.
@ -27,33 +29,33 @@ class MAE(BaseModel):
Returns:
Tuple[torch.Tensor]: Neck outputs.
"""
latent, _, ids_restore = self.backbone(inputs[0])
latent, mask, ids_restore = self.backbone(inputs[0])
pred = self.neck(latent, ids_restore)
self.mask = mask
return pred
def predict(self,
inputs: List[torch.Tensor],
data_samples: Optional[List[SelfSupDataSample]] = None,
**kwargs) -> SelfSupDataSample:
"""The forward function in testing. It is mainly for image
reconstruction.
def reconstruct(self,
features: torch.Tensor,
data_samples: Optional[List[SelfSupDataSample]] = None,
**kwargs) -> SelfSupDataSample:
"""The function is for image reconstruction.
Args:
inputs (List[torch.Tensor]): The input images.
features (torch.Tensor): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns:
SelfSupDataSample: The prediction from model.
"""
mean = kwargs['mean']
std = kwargs['std']
features = features * std + mean
latent, mask, ids_restore = self.backbone(inputs[0])
pred = self.neck(latent, ids_restore)
pred = self.head.unpatchify(pred)
pred = self.head.unpatchify(features)
pred = torch.einsum('nchw->nhwc', pred).detach().cpu()
mask = mask.detach()
mask = self.mask.detach()
mask = mask.unsqueeze(-1).repeat(1, 1, self.head.patch_size**2 *
3) # (N, H*W, p*p*3)
mask = self.head.unpatchify(mask) # 1 is removing, 0 is keeping

View File

@ -1,7 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
import torch
from mmengine.structures import BaseDataElement
from mmselfsup.registry import MODELS
from mmselfsup.structures import SelfSupDataSample
@ -16,8 +17,10 @@ class MaskFeat(BaseModel):
Pre-Training <https://arxiv.org/abs/2112.09133>`_.
"""
def extract_feat(self, inputs: List[torch.Tensor],
def extract_feat(self,
inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
compute_hog: bool = True,
**kwarg) -> Tuple[torch.Tensor]:
"""The forward function to extract features from neck.
@ -25,15 +28,30 @@ class MaskFeat(BaseModel):
inputs (List[torch.Tensor]): The input images and mask.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
compute_hog (bool): Whether to compute hog during extraction. If
True, the batch size of inputs need to be 1. Defaults to True.
Returns:
Tuple[torch.Tensor]: Neck outputs.
"""
img = inputs[0]
mask = torch.stack(
self.mask = torch.stack(
[data_sample.mask.value for data_sample in data_samples])
latent = self.backbone(img, mask)
return latent
latent = self.backbone(img, self.mask)
B, L, C = latent.shape
pred = self.neck([latent.view(B * L, C)])
pred = pred[0].view(B, L, -1)
# compute hog
if compute_hog:
assert img.size(0) == 1, 'Currently only support batch size 1.'
_ = self.target_generator(img)
hog_image = torch.from_numpy(
self.target_generator.generate_hog_image(
self.target_generator.out)).unsqueeze(0).unsqueeze(0)
self.target = hog_image.expand(-1, 3, -1, -1)
return pred[:, 1:, :] # remove cls token
def loss(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
@ -62,3 +80,60 @@ class MaskFeat(BaseModel):
loss = self.head(pred, hog, mask)
losses = dict(loss=loss)
return losses
def reconstruct(self,
features: List[torch.Tensor],
data_samples: Optional[List[SelfSupDataSample]] = None,
**kwargs) -> SelfSupDataSample:
"""The function is for image reconstruction.
Args:
features (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns:
SelfSupDataSample: The prediction from model.
"""
# recover to HOG description from feature embeddings
unfold_size = self.target_generator.unfold_size
tmp4 = features.unflatten(2,
(features.shape[2] // unfold_size**2,
unfold_size, unfold_size)) # 1,196,27,2,2
tmp3 = tmp4.unflatten(1, self.backbone.patch_resolution)
b, p1, p2, c_nbins, _, _ = tmp3.shape # 1,14,14,27,2,2
tmp2 = tmp3.permute(0, 1, 2, 5, 3, 4).reshape(
(b, p1, p2 * unfold_size, c_nbins, unfold_size))
tmp1 = tmp2.permute(0, 1, 4, 2, 3).reshape(
(b, p1 * unfold_size, p2 * unfold_size, c_nbins))
tmp0 = tmp1.permute(0, 3, 1, 2) # 1,27,28,28
hog_out = tmp0.unflatten(1,
(int(c_nbins // self.target_generator.nbins),
self.target_generator.nbins)) # 1,3,9,28,28
# generate predction of HOG
hog_image = torch.from_numpy(
self.target_generator.generate_hog_image(hog_out))
hog_image = hog_image.unsqueeze(0).unsqueeze(0)
pred = torch.einsum('nchw->nhwc', hog_image).expand(-1, -1, -1,
3).detach().cpu()
# transform patch mask to pixel mask
mask = self.mask
patch_dim_1 = int(self.backbone.patch_embed.init_input_size[0] //
self.backbone.patch_resolution[0])
patch_dim_2 = int(self.backbone.patch_embed.init_input_size[1] //
self.backbone.patch_resolution[1])
mask = mask.repeat_interleave(
patch_dim_1, dim=1).repeat_interleave(
patch_dim_2, dim=2).unsqueeze(-1).repeat(1, 1, 1, 3)
# 1 is removing, 0 is keeping
mask = mask.detach().cpu()
results = SelfSupDataSample()
results.mask = BaseDataElement(**dict(value=mask))
results.pred = BaseDataElement(**dict(value=pred))
return results

View File

@ -1,7 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List
from typing import Dict, List, Optional
import torch
from mmengine.structures import BaseDataElement
from mmselfsup.registry import MODELS
from mmselfsup.structures import SelfSupDataSample
@ -33,6 +34,7 @@ class SimMIM(BaseModel):
[data_sample.mask.value for data_sample in data_samples])
img_latent = self.backbone(inputs[0], mask)
feat = self.neck(img_latent[0])
self.mask = mask
return feat
def loss(self, inputs: List[torch.Tensor],
@ -58,3 +60,37 @@ class SimMIM(BaseModel):
losses = dict(loss=loss)
return losses
def reconstruct(self,
features: torch.Tensor,
data_samples: Optional[List[SelfSupDataSample]] = None,
**kwargs) -> SelfSupDataSample:
"""The function is for image reconstruction.
Args:
features (torch.Tensor): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns:
SelfSupDataSample: The prediction from model.
"""
pred = torch.einsum('nchw->nhwc', features).detach().cpu()
# transform patch mask to pixel mask
mask = self.mask.detach()
p1 = int(self.backbone.patch_embed.init_input_size[0] //
self.backbone.patch_resolution[0])
p2 = int(self.backbone.patch_embed.init_input_size[1] //
self.backbone.patch_resolution[1])
mask = mask.repeat_interleave(
p1, dim=1).repeat_interleave(
p2, dim=2).unsqueeze(-1).repeat(1, 1, 1, 3) # (N, H, W, 3)
# 1 is removing, 0 is keeping
mask = mask.detach().cpu()
results = SelfSupDataSample()
results.mask = BaseDataElement(**dict(value=mask))
results.pred = BaseDataElement(**dict(value=pred))
return results

View File

@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from mmengine.model import BaseModule
@ -13,10 +15,10 @@ class HOGGenerator(BaseModule):
"""Generate HOG feature for images.
This module is used in MaskFeat to generate HOG feature. The code is
modified from this `file
modified from file `slowfast/models/operators.py
<https://github.com/facebookresearch/SlowFast/blob/main/slowfast/models/operators.py>`_.
Here is the link `HOG wikipedia
<https://en.m.wikipedia.org/wiki/Histogram_of_oriented_gradients>`_.
Here is the link of `HOG wikipedia
<https://en.wikipedia.org/wiki/Histogram_of_oriented_gradients>`_.
Args:
nbins (int): Number of bin. Defaults to 9.
@ -61,12 +63,12 @@ class HOGGenerator(BaseModule):
def _reshape(self, hog_feat: torch.Tensor) -> torch.Tensor:
"""Reshape HOG Features for output."""
hog_feat = hog_feat.flatten(1, 2)
unfold_size = hog_feat.shape[-1] // 14
hog_feat = (
hog_feat.permute(0, 2, 3,
1).unfold(1, unfold_size, unfold_size).unfold(
2, unfold_size,
unfold_size).flatten(1, 2).flatten(2))
self.unfold_size = hog_feat.shape[-1] // 14
hog_feat = hog_feat.permute(0, 2, 3, 1)
hog_feat = hog_feat.unfold(1, self.unfold_size,
self.unfold_size).unfold(
2, self.unfold_size, self.unfold_size)
hog_feat = hog_feat.flatten(1, 2).flatten(2)
return hog_feat
@torch.no_grad()
@ -80,6 +82,7 @@ class HOGGenerator(BaseModule):
torch.Tensor: Hog features.
"""
# input is RGB image with shape [B 3 H W]
self.h, self.w = x.size(-2), x.size(-1)
x = F.pad(x, pad=(1, 1, 1, 1), mode='reflect')
gx_rgb = F.conv2d(
x, self.weight_x, bias=None, stride=1, padding=0, groups=3)
@ -112,6 +115,38 @@ class HOGGenerator(BaseModule):
out = out.unfold(4, self.pool, self.pool)
out = out.sum(dim=[-1, -2])
out = F.normalize(out, p=2, dim=2)
self.out = F.normalize(out, p=2, dim=2)
return self._reshape(out)
return self._reshape(self.out)
def generate_hog_image(self, hog_out: torch.Tensor) -> np.ndarray:
"""Generate HOG image according to HOG features."""
assert hog_out.size(0) == 1 and hog_out.size(1) == 3, \
'Check the input batch size and the channcel number, only support'\
'"batch_size = 1".'
hog_image = np.zeros([self.h, self.w])
cell_gradient = np.array(hog_out.mean(dim=1).squeeze().detach().cpu())
cell_width = self.pool / 2
max_mag = np.array(cell_gradient).max()
angle_gap = 360 / self.nbins
for x in range(cell_gradient.shape[1]):
for y in range(cell_gradient.shape[2]):
cell_grad = cell_gradient[:, x, y]
cell_grad /= max_mag
angle = 0
for magnitude in cell_grad:
angle_radian = math.radians(angle)
x1 = int(x * self.pool +
magnitude * cell_width * math.cos(angle_radian))
y1 = int(y * self.pool +
magnitude * cell_width * math.sin(angle_radian))
x2 = int(x * self.pool -
magnitude * cell_width * math.cos(angle_radian))
y2 = int(y * self.pool -
magnitude * cell_width * math.sin(angle_radian))
magnitude = 0 if magnitude < 0 else magnitude
cv2.line(hog_image, (y1, x1), (y2, x2),
int(255 * math.sqrt(magnitude)))
angle += angle_gap
return hog_image

View File

@ -27,10 +27,10 @@ class ExampleModel(BaseModel):
super(ExampleModel, self).__init__(backbone=backbone)
self.layer = nn.Linear(1, 1)
def predict(self,
inputs: List[torch.Tensor],
data_samples: Optional[List[SelfSupDataSample]] = None,
**kwargs) -> SelfSupDataSample:
def extract_feat(self,
inputs: List[torch.Tensor],
data_samples: Optional[List[SelfSupDataSample]] = None,
**kwargs) -> SelfSupDataSample:
out = self.layer(inputs[0])
return out

View File

@ -48,9 +48,14 @@ def test_mae():
fake_outputs = alg(fake_batch_inputs, fake_data_samples, mode='loss')
assert isinstance(fake_outputs['loss'].item(), float)
# test extraction
fake_feats = alg(fake_batch_inputs, fake_data_samples, mode='tensor')
assert list(fake_feats.shape) == [2, 196, 768]
results = alg(fake_batch_inputs, fake_data_samples, mode='predict')
# test reconstruct
mean = fake_feats.mean(dim=-1, keepdim=True)
std = (fake_feats.var(dim=-1, keepdim=True) + 1.e-6)**.5
results = alg.reconstruct(
fake_feats, fake_data_samples, mean=mean, std=std)
assert list(results.mask.value.shape) == [2, 224, 224, 3]
assert list(results.pred.value.shape) == [2, 224, 224, 3]

View File

@ -5,6 +5,7 @@ import platform
import pytest
import torch
from mmengine.structures import InstanceData
from mmengine.utils import digit_version
from mmselfsup.models.algorithms.maskfeat import MaskFeat
from mmselfsup.structures import SelfSupDataSample
@ -22,6 +23,9 @@ target_generator = dict(
type='HOGGenerator', nbins=9, pool=8, gaussian_window=16)
@pytest.mark.skipif(
digit_version(torch.__version__) < digit_version('1.7.0'),
reason='torch version')
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_maskfeat():
data_preprocessor = {
@ -42,13 +46,19 @@ def test_maskfeat():
fake_mask = InstanceData(value=torch.rand((14, 14)).bool())
fake_data_sample.mask = fake_mask
fake_data = {
'inputs': [torch.randn((2, 3, 224, 224))],
'data_sample': [fake_data_sample for _ in range(2)]
'inputs': [torch.randn((1, 3, 224, 224))],
'data_sample': [fake_data_sample for _ in range(1)]
}
fake_batch_inputs, fake_data_samples = alg.data_preprocessor(fake_data)
fake_outputs = alg(fake_batch_inputs, fake_data_samples, mode='loss')
assert isinstance(fake_outputs['loss'].item(), float)
# test extraction
fake_feats = alg.extract_feat(fake_batch_inputs, fake_data_samples)
assert list(fake_feats.shape) == [2, 197, 768]
assert list(fake_feats.shape) == [1, 196, 108]
# test reconstruction
results = alg.reconstruct(fake_feats, fake_data_samples)
assert list(results.mask.value.shape) == [1, 224, 224, 3]
assert list(results.pred.value.shape) == [1, 224, 224, 3]

View File

@ -50,5 +50,10 @@ def test_simmim():
# test extract_feat
fake_inputs, fake_data_samples = model.data_preprocessor(fake_data)
fake_feat = model.extract_feat(fake_inputs, fake_data_samples)
assert list(fake_feat.shape) == [2, 3, 192, 192]
fake_feats = model.extract_feat(fake_inputs, fake_data_samples)
assert list(fake_feats.shape) == [2, 3, 192, 192]
# test reconstruct
results = model.reconstruct(fake_feats, fake_data_samples)
assert list(results.mask.value.shape) == [2, 192, 192, 3]
assert list(results.pred.value.shape) == [2, 192, 192, 3]

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmselfsup.models.target_generators import HOGGenerator
@ -10,3 +11,10 @@ def test_hog_generator():
fake_input = torch.randn((2, 3, 224, 224))
fake_output = hog_generator(fake_input)
assert list(fake_output.shape) == [2, 196, 108]
fake_hog_out = hog_generator.out[0].unsqueeze(0)
fake_hog_img = hog_generator.generate_hog_image(fake_hog_out)
assert fake_hog_img.shape == (224, 224)
with pytest.raises(AssertionError):
fake_hog_img = hog_generator.generate_hog_image(hog_generator.out[0])

View File

@ -1,108 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Modified from https://colab.research.google.com/github/facebookresearch/mae
# /blob/main/demo/mae_visualize.ipynb
from argparse import ArgumentParser
from typing import Tuple
import matplotlib.pyplot as plt
import numpy as np
import torch
from mmengine.dataset import Compose, default_collate
from mmselfsup.apis import inference_model, init_model
from mmselfsup.registry import MODELS
from mmselfsup.utils import register_all_modules
imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_std = np.array([0.229, 0.224, 0.225])
def show_image(image: torch.Tensor, title: str = '') -> None:
# image is [H, W, 3]
assert image.shape[2] == 3
image = torch.clip((image * imagenet_std + imagenet_mean) * 255, 0,
255).int()
plt.imshow(image)
plt.title(title, fontsize=16)
plt.axis('off')
return
def save_images(x: torch.Tensor, im_masked: torch.Tensor, y: torch.Tensor,
im_paste: torch.Tensor, out_file: str) -> None:
# make the plt figure larger
plt.rcParams['figure.figsize'] = [24, 6]
plt.subplot(1, 4, 1)
show_image(x, 'original')
plt.subplot(1, 4, 2)
show_image(im_masked, 'masked')
plt.subplot(1, 4, 3)
show_image(y, 'reconstruction')
plt.subplot(1, 4, 4)
show_image(im_paste, 'reconstruction + visible')
plt.savefig(out_file)
def post_process(
x: torch.Tensor, y: torch.Tensor, mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
x = torch.einsum('nchw->nhwc', x.cpu())
# masked image
im_masked = x * (1 - mask)
# MAE reconstruction pasted with visible patches
im_paste = x * (1 - mask) + y * mask
return x[0], im_masked[0], y[0], im_paste[0]
def main():
parser = ArgumentParser()
parser.add_argument('img_path', help='Image file path')
parser.add_argument('config', help='MAE Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument('out_file', help='The output image file path')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
args = parser.parse_args()
register_all_modules()
# build the model from a config file and a checkpoint file
model = init_model(args.config, args.checkpoint, device=args.device)
print('Model loaded.')
# make random mask reproducible (comment out to make it change)
torch.manual_seed(2)
print('MAE with pixel reconstruction:')
model.cfg.test_dataloader = dict(
dataset=dict(pipeline=[
dict(
type='LoadImageFromFile',
file_client_args=dict(backend='disk')),
dict(type='Resize', scale=(224, 224)),
dict(type='PackSelfSupInputs', meta_keys=['img_path'])
]))
results = inference_model(model, args.img_path)
cfg = model.cfg
test_pipeline = Compose(cfg.test_dataloader.dataset.pipeline)
data_preprocessor = MODELS.build(cfg.model.data_preprocessor)
data = dict(img_path=args.img_path)
data = test_pipeline(data)
data = default_collate([data])
img, _ = data_preprocessor(data, False)
x, im_masked, y, im_paste = post_process(img[0], results.pred.value,
results.mask.value)
save_images(x, im_masked, y, im_paste, args.out_file)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,178 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from https://colab.research.google.com/github/facebookresearch/mae
# /blob/main/demo/mae_visualize.ipynb
import random
from argparse import ArgumentParser
from typing import Tuple
import matplotlib.pyplot as plt
import numpy as np
import torch
from mmengine.dataset import Compose, default_collate
from mmselfsup.apis import inference_model, init_model
from mmselfsup.utils import register_all_modules
imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_std = np.array([0.229, 0.224, 0.225])
def show_image(img: torch.Tensor, title: str = '') -> None:
# image is [H, W, 3]
assert img.shape[2] == 3
plt.imshow(img)
plt.title(title, fontsize=16)
plt.axis('off')
return
def save_images(original_img: torch.Tensor, img_masked: torch.Tensor,
pred_img: torch.Tensor, img_paste: torch.Tensor,
out_file: str) -> None:
# make the plt figure larger
plt.rcParams['figure.figsize'] = [24, 6]
plt.subplot(1, 4, 1)
show_image(original_img, 'original')
plt.subplot(1, 4, 2)
show_image(img_masked, 'masked')
plt.subplot(1, 4, 3)
show_image(pred_img, 'reconstruction')
plt.subplot(1, 4, 4)
show_image(img_paste, 'reconstruction + visible')
plt.savefig(out_file)
print(f'Images are saved to {out_file}')
def recover_norm(img: torch.Tensor,
mean: np.ndarray = imagenet_mean,
std: np.ndarray = imagenet_std):
if mean is not None and std is not None:
img = torch.clip((img * std + mean) * 255, 0, 255).int()
return img
def post_process(
original_img: torch.Tensor,
pred_img: torch.Tensor,
mask: torch.Tensor,
mean: np.ndarray = imagenet_mean,
std: np.ndarray = imagenet_std
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# channel conversion
original_img = torch.einsum('nchw->nhwc', original_img.cpu())
# masked image
img_masked = original_img * (1 - mask)
# reconstructed image pasted with visible patches
img_paste = original_img * (1 - mask) + pred_img * mask
# muptiply std and add mean to each image
original_img = recover_norm(original_img[0])
img_masked = recover_norm(img_masked[0])
pred_img = recover_norm(pred_img[0])
img_paste = recover_norm(img_paste[0])
return original_img, img_masked, pred_img, img_paste
def main():
parser = ArgumentParser()
parser.add_argument('config', help='Model config file')
parser.add_argument('--checkpoint', help='Checkpoint file')
parser.add_argument('--img-path', help='Image file path')
parser.add_argument('--out-file', help='The output image file path')
parser.add_argument(
'--use-vis-pipeline',
action='store_true',
help='Use vis_pipeline defined in config. For some algorithms, such '
'as SimMIM and MaskFeat, they generate mask in data pipeline, thus '
'the visualization process applies vis_pipeline in config to obtain '
'the mask.')
parser.add_argument(
'--norm-pix',
action='store_true',
help='MAE uses `norm_pix_loss` for optimization in pre-training, thus '
'the visualization process also need to compute mean and std of each '
'patch embedding while reconstructing the original images.')
parser.add_argument(
'--target-generator',
action='store_true',
help='Some algorithms use target_generator for optimization in '
'pre-training, such as MaskFeat, thus the visualization process could '
'turn this on to visualize the target instead of RGB image.')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--seed',
type=int,
default=0,
help='The random seed for visualization')
args = parser.parse_args()
register_all_modules()
# build the model from a config file and a checkpoint file
model = init_model(args.config, args.checkpoint, device=args.device)
print('Model loaded.')
# make random mask reproducible (comment out to make it change)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
print('Reconstruction visualization.')
if args.use_vis_pipeline:
model.cfg.test_dataloader = dict(
dataset=dict(pipeline=model.cfg.vis_pipeline))
else:
model.cfg.test_dataloader = dict(
dataset=dict(pipeline=[
dict(
type='LoadImageFromFile',
file_client_args=dict(backend='disk')),
dict(type='Resize', scale=(224, 224), backend='pillow'),
dict(type='PackSelfSupInputs', meta_keys=['img_path'])
]))
# get original image
vis_pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline)
data = dict(img_path=args.img_path)
data = vis_pipeline(data)
data = default_collate([data])
img, _ = model.data_preprocessor(data, False)
if args.norm_pix:
# for MAE reconstruction
img_embedding = model.head.patchify(img[0])
# normalize the target image
mean = img_embedding.mean(dim=-1, keepdim=True)
std = (img_embedding.var(dim=-1, keepdim=True) + 1.e-6)**.5
else:
mean = imagenet_mean
std = imagenet_std
# get reconstruction image
features = inference_model(model, args.img_path)
results = model.reconstruct(features, mean=mean, std=std)
original_target = model.target if args.target_generator else img[0]
original_img, img_masked, pred_img, img_paste = post_process(
original_target,
results.pred.value,
results.mask.value,
mean=mean,
std=std)
save_images(original_img, img_masked, pred_img, img_paste, args.out_file)
if __name__ == '__main__':
main()