From 73cd764b5f72891994e440429380eed78fb87841 Mon Sep 17 00:00:00 2001 From: Yixiao Fang <36138628+fangyixiao18@users.noreply.github.com> Date: Tue, 6 Dec 2022 19:45:01 +0800 Subject: [PATCH] [Feature] Support pixel reconstruction visualization (#570) * refactor reconstruction visualization * support simmim visualization * fix reconstruction bug of MAE * support visualization of MaskFeat * refaction mae visualization demo * add unit test * fix lint and ut * update * add docs * set random seed * update * update docstring * add torch version check * update * rename * update version * update * fix lint * add docstring * update docs --- .../_base_/datasets/imagenet_maskfeat.py | 16 + .../_base_/datasets/imagenet_simmim.py | 16 + demo/mae_visualization.ipynb | 290 +++++++++--------- docs/en/user_guides/analysis_tools.md | 52 ---- docs/en/user_guides/visualization.md | 118 ++++++- mmselfsup/apis/inference.py | 5 +- mmselfsup/models/algorithms/base.py | 2 +- mmselfsup/models/algorithms/mae.py | 30 +- mmselfsup/models/algorithms/maskfeat.py | 85 ++++- mmselfsup/models/algorithms/simmim.py | 38 ++- .../models/target_generators/hog_generator.py | 57 +++- tests/test_apis/test_inference.py | 8 +- tests/test_models/test_algorithms/test_mae.py | 7 +- .../test_algorithms/test_maskfeat.py | 16 +- .../test_algorithms/test_simmim.py | 9 +- .../test_hog_generator.py | 8 + tools/analysis_tools/mae_visualization.py | 108 ------- .../visualize_reconstruction.py | 178 +++++++++++ 18 files changed, 693 insertions(+), 350 deletions(-) delete mode 100644 tools/analysis_tools/mae_visualization.py create mode 100644 tools/analysis_tools/visualize_reconstruction.py diff --git a/configs/selfsup/_base_/datasets/imagenet_maskfeat.py b/configs/selfsup/_base_/datasets/imagenet_maskfeat.py index 0566b201..a7c995ce 100644 --- a/configs/selfsup/_base_/datasets/imagenet_maskfeat.py +++ b/configs/selfsup/_base_/datasets/imagenet_maskfeat.py @@ -37,3 +37,19 @@ train_dataloader = dict( ann_file='meta/train.txt', data_prefix=dict(img_path='train/'), pipeline=train_pipeline)) + +# for visualization +vis_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict(type='Resize', scale=(224, 224), backend='pillow'), + dict( + type='BEiTMaskGenerator', + input_size=14, + num_masking_patches=78, + min_num_patches=15, + ), + dict( + type='PackSelfSupInputs', + algorithm_keys=['mask'], + meta_keys=['img_path']) +] diff --git a/configs/selfsup/_base_/datasets/imagenet_simmim.py b/configs/selfsup/_base_/datasets/imagenet_simmim.py index d1e2e9e2..677b4a76 100644 --- a/configs/selfsup/_base_/datasets/imagenet_simmim.py +++ b/configs/selfsup/_base_/datasets/imagenet_simmim.py @@ -34,3 +34,19 @@ train_dataloader = dict( ann_file='meta/train.txt', data_prefix=dict(img_path='train/'), pipeline=train_pipeline)) + +# for visualization +vis_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict(type='Resize', scale=(192, 192), backend='pillow'), + dict( + type='SimMIMMaskGenerator', + input_size=192, + mask_patch_size=32, + model_patch_size=4, + mask_ratio=0.6), + dict( + type='PackSelfSupInputs', + algorithm_keys=['mask'], + meta_keys=['img_path']) +] diff --git a/demo/mae_visualization.ipynb b/demo/mae_visualization.ipynb index 850c24c6..13c9fd5a 100644 --- a/demo/mae_visualization.ipynb +++ b/demo/mae_visualization.ipynb @@ -6,8 +6,6 @@ "source": [ "Copyright (c) OpenMMLab. All rights reserved.\n", "\n", - "Copyright (c) Meta Platforms, Inc. and affiliates.\n", - "\n", "Modified from https://colab.research.google.com/github/facebookresearch/mae/blob/main/demo/mae_visualize.ipynb\n", "\n", "## Masked Autoencoders: Visualization Demo\n", @@ -36,7 +34,8 @@ " print('Running in Colab.')\n", " !pip3 install openmim\n", " !pip install -U openmim\n", - " !mim install 'mmengine==0.1.0' 'mmcv>=2.0.0rc1'\n", + " !mim install mmengine\n", + " !mim install 'mmcv>=2.0.0rc1'\n", "\n", " !git clone https://github.com/open-mmlab/mmselfsup.git\n", " %cd mmselfsup/\n", @@ -51,18 +50,19 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ + "from argparse import ArgumentParser\n", + "from typing import Tuple, Optional\n", + "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import torch\n", "from mmengine.dataset import Compose, default_collate\n", "\n", - "from mmselfsup.apis import inference_model\n", - "from mmselfsup.models.utils import SelfSupDataPreprocessor\n", - "from mmselfsup.registry import MODELS\n", + "from mmselfsup.apis import inference_model, init_model\n", "from mmselfsup.utils import register_all_modules" ] }, @@ -75,7 +75,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 22, "metadata": {}, "outputs": [], "source": [ @@ -84,49 +84,81 @@ "imagenet_mean = np.array([0.485, 0.456, 0.406])\n", "imagenet_std = np.array([0.229, 0.224, 0.225])\n", "\n", - "def show_image(image, title=''):\n", + "\n", + "def show_image(img: torch.Tensor, title: str = '') -> None:\n", " # image is [H, W, 3]\n", - " assert image.shape[2] == 3\n", - " image = torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int()\n", - " plt.imshow(image)\n", + " assert img.shape[2] == 3\n", + "\n", + " plt.imshow(img)\n", " plt.title(title, fontsize=16)\n", " plt.axis('off')\n", " return\n", "\n", "\n", - "def show_images(x, im_masked, y, im_paste):\n", + "def save_images(original_img: torch.Tensor, img_masked: torch.Tensor,\n", + " pred_img: torch.Tensor, img_paste: torch.Tensor,\n", + " out_file: Optional[str] =None) -> None:\n", " # make the plt figure larger\n", " plt.rcParams['figure.figsize'] = [24, 6]\n", "\n", " plt.subplot(1, 4, 1)\n", - " show_image(x, \"original\")\n", + " show_image(original_img, 'original')\n", "\n", " plt.subplot(1, 4, 2)\n", - " show_image(im_masked, \"masked\")\n", + " show_image(img_masked, 'masked')\n", "\n", " plt.subplot(1, 4, 3)\n", - " show_image(y, \"reconstruction\")\n", + " show_image(pred_img, 'reconstruction')\n", "\n", " plt.subplot(1, 4, 4)\n", - " show_image(im_paste, \"reconstruction + visible\")\n", + " show_image(img_paste, 'reconstruction + visible')\n", "\n", - " plt.show()\n", + " if out_file is None:\n", + " plt.show()\n", + " else:\n", + " plt.savefig(out_file)\n", + " print(f'Images are saved to {out_file}')\n", "\n", "\n", - "def post_process(x, y, mask):\n", - " x = torch.einsum('nchw->nhwc', x.cpu())\n", + "def recover_norm(img: torch.Tensor,\n", + " mean: np.ndarray = imagenet_mean,\n", + " std: np.ndarray = imagenet_std):\n", + " if mean is not None and std is not None:\n", + " img = torch.clip((img * std + mean) * 255, 0, 255).int()\n", + " return img\n", + "\n", + "\n", + "def post_process(\n", + " original_img: torch.Tensor,\n", + " pred_img: torch.Tensor,\n", + " mask: torch.Tensor,\n", + " mean: np.ndarray = imagenet_mean,\n", + " std: np.ndarray = imagenet_std\n", + ") -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n", + " # channel conversion\n", + " original_img = torch.einsum('nchw->nhwc', original_img.cpu())\n", " # masked image\n", - " im_masked = x * (1 - mask)\n", - " # MAE reconstruction pasted with visible patches\n", - " im_paste = x * (1 - mask) + y * mask\n", - " return x[0], im_masked[0], y[0], im_paste[0]" + " img_masked = original_img * (1 - mask)\n", + " # reconstructed image pasted with visible patches\n", + " img_paste = original_img * (1 - mask) + pred_img * mask\n", + "\n", + " # muptiply std and add mean to each image\n", + " original_img = recover_norm(original_img[0])\n", + " img_masked = recover_norm(img_masked[0])\n", + "\n", + " pred_img = recover_norm(pred_img[0])\n", + " img_paste = recover_norm(img_paste[0])\n", + "\n", + " return original_img, img_masked, pred_img, img_paste\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Prepare config file" + "### Load a pre-trained MAE model\n", + "\n", + "This is an MAE model trained with config 'mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k.py'.\n" ] }, { @@ -138,55 +170,20 @@ "name": "stdout", "output_type": "stream", "text": [ - "Overwriting ../configs/selfsup/mae/mae_visualization.py\n" + "--2022-11-08 11:00:50-- https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth\n", + "正在解析主机 download.openmmlab.com (download.openmmlab.com)... 47.102.71.233\n", + "正在连接 download.openmmlab.com (download.openmmlab.com)|47.102.71.233|:443... 已连接。\n", + "已发出 HTTP 请求,正在等待回应... 200 OK\n", + "长度: 1355429265 (1.3G) [application/octet-stream]\n", + "正在保存至: “mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth”\n", + "\n", + "e_in1k_20220825-cc7 99%[==================> ] 1.26G 913KB/s 剩余 0s s " ] } ], "source": [ - "%%writefile ../configs/selfsup/mae/mae_visualization.py\n", - "model = dict(\n", - " type='MAE',\n", - " data_preprocessor=dict(\n", - " mean=[123.675, 116.28, 103.53],\n", - " std=[58.395, 57.12, 57.375],\n", - " bgr_to_rgb=True),\n", - " backbone=dict(type='MAEViT', arch='l', patch_size=16, mask_ratio=0.75),\n", - " neck=dict(\n", - " type='MAEPretrainDecoder',\n", - " patch_size=16,\n", - " in_chans=3,\n", - " embed_dim=1024,\n", - " decoder_embed_dim=512,\n", - " decoder_depth=8,\n", - " decoder_num_heads=16,\n", - " mlp_ratio=4.,\n", - " ),\n", - " head=dict(\n", - " type='MAEPretrainHead',\n", - " norm_pix=True,\n", - " patch_size=16,\n", - " loss=dict(type='MAEReconstructionLoss')),\n", - " init_cfg=[\n", - " dict(type='Xavier', distribution='uniform', layer='Linear'),\n", - " dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0)\n", - " ])\n", - "\n", - "file_client_args = dict(backend='disk')\n", - "\n", - "# dataset summary\n", - "test_dataloader = dict(\n", - " dataset=dict(pipeline=[\n", - " dict(type='LoadImageFromFile', file_client_args=file_client_args),\n", - " dict(type='Resize', scale=(224, 224)),\n", - " dict(type='PackSelfSupInputs', meta_keys=['img_path'])\n", - " ]))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Load a pre-trained MAE model" + "# download checkpoint if not exist\n", + "!wget -nc https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth" ] }, { @@ -198,46 +195,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "--2022-09-03 00:34:55-- https://download.openmmlab.com/mmselfsup/mae/mae_visualize_vit_large.pth\n", - "正在解析主机 download.openmmlab.com (download.openmmlab.com)... 47.107.10.247\n", - "正在连接 download.openmmlab.com (download.openmmlab.com)|47.107.10.247|:443... 已连接。\n", - "已发出 HTTP 请求,正在等待回应... 200 OK\n", - "长度: 1318299501 (1.2G) [application/octet-stream]\n", - "正在保存至: “mae_visualize_vit_large.pth”\n", + "local loads checkpoint from path: mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth\n", + "The model and loaded state dict do not match exactly\n", "\n", - "mae_visualize_vit_l 100%[===================>] 1.23G 3.22MB/s 用时 6m 4s \n", + "unexpected key in source state_dict: data_preprocessor.mean, data_preprocessor.std\n", "\n", - "2022-09-03 00:40:59 (3.46 MB/s) - 已保存 “mae_visualize_vit_large.pth” [1318299501/1318299501])\n", - "\n" - ] - } - ], - "source": [ - "# This is an MAE model trained with pixels as targets for visualization (ViT-large, training mask ratio=0.75)\n", - "\n", - "# download checkpoint if not exist\n", - "# This ckpt is converted from https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth\n", - "!wget -nc https://download.openmmlab.com/mmselfsup/mae/mae_visualize_vit_large.pth" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "local loads checkpoint from path: mae_visualize_vit_large.pth\n", "Model loaded.\n" ] } ], "source": [ - "from mmselfsup.apis import init_model\n", - "ckpt_path = \"mae_visualize_vit_large.pth\"\n", - "model = init_model('../configs/selfsup/mae/mae_visualization.py', ckpt_path, device='cpu')\n", + "ckpt_path = \"mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth\"\n", + "model = init_model(\n", + " '../configs/selfsup/mae/mae_vit-large-p16_8xb512-amp-coslr-1600e_in1k.py',\n", + " ckpt_path,\n", + " device='cpu')\n", "print('Model loaded.')" ] }, @@ -250,16 +222,16 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 7, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -272,23 +244,23 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "--2022-09-03 00:41:01-- https://download.openmmlab.com/mmselfsup/mae/fox.jpg\n", - "正在解析主机 download.openmmlab.com (download.openmmlab.com)... 101.133.111.186\n", - "正在连接 download.openmmlab.com (download.openmmlab.com)|101.133.111.186|:443... 已连接。\n", + "--2022-11-08 11:21:14-- https://download.openmmlab.com/mmselfsup/mae/fox.jpg\n", + "正在解析主机 download.openmmlab.com (download.openmmlab.com)... 47.102.71.233\n", + "正在连接 download.openmmlab.com (download.openmmlab.com)|47.102.71.233|:443... 已连接。\n", "已发出 HTTP 请求,正在等待回应... 200 OK\n", "长度: 60133 (59K) [image/jpeg]\n", "正在保存至: “fox.jpg”\n", "\n", - "fox.jpg 100%[===================>] 58.72K --.-KB/s 用时 0.06s \n", + "fox.jpg 100%[===================>] 58.72K --.-KB/s 用时 0.05s \n", "\n", - "2022-09-03 00:41:01 (962 KB/s) - 已保存 “fox.jpg” [60133/60133])\n", + "2022-11-08 11:21:15 (1.08 MB/s) - 已保存 “fox.jpg” [60133/60133])\n", "\n" ] } @@ -299,22 +271,34 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "img_path = 'fox.jpg'" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Build Pipeline" + ] + }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ - "cfg = model.cfg\n", - "test_pipeline = Compose(cfg.test_dataloader.dataset.pipeline)\n", - "data_preprocessor = MODELS.build(cfg.model.data_preprocessor)" + "model.cfg.test_dataloader = dict(\n", + " dataset=dict(pipeline=[\n", + " dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),\n", + " dict(type='Resize', scale=(224, 224), backend='pillow'),\n", + " dict(type='PackSelfSupInputs', meta_keys=['img_path'])\n", + " ]))\n", + "\n", + "vis_pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline)" ] }, { @@ -324,36 +308,62 @@ "outputs": [], "source": [ "data = dict(img_path=img_path)\n", - "data = test_pipeline(data)\n", + "data = vis_pipeline(data)\n", "data = default_collate([data])\n", - "img, _ = data_preprocessor(data, False)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plt.rcParams['figure.figsize'] = [5, 5]\n", - "show_image(torch.einsum('nchw->nhwc', img[0].cpu())[0])" + "img, _ = model.data_preprocessor(data, False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Run MAE on the image" + "### Reconstruction pipeline" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ - "results = inference_model(model, img_path)\n", - "x, im_masked, y, im_paste = post_process(img[0], results.pred.value, results.mask.value)" + "# for MAE reconstruction\n", + "img_embedding = model.head.patchify(img[0])\n", + "# normalize the target image\n", + "mean = img_embedding.mean(dim=-1, keepdim=True)\n", + "std = (img_embedding.var(dim=-1, keepdim=True) + 1.e-6)**.5" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "# get reconstruction image\n", + "features = inference_model(model, img_path)\n", + "results = model.reconstruct(features, mean=mean, std=std)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "original_target = img[0]\n", + "original_img, img_masked, pred_img, img_paste = post_process(\n", + " original_target,\n", + " results.pred.value,\n", + " results.mask.value,\n", + " mean=mean,\n", + " std=std)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Show the image" ] }, { @@ -362,21 +372,13 @@ "metadata": {}, "outputs": [], "source": [ - "print('MAE with pixel reconstruction:')\n", - "show_images(x, im_masked, y, im_paste)" + "save_images(original_img, img_masked, pred_img, img_paste)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3.7.0 ('openmmlab')", "language": "python", "name": "python3" }, @@ -390,11 +392,11 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.13" + "version": "3.7.0" }, "vscode": { "interpreter": { - "hash": "1742319693997e01e5942276ccf039297cd0a474ab9a20f711b7fa536eca5436" + "hash": "5909b3386efe3692f76356628babf720cfd47771f5d858315790cc041eb41361" } } }, diff --git a/docs/en/user_guides/analysis_tools.md b/docs/en/user_guides/analysis_tools.md index 33069192..7a4b1cde 100644 --- a/docs/en/user_guides/analysis_tools.md +++ b/docs/en/user_guides/analysis_tools.md @@ -7,8 +7,6 @@ - [Publish a model](#publish-a-model) - [Reproducibility](#reproducibility) - [Log Analysis](#log-analysis) - - [Visualize Datasets](#visualize-datasets) - - [Use t-SNE](#use-t-sne) ## Count number of parameters @@ -92,53 +90,3 @@ Examples: time std over epochs is 0.0028 average iter time: 1.1959 s/iter ``` - -## Visualize Datasets - -`tools/misc/browse_dataset.py` helps the user to browse a mmselfsup dataset (transformed images) visually, or save the image to a designated directory. - -```shell -python tools/misc/browse_dataset.py ${CONFIG} [-h] [--skip-type ${SKIP_TYPE[SKIP_TYPE...]}] [--output-dir ${OUTPUT_DIR}] [--not-show] [--show-interval ${SHOW_INTERVAL}] -``` - -An example: - -```shell -python tools/misc/browse_dataset.py configs/selfsup/simsiam/simsiam_resnet50_8xb32-coslr-100e_in1k.py -``` - -An example of visualization: - -
- -
- -- The left two pictures are images from contrastive learning data pipeline. -- The right one is a masked image. - -## Use t-SNE - -We provide an off-the-shelf tool to visualize the quality of image representations by t-SNE. - -```shell -python tools/analysis_tools/visualize_tsne.py ${CONFIG_FILE} --checkpoint ${CKPT_PATH} --work-dir ${WORK_DIR} [optional arguments] -``` - -Arguments: - -- `CONFIG_FILE`: config file for the pre-trained model. -- `CKPT_PATH`: the path of model's checkpoint. -- `WORK_DIR`: the directory to save the results of visualization. -- `[optional arguments]`: for optional arguments, you can refer to [visualize_tsne.py](https://github.com/open-mmlab/mmselfsup/blob/master/tools/analysis_tools/visualize_tsne.py) - -An example: - -```shell -python tools/analysis_tools/visualize_tsne.py configs/selfsup/simsiam/simsiam_resnet50_8xb32-coslr-100e_in1k.py --checkpoint epoch_100.pth --work-dir work_dirs/selfsup/simsiam_resnet50_8xb32-coslr-200e_in1k -``` - -An example of visualization: - -
- -
diff --git a/docs/en/user_guides/visualization.md b/docs/en/user_guides/visualization.md index 1c3ccb87..6f167370 100644 --- a/docs/en/user_guides/visualization.md +++ b/docs/en/user_guides/visualization.md @@ -7,8 +7,11 @@ Visualization can give an intuitive interpretation of the performance of the mod - [Visualization](#visualization) - [How visualization is implemented](#how-visualization-is-implemented) - [What Visualization do in MMSelfsup](#what-visualization-do-in-mmselfsup) - - [Use different storage backends](#use-different-storage-backends) + - [Use Different Storage Backends](#use-different-storage-backends) - [Customize Visualization](#customize-visualization) + - [Visualize Datasets](#visualize-datasets) + - [Visualize t-SNE](#visualize-t-sne) + - [Visualize Low-level Feature Reconstruction](#visualize-low-level-feature-reconstruction) @@ -43,7 +46,7 @@ def after_train_iter(...): The function [`add_datasample()`](https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/mmselfsup/visualization/selfsup_visualizer.py#L151) is impleted in [`SelfSupVisualizer`](mmselfsup.visualization.SelfSupVisualizer), and it is mainly used in [browse_dataset.py](https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/tools/analysis_tools/browse_dataset.py) for browsing dataset. More tutorial is in [analysis_tools.md](analysis_tools.md) -## Use different storage backends +## Use Different Storage Backends If you want to use a different backend (Wandb, Tensorboard, or a custom backend with a remote window), just change the `vis_backends` in the config, as follows: @@ -86,3 +89,114 @@ E.g. ## Customize Visualization The customization of the visualization is similar to other components. If you want to customize `Visualizer`, `VisBackend` or `VisualizationHook`, you can refer to [Visualization Doc](https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/visualization.md) in MMEngine. + +## Visualize Datasets + +`tools/misc/browse_dataset.py` helps the user to browse a mmselfsup dataset (transformed images) visually, or save the image to a designated directory. + +```shell +python tools/misc/browse_dataset.py ${CONFIG} [-h] [--skip-type ${SKIP_TYPE[SKIP_TYPE...]}] [--output-dir ${OUTPUT_DIR}] [--not-show] [--show-interval ${SHOW_INTERVAL}] +``` + +An example: + +```shell +python tools/misc/browse_dataset.py configs/selfsup/simsiam/simsiam_resnet50_8xb32-coslr-100e_in1k.py +``` + +An example of visualization: + +
+ +
+ +- The left two pictures are images from contrastive learning data pipeline. +- The right one is a masked image. + +## Visualize t-SNE + +We provide an off-the-shelf tool to visualize the quality of image representations by t-SNE. + +```shell +python tools/analysis_tools/visualize_tsne.py ${CONFIG_FILE} --checkpoint ${CKPT_PATH} --work-dir ${WORK_DIR} [optional arguments] +``` + +Arguments: + +- `CONFIG_FILE`: config file for the pre-trained model. +- `CKPT_PATH`: the path of model's checkpoint. +- `WORK_DIR`: the directory to save the results of visualization. +- `[optional arguments]`: for optional arguments, you can refer to [visualize_tsne.py](https://github.com/open-mmlab/mmselfsup/blob/master/tools/analysis_tools/visualize_tsne.py) + +An example: + +```shell +python tools/analysis_tools/visualize_tsne.py configs/selfsup/simsiam/simsiam_resnet50_8xb32-coslr-100e_in1k.py --checkpoint epoch_100.pth --work-dir work_dirs/selfsup/simsiam_resnet50_8xb32-coslr-200e_in1k +``` + +An example of visualization: + +
+ +
+ +## Visualize Low-level Feature Reconstruction + +We provide several reconstruction visualization for listed algorithms: + +- MAE +- SimMIM +- MaskFeat + +Users can run command below to visualize the reconstruction. + +```shell +python tools/analysis_tools/visualize_reconstruction.py ${CONFIG_FILE} \ + --checkpoint ${CKPT_PATH} \ + --img-path ${IMAGE_PATH} \ + --out-file ${OUTPUT_PATH} +``` + +Arguments: + +- `CONFIG_FILE`: config file for the pre-trained model. +- `CKPT_PATH`: the path of model's checkpoint. +- `IMAGE_PATH`: the input image path. +- `OUTPUT_PATH`: the output image path, including 4 sub-images. +- `[optional arguments]`: for optional arguments, you can refer to [visualize_reconstruction.py](https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/tools/analysis_tools/visualize_reconstruction.py) + +An example: + +```shell +python tools/analysis_tools/visualize_reconstruction.py configs/selfsup/mae/mae_vit-huge-p16_8xb512-amp-coslr-1600e_in1k.py \ + --checkpoint https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-huge-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-huge-p16_8xb512-fp16-coslr-1600e_in1k_20220916-ff848775.pth \ + --img-path data/imagenet/val/ILSVRC2012_val_00000003.JPEG \ + --out-file test_mae.jpg \ + --norm-pix + + +# As for SimMIM, it generates the mask in data pipeline, thus we use '--use-vis-pipeline' to apply 'vis_pipeline' defined in config instead of the pipeline defined in script. +python tools/analysis_tools/visualize_reconstruction.py configs/selfsup/simmim/simmim_swin-large_16xb128-amp-coslr-800e_in1k-192.py \ + --checkpoint https://download.openmmlab.com/mmselfsup/1.x/simmim/simmim_swin-large_16xb128-amp-coslr-800e_in1k-192/simmim_swin-large_16xb128-amp-coslr-800e_in1k-192_20220916-4ad216d3.pth \ + --img-path data/imagenet/val/ILSVRC2012_val_00000003.JPEG \ + --out-file test_simmim.jpg \ + --use-vis-pipeline +``` + +Results of MAE: + +
+ +
+ +Results of SimMIM: + +
+ +
+ +Results of MaskFeat: + +
+ +
diff --git a/mmselfsup/apis/inference.py b/mmselfsup/apis/inference.py index f448b8f8..fbc02897 100644 --- a/mmselfsup/apis/inference.py +++ b/mmselfsup/apis/inference.py @@ -81,5 +81,6 @@ def inference_model(model: nn.Module, # forward the model with torch.no_grad(): - results = model.test_step(data) - return results + inputs, data_samples = model.data_preprocessor(data, False) + features = model(inputs, data_samples, mode='tensor') + return features diff --git a/mmselfsup/models/algorithms/base.py b/mmselfsup/models/algorithms/base.py index fa43fde9..fad0689d 100644 --- a/mmselfsup/models/algorithms/base.py +++ b/mmselfsup/models/algorithms/base.py @@ -119,7 +119,7 @@ class BaseModel(_BaseModel): or ``dict of tensor for custom use. """ if mode == 'tensor': - feats = self.extract_feat(inputs) + feats = self.extract_feat(inputs, data_samples=data_samples) return feats elif mode == 'loss': return self.loss(inputs, data_samples) diff --git a/mmselfsup/models/algorithms/mae.py b/mmselfsup/models/algorithms/mae.py index 1e89ca4b..b1e8daca 100644 --- a/mmselfsup/models/algorithms/mae.py +++ b/mmselfsup/models/algorithms/mae.py @@ -17,7 +17,9 @@ class MAE(BaseModel): `_. """ - def extract_feat(self, inputs: List[torch.Tensor], + def extract_feat(self, + inputs: List[torch.Tensor], + data_samples: Optional[List[SelfSupDataSample]] = None, **kwarg) -> Tuple[torch.Tensor]: """The forward function to extract features from neck. @@ -27,33 +29,33 @@ class MAE(BaseModel): Returns: Tuple[torch.Tensor]: Neck outputs. """ - latent, _, ids_restore = self.backbone(inputs[0]) + latent, mask, ids_restore = self.backbone(inputs[0]) pred = self.neck(latent, ids_restore) + self.mask = mask return pred - def predict(self, - inputs: List[torch.Tensor], - data_samples: Optional[List[SelfSupDataSample]] = None, - **kwargs) -> SelfSupDataSample: - """The forward function in testing. It is mainly for image - reconstruction. + def reconstruct(self, + features: torch.Tensor, + data_samples: Optional[List[SelfSupDataSample]] = None, + **kwargs) -> SelfSupDataSample: + """The function is for image reconstruction. Args: - inputs (List[torch.Tensor]): The input images. + features (torch.Tensor): The input images. data_samples (List[SelfSupDataSample]): All elements required during the forward function. Returns: SelfSupDataSample: The prediction from model. """ + mean = kwargs['mean'] + std = kwargs['std'] + features = features * std + mean - latent, mask, ids_restore = self.backbone(inputs[0]) - pred = self.neck(latent, ids_restore) - - pred = self.head.unpatchify(pred) + pred = self.head.unpatchify(features) pred = torch.einsum('nchw->nhwc', pred).detach().cpu() - mask = mask.detach() + mask = self.mask.detach() mask = mask.unsqueeze(-1).repeat(1, 1, self.head.patch_size**2 * 3) # (N, H*W, p*p*3) mask = self.head.unpatchify(mask) # 1 is removing, 0 is keeping diff --git a/mmselfsup/models/algorithms/maskfeat.py b/mmselfsup/models/algorithms/maskfeat.py index 4ffcf3b3..0a2b2f15 100644 --- a/mmselfsup/models/algorithms/maskfeat.py +++ b/mmselfsup/models/algorithms/maskfeat.py @@ -1,7 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple import torch +from mmengine.structures import BaseDataElement from mmselfsup.registry import MODELS from mmselfsup.structures import SelfSupDataSample @@ -16,8 +17,10 @@ class MaskFeat(BaseModel): Pre-Training `_. """ - def extract_feat(self, inputs: List[torch.Tensor], + def extract_feat(self, + inputs: List[torch.Tensor], data_samples: List[SelfSupDataSample], + compute_hog: bool = True, **kwarg) -> Tuple[torch.Tensor]: """The forward function to extract features from neck. @@ -25,15 +28,30 @@ class MaskFeat(BaseModel): inputs (List[torch.Tensor]): The input images and mask. data_samples (List[SelfSupDataSample]): All elements required during the forward function. + compute_hog (bool): Whether to compute hog during extraction. If + True, the batch size of inputs need to be 1. Defaults to True. Returns: Tuple[torch.Tensor]: Neck outputs. """ img = inputs[0] - mask = torch.stack( + self.mask = torch.stack( [data_sample.mask.value for data_sample in data_samples]) - latent = self.backbone(img, mask) - return latent + latent = self.backbone(img, self.mask) + B, L, C = latent.shape + pred = self.neck([latent.view(B * L, C)]) + pred = pred[0].view(B, L, -1) + + # compute hog + if compute_hog: + assert img.size(0) == 1, 'Currently only support batch size 1.' + _ = self.target_generator(img) + hog_image = torch.from_numpy( + self.target_generator.generate_hog_image( + self.target_generator.out)).unsqueeze(0).unsqueeze(0) + self.target = hog_image.expand(-1, 3, -1, -1) + + return pred[:, 1:, :] # remove cls token def loss(self, inputs: List[torch.Tensor], data_samples: List[SelfSupDataSample], @@ -62,3 +80,60 @@ class MaskFeat(BaseModel): loss = self.head(pred, hog, mask) losses = dict(loss=loss) return losses + + def reconstruct(self, + features: List[torch.Tensor], + data_samples: Optional[List[SelfSupDataSample]] = None, + **kwargs) -> SelfSupDataSample: + """The function is for image reconstruction. + + Args: + features (List[torch.Tensor]): The input images. + data_samples (List[SelfSupDataSample]): All elements required + during the forward function. + + Returns: + SelfSupDataSample: The prediction from model. + """ + + # recover to HOG description from feature embeddings + unfold_size = self.target_generator.unfold_size + tmp4 = features.unflatten(2, + (features.shape[2] // unfold_size**2, + unfold_size, unfold_size)) # 1,196,27,2,2 + tmp3 = tmp4.unflatten(1, self.backbone.patch_resolution) + + b, p1, p2, c_nbins, _, _ = tmp3.shape # 1,14,14,27,2,2 + tmp2 = tmp3.permute(0, 1, 2, 5, 3, 4).reshape( + (b, p1, p2 * unfold_size, c_nbins, unfold_size)) + tmp1 = tmp2.permute(0, 1, 4, 2, 3).reshape( + (b, p1 * unfold_size, p2 * unfold_size, c_nbins)) + tmp0 = tmp1.permute(0, 3, 1, 2) # 1,27,28,28 + hog_out = tmp0.unflatten(1, + (int(c_nbins // self.target_generator.nbins), + self.target_generator.nbins)) # 1,3,9,28,28 + + # generate predction of HOG + hog_image = torch.from_numpy( + self.target_generator.generate_hog_image(hog_out)) + hog_image = hog_image.unsqueeze(0).unsqueeze(0) + pred = torch.einsum('nchw->nhwc', hog_image).expand(-1, -1, -1, + 3).detach().cpu() + + # transform patch mask to pixel mask + mask = self.mask + patch_dim_1 = int(self.backbone.patch_embed.init_input_size[0] // + self.backbone.patch_resolution[0]) + patch_dim_2 = int(self.backbone.patch_embed.init_input_size[1] // + self.backbone.patch_resolution[1]) + mask = mask.repeat_interleave( + patch_dim_1, dim=1).repeat_interleave( + patch_dim_2, dim=2).unsqueeze(-1).repeat(1, 1, 1, 3) + # 1 is removing, 0 is keeping + mask = mask.detach().cpu() + + results = SelfSupDataSample() + results.mask = BaseDataElement(**dict(value=mask)) + results.pred = BaseDataElement(**dict(value=pred)) + + return results diff --git a/mmselfsup/models/algorithms/simmim.py b/mmselfsup/models/algorithms/simmim.py index 9a306a8e..5337f5d9 100644 --- a/mmselfsup/models/algorithms/simmim.py +++ b/mmselfsup/models/algorithms/simmim.py @@ -1,7 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List +from typing import Dict, List, Optional import torch +from mmengine.structures import BaseDataElement from mmselfsup.registry import MODELS from mmselfsup.structures import SelfSupDataSample @@ -33,6 +34,7 @@ class SimMIM(BaseModel): [data_sample.mask.value for data_sample in data_samples]) img_latent = self.backbone(inputs[0], mask) feat = self.neck(img_latent[0]) + self.mask = mask return feat def loss(self, inputs: List[torch.Tensor], @@ -58,3 +60,37 @@ class SimMIM(BaseModel): losses = dict(loss=loss) return losses + + def reconstruct(self, + features: torch.Tensor, + data_samples: Optional[List[SelfSupDataSample]] = None, + **kwargs) -> SelfSupDataSample: + """The function is for image reconstruction. + + Args: + features (torch.Tensor): The input images. + data_samples (List[SelfSupDataSample]): All elements required + during the forward function. + + Returns: + SelfSupDataSample: The prediction from model. + """ + pred = torch.einsum('nchw->nhwc', features).detach().cpu() + + # transform patch mask to pixel mask + mask = self.mask.detach() + p1 = int(self.backbone.patch_embed.init_input_size[0] // + self.backbone.patch_resolution[0]) + p2 = int(self.backbone.patch_embed.init_input_size[1] // + self.backbone.patch_resolution[1]) + mask = mask.repeat_interleave( + p1, dim=1).repeat_interleave( + p2, dim=2).unsqueeze(-1).repeat(1, 1, 1, 3) # (N, H, W, 3) + # 1 is removing, 0 is keeping + mask = mask.detach().cpu() + + results = SelfSupDataSample() + results.mask = BaseDataElement(**dict(value=mask)) + results.pred = BaseDataElement(**dict(value=pred)) + + return results diff --git a/mmselfsup/models/target_generators/hog_generator.py b/mmselfsup/models/target_generators/hog_generator.py index 2da06966..53d8515f 100644 --- a/mmselfsup/models/target_generators/hog_generator.py +++ b/mmselfsup/models/target_generators/hog_generator.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. import math +import cv2 +import numpy as np import torch import torch.nn.functional as F from mmengine.model import BaseModule @@ -13,10 +15,10 @@ class HOGGenerator(BaseModule): """Generate HOG feature for images. This module is used in MaskFeat to generate HOG feature. The code is - modified from this `file + modified from file `slowfast/models/operators.py `_. - Here is the link `HOG wikipedia - `_. + Here is the link of `HOG wikipedia + `_. Args: nbins (int): Number of bin. Defaults to 9. @@ -61,12 +63,12 @@ class HOGGenerator(BaseModule): def _reshape(self, hog_feat: torch.Tensor) -> torch.Tensor: """Reshape HOG Features for output.""" hog_feat = hog_feat.flatten(1, 2) - unfold_size = hog_feat.shape[-1] // 14 - hog_feat = ( - hog_feat.permute(0, 2, 3, - 1).unfold(1, unfold_size, unfold_size).unfold( - 2, unfold_size, - unfold_size).flatten(1, 2).flatten(2)) + self.unfold_size = hog_feat.shape[-1] // 14 + hog_feat = hog_feat.permute(0, 2, 3, 1) + hog_feat = hog_feat.unfold(1, self.unfold_size, + self.unfold_size).unfold( + 2, self.unfold_size, self.unfold_size) + hog_feat = hog_feat.flatten(1, 2).flatten(2) return hog_feat @torch.no_grad() @@ -80,6 +82,7 @@ class HOGGenerator(BaseModule): torch.Tensor: Hog features. """ # input is RGB image with shape [B 3 H W] + self.h, self.w = x.size(-2), x.size(-1) x = F.pad(x, pad=(1, 1, 1, 1), mode='reflect') gx_rgb = F.conv2d( x, self.weight_x, bias=None, stride=1, padding=0, groups=3) @@ -112,6 +115,38 @@ class HOGGenerator(BaseModule): out = out.unfold(4, self.pool, self.pool) out = out.sum(dim=[-1, -2]) - out = F.normalize(out, p=2, dim=2) + self.out = F.normalize(out, p=2, dim=2) - return self._reshape(out) + return self._reshape(self.out) + + def generate_hog_image(self, hog_out: torch.Tensor) -> np.ndarray: + """Generate HOG image according to HOG features.""" + assert hog_out.size(0) == 1 and hog_out.size(1) == 3, \ + 'Check the input batch size and the channcel number, only support'\ + '"batch_size = 1".' + hog_image = np.zeros([self.h, self.w]) + cell_gradient = np.array(hog_out.mean(dim=1).squeeze().detach().cpu()) + cell_width = self.pool / 2 + max_mag = np.array(cell_gradient).max() + angle_gap = 360 / self.nbins + + for x in range(cell_gradient.shape[1]): + for y in range(cell_gradient.shape[2]): + cell_grad = cell_gradient[:, x, y] + cell_grad /= max_mag + angle = 0 + for magnitude in cell_grad: + angle_radian = math.radians(angle) + x1 = int(x * self.pool + + magnitude * cell_width * math.cos(angle_radian)) + y1 = int(y * self.pool + + magnitude * cell_width * math.sin(angle_radian)) + x2 = int(x * self.pool - + magnitude * cell_width * math.cos(angle_radian)) + y2 = int(y * self.pool - + magnitude * cell_width * math.sin(angle_radian)) + magnitude = 0 if magnitude < 0 else magnitude + cv2.line(hog_image, (y1, x1), (y2, x2), + int(255 * math.sqrt(magnitude))) + angle += angle_gap + return hog_image diff --git a/tests/test_apis/test_inference.py b/tests/test_apis/test_inference.py index aa37cc0b..6f3a1c4f 100644 --- a/tests/test_apis/test_inference.py +++ b/tests/test_apis/test_inference.py @@ -27,10 +27,10 @@ class ExampleModel(BaseModel): super(ExampleModel, self).__init__(backbone=backbone) self.layer = nn.Linear(1, 1) - def predict(self, - inputs: List[torch.Tensor], - data_samples: Optional[List[SelfSupDataSample]] = None, - **kwargs) -> SelfSupDataSample: + def extract_feat(self, + inputs: List[torch.Tensor], + data_samples: Optional[List[SelfSupDataSample]] = None, + **kwargs) -> SelfSupDataSample: out = self.layer(inputs[0]) return out diff --git a/tests/test_models/test_algorithms/test_mae.py b/tests/test_models/test_algorithms/test_mae.py index 51546e00..f1c08758 100644 --- a/tests/test_models/test_algorithms/test_mae.py +++ b/tests/test_models/test_algorithms/test_mae.py @@ -48,9 +48,14 @@ def test_mae(): fake_outputs = alg(fake_batch_inputs, fake_data_samples, mode='loss') assert isinstance(fake_outputs['loss'].item(), float) + # test extraction fake_feats = alg(fake_batch_inputs, fake_data_samples, mode='tensor') assert list(fake_feats.shape) == [2, 196, 768] - results = alg(fake_batch_inputs, fake_data_samples, mode='predict') + # test reconstruct + mean = fake_feats.mean(dim=-1, keepdim=True) + std = (fake_feats.var(dim=-1, keepdim=True) + 1.e-6)**.5 + results = alg.reconstruct( + fake_feats, fake_data_samples, mean=mean, std=std) assert list(results.mask.value.shape) == [2, 224, 224, 3] assert list(results.pred.value.shape) == [2, 224, 224, 3] diff --git a/tests/test_models/test_algorithms/test_maskfeat.py b/tests/test_models/test_algorithms/test_maskfeat.py index fa9f3e01..72e4a3eb 100644 --- a/tests/test_models/test_algorithms/test_maskfeat.py +++ b/tests/test_models/test_algorithms/test_maskfeat.py @@ -5,6 +5,7 @@ import platform import pytest import torch from mmengine.structures import InstanceData +from mmengine.utils import digit_version from mmselfsup.models.algorithms.maskfeat import MaskFeat from mmselfsup.structures import SelfSupDataSample @@ -22,6 +23,9 @@ target_generator = dict( type='HOGGenerator', nbins=9, pool=8, gaussian_window=16) +@pytest.mark.skipif( + digit_version(torch.__version__) < digit_version('1.7.0'), + reason='torch version') @pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit') def test_maskfeat(): data_preprocessor = { @@ -42,13 +46,19 @@ def test_maskfeat(): fake_mask = InstanceData(value=torch.rand((14, 14)).bool()) fake_data_sample.mask = fake_mask fake_data = { - 'inputs': [torch.randn((2, 3, 224, 224))], - 'data_sample': [fake_data_sample for _ in range(2)] + 'inputs': [torch.randn((1, 3, 224, 224))], + 'data_sample': [fake_data_sample for _ in range(1)] } fake_batch_inputs, fake_data_samples = alg.data_preprocessor(fake_data) fake_outputs = alg(fake_batch_inputs, fake_data_samples, mode='loss') assert isinstance(fake_outputs['loss'].item(), float) + # test extraction fake_feats = alg.extract_feat(fake_batch_inputs, fake_data_samples) - assert list(fake_feats.shape) == [2, 197, 768] + assert list(fake_feats.shape) == [1, 196, 108] + + # test reconstruction + results = alg.reconstruct(fake_feats, fake_data_samples) + assert list(results.mask.value.shape) == [1, 224, 224, 3] + assert list(results.pred.value.shape) == [1, 224, 224, 3] diff --git a/tests/test_models/test_algorithms/test_simmim.py b/tests/test_models/test_algorithms/test_simmim.py index da54ffce..f89e2bcc 100644 --- a/tests/test_models/test_algorithms/test_simmim.py +++ b/tests/test_models/test_algorithms/test_simmim.py @@ -50,5 +50,10 @@ def test_simmim(): # test extract_feat fake_inputs, fake_data_samples = model.data_preprocessor(fake_data) - fake_feat = model.extract_feat(fake_inputs, fake_data_samples) - assert list(fake_feat.shape) == [2, 3, 192, 192] + fake_feats = model.extract_feat(fake_inputs, fake_data_samples) + assert list(fake_feats.shape) == [2, 3, 192, 192] + + # test reconstruct + results = model.reconstruct(fake_feats, fake_data_samples) + assert list(results.mask.value.shape) == [2, 192, 192, 3] + assert list(results.pred.value.shape) == [2, 192, 192, 3] diff --git a/tests/test_models/test_target_generators/test_hog_generator.py b/tests/test_models/test_target_generators/test_hog_generator.py index d7ad3779..99a8dc7b 100644 --- a/tests/test_models/test_target_generators/test_hog_generator.py +++ b/tests/test_models/test_target_generators/test_hog_generator.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import pytest import torch from mmselfsup.models.target_generators import HOGGenerator @@ -10,3 +11,10 @@ def test_hog_generator(): fake_input = torch.randn((2, 3, 224, 224)) fake_output = hog_generator(fake_input) assert list(fake_output.shape) == [2, 196, 108] + + fake_hog_out = hog_generator.out[0].unsqueeze(0) + fake_hog_img = hog_generator.generate_hog_image(fake_hog_out) + assert fake_hog_img.shape == (224, 224) + + with pytest.raises(AssertionError): + fake_hog_img = hog_generator.generate_hog_image(hog_generator.out[0]) diff --git a/tools/analysis_tools/mae_visualization.py b/tools/analysis_tools/mae_visualization.py deleted file mode 100644 index 2f6902ce..00000000 --- a/tools/analysis_tools/mae_visualization.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# Modified from https://colab.research.google.com/github/facebookresearch/mae -# /blob/main/demo/mae_visualize.ipynb -from argparse import ArgumentParser -from typing import Tuple - -import matplotlib.pyplot as plt -import numpy as np -import torch -from mmengine.dataset import Compose, default_collate - -from mmselfsup.apis import inference_model, init_model -from mmselfsup.registry import MODELS -from mmselfsup.utils import register_all_modules - -imagenet_mean = np.array([0.485, 0.456, 0.406]) -imagenet_std = np.array([0.229, 0.224, 0.225]) - - -def show_image(image: torch.Tensor, title: str = '') -> None: - # image is [H, W, 3] - assert image.shape[2] == 3 - image = torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, - 255).int() - plt.imshow(image) - plt.title(title, fontsize=16) - plt.axis('off') - return - - -def save_images(x: torch.Tensor, im_masked: torch.Tensor, y: torch.Tensor, - im_paste: torch.Tensor, out_file: str) -> None: - # make the plt figure larger - plt.rcParams['figure.figsize'] = [24, 6] - - plt.subplot(1, 4, 1) - show_image(x, 'original') - - plt.subplot(1, 4, 2) - show_image(im_masked, 'masked') - - plt.subplot(1, 4, 3) - show_image(y, 'reconstruction') - - plt.subplot(1, 4, 4) - show_image(im_paste, 'reconstruction + visible') - - plt.savefig(out_file) - - -def post_process( - x: torch.Tensor, y: torch.Tensor, mask: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - x = torch.einsum('nchw->nhwc', x.cpu()) - # masked image - im_masked = x * (1 - mask) - # MAE reconstruction pasted with visible patches - im_paste = x * (1 - mask) + y * mask - return x[0], im_masked[0], y[0], im_paste[0] - - -def main(): - parser = ArgumentParser() - parser.add_argument('img_path', help='Image file path') - parser.add_argument('config', help='MAE Config file') - parser.add_argument('checkpoint', help='Checkpoint file') - parser.add_argument('out_file', help='The output image file path') - parser.add_argument( - '--device', default='cuda:0', help='Device used for inference') - args = parser.parse_args() - - register_all_modules() - - # build the model from a config file and a checkpoint file - model = init_model(args.config, args.checkpoint, device=args.device) - print('Model loaded.') - - # make random mask reproducible (comment out to make it change) - torch.manual_seed(2) - print('MAE with pixel reconstruction:') - - model.cfg.test_dataloader = dict( - dataset=dict(pipeline=[ - dict( - type='LoadImageFromFile', - file_client_args=dict(backend='disk')), - dict(type='Resize', scale=(224, 224)), - dict(type='PackSelfSupInputs', meta_keys=['img_path']) - ])) - - results = inference_model(model, args.img_path) - - cfg = model.cfg - test_pipeline = Compose(cfg.test_dataloader.dataset.pipeline) - data_preprocessor = MODELS.build(cfg.model.data_preprocessor) - data = dict(img_path=args.img_path) - data = test_pipeline(data) - data = default_collate([data]) - img, _ = data_preprocessor(data, False) - - x, im_masked, y, im_paste = post_process(img[0], results.pred.value, - results.mask.value) - save_images(x, im_masked, y, im_paste, args.out_file) - - -if __name__ == '__main__': - main() diff --git a/tools/analysis_tools/visualize_reconstruction.py b/tools/analysis_tools/visualize_reconstruction.py new file mode 100644 index 00000000..c4acd212 --- /dev/null +++ b/tools/analysis_tools/visualize_reconstruction.py @@ -0,0 +1,178 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://colab.research.google.com/github/facebookresearch/mae +# /blob/main/demo/mae_visualize.ipynb +import random +from argparse import ArgumentParser +from typing import Tuple + +import matplotlib.pyplot as plt +import numpy as np +import torch +from mmengine.dataset import Compose, default_collate + +from mmselfsup.apis import inference_model, init_model +from mmselfsup.utils import register_all_modules + +imagenet_mean = np.array([0.485, 0.456, 0.406]) +imagenet_std = np.array([0.229, 0.224, 0.225]) + + +def show_image(img: torch.Tensor, title: str = '') -> None: + # image is [H, W, 3] + assert img.shape[2] == 3 + + plt.imshow(img) + plt.title(title, fontsize=16) + plt.axis('off') + return + + +def save_images(original_img: torch.Tensor, img_masked: torch.Tensor, + pred_img: torch.Tensor, img_paste: torch.Tensor, + out_file: str) -> None: + # make the plt figure larger + plt.rcParams['figure.figsize'] = [24, 6] + + plt.subplot(1, 4, 1) + show_image(original_img, 'original') + + plt.subplot(1, 4, 2) + show_image(img_masked, 'masked') + + plt.subplot(1, 4, 3) + show_image(pred_img, 'reconstruction') + + plt.subplot(1, 4, 4) + show_image(img_paste, 'reconstruction + visible') + + plt.savefig(out_file) + print(f'Images are saved to {out_file}') + + +def recover_norm(img: torch.Tensor, + mean: np.ndarray = imagenet_mean, + std: np.ndarray = imagenet_std): + if mean is not None and std is not None: + img = torch.clip((img * std + mean) * 255, 0, 255).int() + return img + + +def post_process( + original_img: torch.Tensor, + pred_img: torch.Tensor, + mask: torch.Tensor, + mean: np.ndarray = imagenet_mean, + std: np.ndarray = imagenet_std +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # channel conversion + original_img = torch.einsum('nchw->nhwc', original_img.cpu()) + # masked image + img_masked = original_img * (1 - mask) + # reconstructed image pasted with visible patches + img_paste = original_img * (1 - mask) + pred_img * mask + + # muptiply std and add mean to each image + original_img = recover_norm(original_img[0]) + img_masked = recover_norm(img_masked[0]) + + pred_img = recover_norm(pred_img[0]) + img_paste = recover_norm(img_paste[0]) + + return original_img, img_masked, pred_img, img_paste + + +def main(): + parser = ArgumentParser() + parser.add_argument('config', help='Model config file') + parser.add_argument('--checkpoint', help='Checkpoint file') + parser.add_argument('--img-path', help='Image file path') + parser.add_argument('--out-file', help='The output image file path') + parser.add_argument( + '--use-vis-pipeline', + action='store_true', + help='Use vis_pipeline defined in config. For some algorithms, such ' + 'as SimMIM and MaskFeat, they generate mask in data pipeline, thus ' + 'the visualization process applies vis_pipeline in config to obtain ' + 'the mask.') + parser.add_argument( + '--norm-pix', + action='store_true', + help='MAE uses `norm_pix_loss` for optimization in pre-training, thus ' + 'the visualization process also need to compute mean and std of each ' + 'patch embedding while reconstructing the original images.') + parser.add_argument( + '--target-generator', + action='store_true', + help='Some algorithms use target_generator for optimization in ' + 'pre-training, such as MaskFeat, thus the visualization process could ' + 'turn this on to visualize the target instead of RGB image.') + parser.add_argument( + '--device', default='cuda:0', help='Device used for inference') + parser.add_argument( + '--seed', + type=int, + default=0, + help='The random seed for visualization') + args = parser.parse_args() + + register_all_modules() + + # build the model from a config file and a checkpoint file + model = init_model(args.config, args.checkpoint, device=args.device) + print('Model loaded.') + + # make random mask reproducible (comment out to make it change) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + + print('Reconstruction visualization.') + + if args.use_vis_pipeline: + model.cfg.test_dataloader = dict( + dataset=dict(pipeline=model.cfg.vis_pipeline)) + else: + model.cfg.test_dataloader = dict( + dataset=dict(pipeline=[ + dict( + type='LoadImageFromFile', + file_client_args=dict(backend='disk')), + dict(type='Resize', scale=(224, 224), backend='pillow'), + dict(type='PackSelfSupInputs', meta_keys=['img_path']) + ])) + + # get original image + vis_pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline) + data = dict(img_path=args.img_path) + data = vis_pipeline(data) + data = default_collate([data]) + img, _ = model.data_preprocessor(data, False) + + if args.norm_pix: + # for MAE reconstruction + img_embedding = model.head.patchify(img[0]) + # normalize the target image + mean = img_embedding.mean(dim=-1, keepdim=True) + std = (img_embedding.var(dim=-1, keepdim=True) + 1.e-6)**.5 + else: + mean = imagenet_mean + std = imagenet_std + + # get reconstruction image + features = inference_model(model, args.img_path) + results = model.reconstruct(features, mean=mean, std=std) + + original_target = model.target if args.target_generator else img[0] + + original_img, img_masked, pred_img, img_paste = post_process( + original_target, + results.pred.value, + results.mask.value, + mean=mean, + std=std) + + save_images(original_img, img_masked, pred_img, img_paste, args.out_file) + + +if __name__ == '__main__': + main()