{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Copyright (c) OpenMMLab. All rights reserved.\n", "\n", "Modified from https://colab.research.google.com/github/facebookresearch/mae/blob/main/demo/mae_visualize.ipynb\n", "\n", "## Masked Autoencoders: Visualization Demo\n", "\n", "This is a visualization demo using our pre-trained MAE models. No GPU is needed." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prepare\n", "Check environment. Install packages if in Colab." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import sys\n", "\n", "# check whether run in Colab\n", "if 'google.colab' in sys.modules:\n", " print('Running in Colab.')\n", " !pip3 install openmim\n", " !pip install -U openmim\n", " !mim install mmengine\n", " !mim install 'mmcv>=2.0.0rc1'\n", "\n", " !git clone https://github.com/open-mmlab/mmselfsup.git\n", " %cd mmselfsup/\n", " !git checkout dev-1.x\n", " !pip install -e .\n", "\n", " sys.path.append('./mmselfsup')\n", " %cd demo\n", "else:\n", " sys.path.append('..')" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "from argparse import ArgumentParser\n", "from typing import Tuple, Optional\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import torch\n", "from mmengine.dataset import Compose, default_collate\n", "\n", "from mmselfsup.apis import inference_model, init_model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define utils" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "# define the utils\n", "\n", "imagenet_mean = np.array([0.485, 0.456, 0.406])\n", "imagenet_std = np.array([0.229, 0.224, 0.225])\n", "\n", "\n", "def show_image(img: torch.Tensor, title: str = '') -> None:\n", " # image is [H, W, 3]\n", " assert img.shape[2] == 3\n", "\n", " plt.imshow(img)\n", " plt.title(title, fontsize=16)\n", " plt.axis('off')\n", " return\n", "\n", "\n", "def save_images(original_img: torch.Tensor, img_masked: torch.Tensor,\n", " pred_img: torch.Tensor, img_paste: torch.Tensor,\n", " out_file: Optional[str] =None) -> None:\n", " # make the plt figure larger\n", " plt.rcParams['figure.figsize'] = [24, 6]\n", "\n", " plt.subplot(1, 4, 1)\n", " show_image(original_img, 'original')\n", "\n", " plt.subplot(1, 4, 2)\n", " show_image(img_masked, 'masked')\n", "\n", " plt.subplot(1, 4, 3)\n", " show_image(pred_img, 'reconstruction')\n", "\n", " plt.subplot(1, 4, 4)\n", " show_image(img_paste, 'reconstruction + visible')\n", "\n", " if out_file is None:\n", " plt.show()\n", " else:\n", " plt.savefig(out_file)\n", " print(f'Images are saved to {out_file}')\n", "\n", "\n", "def recover_norm(img: torch.Tensor,\n", " mean: np.ndarray = imagenet_mean,\n", " std: np.ndarray = imagenet_std):\n", " if mean is not None and std is not None:\n", " img = torch.clip((img * std + mean) * 255, 0, 255).int()\n", " return img\n", "\n", "\n", "def post_process(\n", " original_img: torch.Tensor,\n", " pred_img: torch.Tensor,\n", " mask: torch.Tensor,\n", " mean: np.ndarray = imagenet_mean,\n", " std: np.ndarray = imagenet_std\n", ") -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n", " # channel conversion\n", " original_img = torch.einsum('nchw->nhwc', original_img.cpu())\n", " # masked image\n", " img_masked = original_img * (1 - mask)\n", " # reconstructed image pasted with visible patches\n", " img_paste = original_img * (1 - mask) + pred_img * mask\n", "\n", " # muptiply std and add mean to each image\n", " original_img = recover_norm(original_img[0])\n", " img_masked = recover_norm(img_masked[0])\n", "\n", " pred_img = recover_norm(pred_img[0])\n", " img_paste = recover_norm(img_paste[0])\n", "\n", " return original_img, img_masked, pred_img, img_paste\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load a pre-trained MAE model\n", "\n", "This is an MAE model trained with config 'mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k.py'.\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--2022-11-08 11:00:50-- https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth\n", "正在解析主机 download.openmmlab.com (download.openmmlab.com)... 47.102.71.233\n", "正在连接 download.openmmlab.com (download.openmmlab.com)|47.102.71.233|:443... 已连接。\n", "已发出 HTTP 请求,正在等待回应... 200 OK\n", "长度: 1355429265 (1.3G) [application/octet-stream]\n", "正在保存至: “mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth”\n", "\n", "e_in1k_20220825-cc7 99%[==================> ] 1.26G 913KB/s 剩余 0s s " ] } ], "source": [ "# download checkpoint if not exist\n", "!wget -nc https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "local loads checkpoint from path: mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth\n", "The model and loaded state dict do not match exactly\n", "\n", "unexpected key in source state_dict: data_preprocessor.mean, data_preprocessor.std\n", "\n", "Model loaded.\n" ] } ], "source": [ "ckpt_path = \"mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth\"\n", "model = init_model(\n", " '../configs/selfsup/mae/mae_vit-large-p16_8xb512-amp-coslr-1600e_in1k.py',\n", " ckpt_path,\n", " device='cpu')\n", "print('Model loaded.')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load an image" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# make random mask reproducible (comment out to make it change)\n", "torch.manual_seed(2)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--2022-11-08 11:21:14-- https://download.openmmlab.com/mmselfsup/mae/fox.jpg\n", "正在解析主机 download.openmmlab.com (download.openmmlab.com)... 47.102.71.233\n", "正在连接 download.openmmlab.com (download.openmmlab.com)|47.102.71.233|:443... 已连接。\n", "已发出 HTTP 请求,正在等待回应... 200 OK\n", "长度: 60133 (59K) [image/jpeg]\n", "正在保存至: “fox.jpg”\n", "\n", "fox.jpg 100%[===================>] 58.72K --.-KB/s 用时 0.05s \n", "\n", "2022-11-08 11:21:15 (1.08 MB/s) - 已保存 “fox.jpg” [60133/60133])\n", "\n" ] } ], "source": [ "!wget -nc 'https://download.openmmlab.com/mmselfsup/mae/fox.jpg'" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "img_path = 'fox.jpg'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Build Pipeline" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "model.cfg.test_dataloader = dict(\n", " dataset=dict(pipeline=[\n", " dict(type='LoadImageFromFile'),\n", " dict(type='Resize', scale=(224, 224), backend='pillow'),\n", " dict(type='PackSelfSupInputs', meta_keys=['img_path'])\n", " ]))\n", "\n", "vis_pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "data = dict(img_path=img_path)\n", "data = vis_pipeline(data)\n", "data = default_collate([data])\n", "img, _ = model.data_preprocessor(data, False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Reconstruction pipeline" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# for MAE reconstruction\n", "img_embedding = model.head.patchify(img[0])\n", "# normalize the target image\n", "mean = img_embedding.mean(dim=-1, keepdim=True)\n", "std = (img_embedding.var(dim=-1, keepdim=True) + 1.e-6)**.5" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# get reconstruction image\n", "features = inference_model(model, img_path)\n", "results = model.reconstruct(features, mean=mean, std=std)\n" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "original_target = img[0]\n", "original_img, img_masked, pred_img, img_paste = post_process(\n", " original_target,\n", " results.pred.value,\n", " results.mask.value,\n", " mean=mean,\n", " std=std)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Show the image" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save_images(original_img, img_masked, pred_img, img_paste)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.7.0 ('openmmlab')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.0 (default, Oct 9 2018, 10:31:47) \n[GCC 7.3.0]" }, "vscode": { "interpreter": { "hash": "5909b3386efe3692f76356628babf720cfd47771f5d858315790cc041eb41361" } } }, "nbformat": 4, "nbformat_minor": 2 }