From 3f530f085e85be29fb0d12d5a44f521e01aeca7b Mon Sep 17 00:00:00 2001 From: RenQin <45731309+soonera@users.noreply.github.com> Date: Wed, 27 Jul 2022 16:03:57 +0800 Subject: [PATCH] [Tools]: MAE Reconstructed Image Visualization (#376) * [Tools]: MAE Reconstructed Image Visualization] * [Fix]: fix docstring and type hint * [Fix]: fix docstring in MAE clsss * [Fix]: fix docstring in MAE clsss * [Fix]: fix type hint * [Fix]: fix type hint and docstring * [refactor]: refactor super init --- demo/mae_visualization.ipynb | 268 ++++++++++++++++++ docs/en/get_started.md | 21 ++ docs/zh_cn/get_started.md | 21 ++ mmselfsup/apis/__init__.py | 6 +- mmselfsup/apis/inference.py | 85 ++++++ mmselfsup/models/algorithms/mae.py | 66 +++-- mmselfsup/models/heads/mae_head.py | 36 ++- tests/test_apis/test_inference.py | 59 ++++ tests/test_models/test_algorithms/test_mae.py | 3 + tests/test_models/test_heads.py | 4 + tools/misc/mae_visualization.py | 93 ++++++ 11 files changed, 638 insertions(+), 24 deletions(-) create mode 100644 demo/mae_visualization.ipynb create mode 100644 mmselfsup/apis/inference.py create mode 100644 tests/test_apis/test_inference.py create mode 100644 tools/misc/mae_visualization.py diff --git a/demo/mae_visualization.ipynb b/demo/mae_visualization.ipynb new file mode 100644 index 00000000..1d24ee3d --- /dev/null +++ b/demo/mae_visualization.ipynb @@ -0,0 +1,268 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Copyright (c) OpenMMLab. All rights reserved.\n", + "\n", + "Copyright (c) Meta Platforms, Inc. and affiliates.\n", + "\n", + "Modified from https://colab.research.google.com/github/facebookresearch/mae/blob/main/demo/mae_visualize.ipynb\n", + "\n", + "## Masked Autoencoders: Visualization Demo\n", + "\n", + "This is a visualization demo using our pre-trained MAE models. No GPU is needed." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Prepare\n", + "Check environment. Install packages if in Colab." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "import os\n", + "import requests\n", + "\n", + "import torch\n", + "import numpy as np\n", + "\n", + "import matplotlib.pyplot as plt\n", + "from PIL import Image\n", + "\n", + "from mmselfsup.models import build_algorithm\n", + "\n", + "# check whether run in Colab\n", + "if 'google.colab' in sys.modules:\n", + " print('Running in Colab.')\n", + " !pip install openmim\n", + " !mim install mmcv-full\n", + " !git clone https://github.com/open-mmlab/mmselfsup.git\n", + " %cd mmselfsup/\n", + " !pip install -e .\n", + " sys.path.append('./mmselfsup')\n", + " %cd demo\n", + "else:\n", + " sys.path.append('..')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define utils" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# define the utils\n", + "\n", + "imagenet_mean = np.array([0.485, 0.456, 0.406])\n", + "imagenet_std = np.array([0.229, 0.224, 0.225])\n", + "\n", + "def show_image(image, title=''):\n", + " # image is [H, W, 3]\n", + " assert image.shape[2] == 3\n", + " image = torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int()\n", + " plt.imshow(image)\n", + " plt.title(title, fontsize=16)\n", + " plt.axis('off')\n", + " return\n", + "\n", + "\n", + "def show_images(x, im_masked, y, im_paste):\n", + " # make the plt figure larger\n", + " plt.rcParams['figure.figsize'] = [24, 6]\n", + "\n", + " plt.subplot(1, 4, 1)\n", + " show_image(x, \"original\")\n", + "\n", + " plt.subplot(1, 4, 2)\n", + " show_image(im_masked, \"masked\")\n", + "\n", + " plt.subplot(1, 4, 3)\n", + " show_image(y, \"reconstruction\")\n", + "\n", + " plt.subplot(1, 4, 4)\n", + " show_image(im_paste, \"reconstruction + visible\")\n", + "\n", + " plt.show()\n", + "\n", + "\n", + "def post_process(x, y, mask):\n", + " x = torch.einsum('nchw->nhwc', x.cpu())\n", + " # masked image\n", + " im_masked = x * (1 - mask)\n", + " # MAE reconstruction pasted with visible patches\n", + " im_paste = x * (1 - mask) + y * mask\n", + " return x[0], im_masked[0], y[0], im_paste[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load an image" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "# load an image\n", + "img_url = 'https://download.openmmlab.com/mmselfsup/mae/fox.jpg'\n", + "img_pil = Image.open(requests.get(img_url, stream=True).raw)\n", + "img = img_pil.resize((224, 224))\n", + "img = np.array(img) / 255.\n", + "\n", + "assert img.shape == (224, 224, 3)\n", + "\n", + "# normalize by ImageNet mean and std\n", + "img = img - imagenet_mean\n", + "img = img / imagenet_std\n", + "\n", + "plt.rcParams['figure.figsize'] = [5, 5]\n", + "show_image(torch.tensor(img))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load a pre-trained MAE model" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile ../configs/selfsup/mae/mae_visualization.py\n", + "model = dict(\n", + " type='MAE',\n", + " backbone=dict(type='MAEViT', arch='l', patch_size=16, mask_ratio=0.75),\n", + " neck=dict(\n", + " type='MAEPretrainDecoder',\n", + " patch_size=16,\n", + " in_chans=3,\n", + " embed_dim=1024,\n", + " decoder_embed_dim=512,\n", + " decoder_depth=8,\n", + " decoder_num_heads=16,\n", + " mlp_ratio=4.,\n", + " ),\n", + " head=dict(type='MAEPretrainHead', norm_pix=True, patch_size=16))\n", + "\n", + "img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", + "# dataset summary\n", + "data = dict(\n", + " test=dict(\n", + " pipeline = [\n", + " dict(type='Resize', size=(224, 224)),\n", + " dict(type='ToTensor'),\n", + " dict(type='Normalize', **img_norm_cfg),]\n", + " ))" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "# This is an MAE model trained with pixels as targets for visualization (ViT-large, training mask ratio=0.75)\n", + "\n", + "# download checkpoint if not exist\n", + "# This ckpt is converted from https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth\n", + "!wget -nc https://download.openmmlab.com/mmselfsup/mae/mae_visualize_vit_large.pth" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "from mmselfsup.apis import init_model\n", + "ckpt_path = \"mae_visualize_vit_large.pth\"\n", + "model = init_model('../configs/selfsup/mae/mae_visualization.py', ckpt_path, device='cpu')\n", + "print('Model loaded.')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Run MAE on the image" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "# make random mask reproducible (comment out to make it change)\n", + "torch.manual_seed(2)\n", + "print('MAE with pixel reconstruction:')\n", + "\n", + "from mmselfsup.apis import inference_model\n", + "\n", + "img_url = 'https://download.openmmlab.com/mmselfsup/mae/fox.jpg'\n", + "img = Image.open(requests.get(img_url, stream=True).raw)\n", + "img, (mask, pred) = inference_model(model, img)\n", + "x, im_masked, y, im_paste = post_process(img, pred, mask)\n", + "show_images(x, im_masked, y, im_paste)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + }, + "vscode": { + "interpreter": { + "hash": "3e4aeeccd14e965f43d0896afbaf8d71604e66b8605affbaa33ec76aa4083757" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/en/get_started.md b/docs/en/get_started.md index 66f62f0b..efc20451 100644 --- a/docs/en/get_started.md +++ b/docs/en/get_started.md @@ -166,6 +166,27 @@ Arguments: - `WORK_DIR`: the directory to save the results of visualization. - `[optional arguments]`: for optional arguments, you can refer to [visualize_tsne.py](https://github.com/open-mmlab/mmselfsup/blob/master/tools/analysis_tools/visualize_tsne.py) +### MAE Visualization + +We provide a tool to visualize the mask and reconstruction image of MAE model. + +```shell +python tools/misc/mae_visualization.py ${IMG} ${CONFIG_FILE} ${CKPT_PATH} --device ${DEVICE} +``` + +参数: + +- `IMG`: an image path used for visualization. +- `CONFIG_FILE`: config file for the pre-trained model. +- `CKPT_PATH`: the path of model's checkpoint. +- `DEVICE`: device used for inference. + +An example: + +```shell +python tools/misc/mae_visualization.py tests/data/color.jpg configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py mae_epoch_400.pth --device 'cuda:0' +``` + ### Reproducibility If you want to make your performance exactly reproducible, please switch on `--deterministic` to train the final model to be published. Note that this flag will switch off `torch.backends.cudnn.benchmark` and slow down the training speed. diff --git a/docs/zh_cn/get_started.md b/docs/zh_cn/get_started.md index 3c36646a..51525f6a 100644 --- a/docs/zh_cn/get_started.md +++ b/docs/zh_cn/get_started.md @@ -164,6 +164,27 @@ python tools/analysis_tools/visualize_tsne.py ${CONFIG_FILE} --checkpoint ${CKPT - `WORK_DIR`: 保存可视化结果的路径. - `[optional arguments]`: 可选参数,具体可以参考 [visualize_tsne.py](../../tools/analysis_tools/visualize_tsne.py) +### MAE 可视化 + +我们提供了一个对 MAE 掩码效果和重建效果可视化可视化的方法: + +```shell +python tools/misc/mae_visualization.py ${IMG} ${CONFIG_FILE} ${CKPT_PATH} --device ${DEVICE} +``` + +参数: + +- `IMG`: 用于可视化的图片 +- `CONFIG_FILE`: 训练预训练模型的参数配置文件. +- `CKPT_PATH`: 预训练模型的路径. +- `DEVICE`: 用于推理的设备. + +示例: + +```shell +python tools/misc/mae_visualization.py tests/data/color.jpg configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py mae_epoch_400.pth --device 'cuda:0' +``` + ### 可复现性 如果您想确保模型精度的可复现性,您可以设置 `--deterministic` 参数。但是,开启 `--deterministic` 意味着关闭 `torch.backends.cudnn.benchmark`, 所以会使模型的训练速度变慢。 diff --git a/mmselfsup/apis/__init__.py b/mmselfsup/apis/__init__.py index 51d0cd25..412f1a05 100644 --- a/mmselfsup/apis/__init__.py +++ b/mmselfsup/apis/__init__.py @@ -1,4 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .inference import inference_model, init_model from .train import init_random_seed, set_random_seed, train_model -__all__ = ['init_random_seed', 'set_random_seed', 'train_model'] +__all__ = [ + 'init_random_seed', 'inference_model', 'set_random_seed', 'train_model', + 'init_model' +] diff --git a/mmselfsup/apis/inference.py b/mmselfsup/apis/inference.py new file mode 100644 index 00000000..dbc8571c --- /dev/null +++ b/mmselfsup/apis/inference.py @@ -0,0 +1,85 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple, Union + +import mmcv +import torch +from mmcv.parallel import collate, scatter +from mmcv.runner import load_checkpoint +from mmcv.utils import build_from_cfg +from PIL import Image +from torch import nn +from torchvision.transforms import Compose + +from mmselfsup.datasets import PIPELINES +from mmselfsup.models import build_algorithm + + +def init_model(config: Union[str, mmcv.Config], + checkpoint: Optional[str] = None, + device: str = 'cuda:0', + options: Optional[dict] = None) -> nn.Module: + """Initialize an model from config file. + + Args: + config (str or :obj:``mmcv.Config``): Config file path or the config + object. + checkpoint (str, optional): Checkpoint path. If left as None, the model + will not load any weights. Defaults to None. + device (str): The device where the model will be put on. + Default to 'cuda:0'. + options (dict, optional): Options to override some settings in the used + config. Defaults to None. + Returns: + nn.Module: The initialized model. + """ + if isinstance(config, str): + config = mmcv.Config.fromfile(config) + elif not isinstance(config, mmcv.Config): + raise TypeError('config must be a filename or Config object, ' + f'but got {type(config)}') + if options is not None: + config.merge_from_dict(options) + model = build_algorithm(config.model) + if checkpoint is not None: + # Mapping the weights to GPU may cause unexpected video memory leak + # which refers to https://github.com/open-mmlab/mmdetection/pull/6405 + checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') + model.cfg = config # save the config in the model for convenience + model.to(device) + model.eval() + return model + + +def inference_model( + model: nn.Module, + data: Image) -> Tuple[torch.Tensor, Union[torch.Tensor, dict]]: + """Inference an image with the model. + Args: + model (nn.Module): The loaded model. + data (PIL.Image): The loaded image. + Returns: + Tuple[torch.Tensor, Union(torch.Tensor, dict)]: Output of model + inference. + - data (torch.Tensor): The loaded image to input model. + - output (torch.Tensor, dict[str, torch.Tensor]): the output + of test model. + """ + cfg = model.cfg + device = next(model.parameters()).device # model device + # build the data pipeline + test_pipeline = [ + build_from_cfg(p, PIPELINES) for p in cfg.data.test.pipeline + ] + test_pipeline = Compose(test_pipeline) + + data = test_pipeline(data) + data = collate([data], samples_per_gpu=1) + + if next(model.parameters()).is_cuda: + # scatter to specified GPU + data = scatter(data, [device])[0] + + # forward the model + with torch.no_grad(): + output = model(data, mode='test') + return data, output diff --git a/mmselfsup/models/algorithms/mae.py b/mmselfsup/models/algorithms/mae.py index a0f003fe..15b8ec80 100644 --- a/mmselfsup/models/algorithms/mae.py +++ b/mmselfsup/models/algorithms/mae.py @@ -1,4 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Dict, Tuple + +import torch + from ..builder import ALGORITHMS, build_backbone, build_head, build_neck from .base import BaseModel @@ -8,17 +12,22 @@ class MAE(BaseModel): """MAE. Implementation of `Masked Autoencoders Are Scalable Vision Learners - `_. + `_. + Args: - backbone (dict): Config dict for encoder. Defaults to None. - neck (dict): Config dict for encoder. Defaults to None. - head (dict): Config dict for loss functions. Defaults to None. - init_cfg (dict): Config dict for weight initialization. + backbone (dict, optional): Config dict for encoder. Defaults to None. + neck (dict, optional): Config dict for encoder. Defaults to None. + head (dict, optional): Config dict for loss functions. Defaults to None. + init_cfg (dict, optional): Config dict for weight initialization. Defaults to None. """ - def __init__(self, backbone=None, neck=None, head=None, init_cfg=None): - super(MAE, self).__init__(init_cfg) + def __init__(self, + backbone: Optional[dict] = None, + neck: Optional[dict] = None, + head: Optional[dict] = None, + init_cfg: Optional[dict] = None) -> None: + super().__init__(init_cfg) assert backbone is not None self.backbone = build_backbone(backbone) assert neck is not None @@ -28,31 +37,56 @@ class MAE(BaseModel): self.head = build_head(head) def init_weights(self): - super(MAE, self).init_weights() + super().init_weights() - def extract_feat(self, img): + def extract_feat(self, img: torch.Tensor) -> Tuple[torch.Tensor]: """Function to extract features from backbone. Args: - img (Tensor): Input images of shape (N, C, H, W). - + img (torch.Tensor): Input images of shape (N, C, H, W). Returns: - tuple[Tensor]: backbone outputs. + Tuple[torch.Tensor]: backbone outputs. """ return self.backbone(img) - def forward_train(self, img, **kwargs): + def forward_train(self, img: torch.Tensor, + **kwargs) -> Dict[str, torch.Tensor]: """Forward computation during training. Args: - img (Tensor): Input images of shape (N, C, H, W). + img (torch.Tensor): Input images of shape (N, C, H, W). kwargs: Any keyword arguments to be used to forward. - Returns: - dict[str, Tensor]: A dictionary of loss components. + Dict[str, torch.Tensor]: A dictionary of loss components. """ latent, mask, ids_restore = self.backbone(img) pred = self.neck(latent, ids_restore) losses = self.head(img, pred, mask) return losses + + def forward_test(self, img: torch.Tensor, + **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward computation during testing. + + Args: + img (torch.Tensor): Input images of shape (N, C, H, W). + kwargs: Any keyword arguments to be used to forward. + Returns: + Tuple[torch.Tensor, torch.Tensor]: Output of model test. + - mask: Mask used to mask image. + - pred: The output of neck. + """ + latent, mask, ids_restore = self.backbone(img) + pred = self.neck(latent, ids_restore) + + pred = self.head.unpatchify(pred) + pred = torch.einsum('nchw->nhwc', pred).detach().cpu() + + mask = mask.detach() + mask = mask.unsqueeze(-1).repeat(1, 1, self.head.patch_size**2 * + 3) # (N, H*W, p*p*3) + mask = self.head.unpatchify(mask) # 1 is removing, 0 is keeping + mask = torch.einsum('nchw->nhwc', mask).detach().cpu() + + return mask, pred diff --git a/mmselfsup/models/heads/mae_head.py b/mmselfsup/models/heads/mae_head.py index cfaf2161..bfb86d3d 100644 --- a/mmselfsup/models/heads/mae_head.py +++ b/mmselfsup/models/heads/mae_head.py @@ -18,13 +18,18 @@ class MAEPretrainHead(BaseModule): patch_size (int): Patch size. Defaults to 16. """ - def __init__(self, norm_pix=False, patch_size=16): - super(MAEPretrainHead, self).__init__() + def __init__(self, norm_pix: bool = False, patch_size: int = 16) -> None: + super().__init__() self.norm_pix = norm_pix self.patch_size = patch_size - def patchify(self, imgs): - + def patchify(self, imgs: torch.Tensor) -> torch.Tensor: + """ + Args: + imgs (torch.Tensor): The shape is (N, 3, H, W) + Returns: + x (torch.Tensor): The shape is (N, L, patch_size**2 *3) + """ p = self.patch_size assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 @@ -34,7 +39,24 @@ class MAEPretrainHead(BaseModule): x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) return x - def forward(self, x, pred, mask): + def unpatchify(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): The shape is (N, L, patch_size**2 *3) + Returns: + imgs (torch.Tensor): The shape is (N, 3, H, W) + """ + p = self.patch_size + h = w = int(x.shape[1]**.5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) + return imgs + + def forward(self, x: torch.Tensor, pred: torch.Tensor, + mask: torch.Tensor) -> dict: losses = dict() target = self.patchify(x) if self.norm_pix: @@ -60,7 +82,7 @@ class MAEFinetuneHead(BaseModule): """ def __init__(self, embed_dim, num_classes=1000, label_smooth_val=0.1): - super(MAEFinetuneHead, self).__init__() + super().__init__() self.head = nn.Linear(embed_dim, num_classes) self.criterion = LabelSmoothLoss(label_smooth_val, num_classes) @@ -92,7 +114,7 @@ class MAELinprobeHead(BaseModule): """ def __init__(self, embed_dim, num_classes=1000): - super(MAELinprobeHead, self).__init__() + super().__init__() self.head = nn.Linear(embed_dim, num_classes) self.bn = nn.BatchNorm1d(embed_dim, affine=False, eps=1e-6) self.criterion = nn.CrossEntropyLoss() diff --git a/tests/test_apis/test_inference.py b/tests/test_apis/test_inference.py new file mode 100644 index 00000000..41ccb728 --- /dev/null +++ b/tests/test_apis/test_inference.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import platform + +import pytest +import torch +import torch.nn as nn +from mmcv import Config +from PIL import Image + +from mmselfsup.apis import inference_model +from mmselfsup.models import BaseModel + + +class ExampleModel(BaseModel): + + def __init__(self): + super(ExampleModel, self).__init__() + self.test_cfg = None + self.layer = nn.Linear(1, 1) + self.neck = nn.Identity() + + def extract_feat(self, imgs): + pass + + def forward_train(self, imgs, **kwargs): + pass + + def forward_test(self, img, **kwargs): + out = self.layer(img) + return out + + +@pytest.mark.skipif(platform.system() == 'Windows', reason='') +def test_inference_model(): + # Specify the data settings + cfg = Config.fromfile( + 'configs/selfsup/relative_loc/relative-loc_resnet50_8xb64-steplr-70e_in1k.py' # noqa: E501 + ) + + # Build the algorithm + model = ExampleModel() + model.cfg = cfg + + img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + model.cfg.data = dict( + test=dict(pipeline=[ + dict(type='Resize', size=(1, 1)), + dict(type='ToTensor'), + dict(type='Normalize', **img_norm_cfg), + ])) + + data = Image.open( + osp.join(osp.dirname(__file__), '..', 'data', 'color.jpg')) + + # inference model + data, output = inference_model(model, data) + assert data.size() == torch.Size([1, 3, 1, 1]) + assert output.size() == torch.Size([1, 3, 1, 1]) diff --git a/tests/test_models/test_algorithms/test_mae.py b/tests/test_models/test_algorithms/test_mae.py index 87ad1034..a7d88c39 100644 --- a/tests/test_models/test_algorithms/test_mae.py +++ b/tests/test_models/test_algorithms/test_mae.py @@ -33,5 +33,8 @@ def test_mae(): fake_input = torch.randn((2, 3, 224, 224)) fake_loss = alg.forward_train(fake_input) fake_feature = alg.extract_feat(fake_input) + mask, pred = alg.forward_test(fake_input) assert isinstance(fake_loss['loss'].item(), float) assert list(fake_feature[0].shape) == [2, 50, 768] + assert list(mask.shape) == [2, 224, 224, 3] + assert list(pred.shape) == [2, 224, 224, 3] diff --git a/tests/test_models/test_heads.py b/tests/test_models/test_heads.py index 269557c4..a019a51d 100644 --- a/tests/test_models/test_heads.py +++ b/tests/test_models/test_heads.py @@ -103,6 +103,10 @@ def test_mae_pretrain_head(): assert loss_norm_pixel['loss'].item() > 0 + x = torch.rand((1, 4, 16**2 * 3)) + imgs = head_norm_pixel.unpatchify(x) + assert imgs.size() == torch.Size((1, 3, 32, 32)) + def test_mae_finetune_head(): diff --git a/tools/misc/mae_visualization.py b/tools/misc/mae_visualization.py new file mode 100644 index 00000000..ff38d2d5 --- /dev/null +++ b/tools/misc/mae_visualization.py @@ -0,0 +1,93 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Modified from https://colab.research.google.com/github/facebookresearch/mae +# /blob/main/demo/mae_visualize.ipynb +from argparse import ArgumentParser +from typing import Tuple + +import matplotlib.pyplot as plt +import numpy as np +import torch +from PIL import Image + +from mmselfsup.apis import inference_model, init_model + +imagenet_mean = np.array([0.485, 0.456, 0.406]) +imagenet_std = np.array([0.229, 0.224, 0.225]) + + +def show_image(image: torch.Tensor, title: str = '') -> None: + # image is [H, W, 3] + assert image.shape[2] == 3 + image = torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, + 255).int() + plt.imshow(image) + plt.title(title, fontsize=16) + plt.axis('off') + return + + +def show_images(x: torch.Tensor, im_masked: torch.Tensor, y: torch.Tensor, + im_paste: torch.Tensor) -> None: + # make the plt figure larger + plt.rcParams['figure.figsize'] = [24, 6] + + plt.subplot(1, 4, 1) + show_image(x, 'original') + + plt.subplot(1, 4, 2) + show_image(im_masked, 'masked') + + plt.subplot(1, 4, 3) + show_image(y, 'reconstruction') + + plt.subplot(1, 4, 4) + show_image(im_paste, 'reconstruction + visible') + + plt.show() + + +def post_process( + x: torch.Tensor, y: torch.Tensor, mask: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + x = torch.einsum('nchw->nhwc', x.cpu()) + # masked image + im_masked = x * (1 - mask) + # MAE reconstruction pasted with visible patches + im_paste = x * (1 - mask) + y * mask + return x[0], im_masked[0], y[0], im_paste[0] + + +def main(): + parser = ArgumentParser() + parser.add_argument('img', help='Image file') + parser.add_argument('config', help='MAE Config file') + parser.add_argument('checkpoint', help='Checkpoint file') + parser.add_argument( + '--device', default='cuda:0', help='Device used for inference') + args = parser.parse_args() + + # build the model from a config file and a checkpoint file + model = init_model(args.config, args.checkpoint, device=args.device) + print('Model loaded.') + + # make random mask reproducible (comment out to make it change) + torch.manual_seed(2) + print('MAE with pixel reconstruction:') + + img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + model.cfg.data = dict( + test=dict(pipeline=[ + dict(type='Resize', size=(224, 224)), + dict(type='ToTensor'), + dict(type='Normalize', **img_norm_cfg), + ])) + + img = Image.open(args.img) + img, (mask, pred) = inference_model(model, img) + x, im_masked, y, im_paste = post_process(img, pred, mask) + show_images(x, im_masked, y, im_paste) + + +if __name__ == '__main__': + main()