404 lines
11 KiB
Plaintext
404 lines
11 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Copyright (c) OpenMMLab. All rights reserved.\n",
|
|
"\n",
|
|
"Modified from https://colab.research.google.com/github/facebookresearch/mae/blob/main/demo/mae_visualize.ipynb\n",
|
|
"\n",
|
|
"## Masked Autoencoders: Visualization Demo\n",
|
|
"\n",
|
|
"This is a visualization demo using our pre-trained MAE models. No GPU is needed."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Prepare\n",
|
|
"Check environment. Install packages if in Colab."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import sys\n",
|
|
"\n",
|
|
"# check whether run in Colab\n",
|
|
"if 'google.colab' in sys.modules:\n",
|
|
" print('Running in Colab.')\n",
|
|
" !pip3 install openmim\n",
|
|
" !pip install -U openmim\n",
|
|
" !mim install mmengine\n",
|
|
" !mim install 'mmcv>=2.0.0rc1'\n",
|
|
"\n",
|
|
" !git clone https://github.com/open-mmlab/mmselfsup.git\n",
|
|
" %cd mmselfsup/\n",
|
|
" !git checkout dev-1.x\n",
|
|
" !pip install -e .\n",
|
|
"\n",
|
|
" sys.path.append('./mmselfsup')\n",
|
|
" %cd demo\n",
|
|
"else:\n",
|
|
" sys.path.append('..')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 21,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from argparse import ArgumentParser\n",
|
|
"from typing import Tuple, Optional\n",
|
|
"\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import numpy as np\n",
|
|
"import torch\n",
|
|
"from mmengine.dataset import Compose, default_collate\n",
|
|
"\n",
|
|
"from mmselfsup.apis import inference_model, init_model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Define utils"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# define the utils\n",
|
|
"\n",
|
|
"imagenet_mean = np.array([0.485, 0.456, 0.406])\n",
|
|
"imagenet_std = np.array([0.229, 0.224, 0.225])\n",
|
|
"\n",
|
|
"\n",
|
|
"def show_image(img: torch.Tensor, title: str = '') -> None:\n",
|
|
" # image is [H, W, 3]\n",
|
|
" assert img.shape[2] == 3\n",
|
|
"\n",
|
|
" plt.imshow(img)\n",
|
|
" plt.title(title, fontsize=16)\n",
|
|
" plt.axis('off')\n",
|
|
" return\n",
|
|
"\n",
|
|
"\n",
|
|
"def save_images(original_img: torch.Tensor, img_masked: torch.Tensor,\n",
|
|
" pred_img: torch.Tensor, img_paste: torch.Tensor,\n",
|
|
" out_file: Optional[str] =None) -> None:\n",
|
|
" # make the plt figure larger\n",
|
|
" plt.rcParams['figure.figsize'] = [24, 6]\n",
|
|
"\n",
|
|
" plt.subplot(1, 4, 1)\n",
|
|
" show_image(original_img, 'original')\n",
|
|
"\n",
|
|
" plt.subplot(1, 4, 2)\n",
|
|
" show_image(img_masked, 'masked')\n",
|
|
"\n",
|
|
" plt.subplot(1, 4, 3)\n",
|
|
" show_image(pred_img, 'reconstruction')\n",
|
|
"\n",
|
|
" plt.subplot(1, 4, 4)\n",
|
|
" show_image(img_paste, 'reconstruction + visible')\n",
|
|
"\n",
|
|
" if out_file is None:\n",
|
|
" plt.show()\n",
|
|
" else:\n",
|
|
" plt.savefig(out_file)\n",
|
|
" print(f'Images are saved to {out_file}')\n",
|
|
"\n",
|
|
"\n",
|
|
"def recover_norm(img: torch.Tensor,\n",
|
|
" mean: np.ndarray = imagenet_mean,\n",
|
|
" std: np.ndarray = imagenet_std):\n",
|
|
" if mean is not None and std is not None:\n",
|
|
" img = torch.clip((img * std + mean) * 255, 0, 255).int()\n",
|
|
" return img\n",
|
|
"\n",
|
|
"\n",
|
|
"def post_process(\n",
|
|
" original_img: torch.Tensor,\n",
|
|
" pred_img: torch.Tensor,\n",
|
|
" mask: torch.Tensor,\n",
|
|
" mean: np.ndarray = imagenet_mean,\n",
|
|
" std: np.ndarray = imagenet_std\n",
|
|
") -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n",
|
|
" # channel conversion\n",
|
|
" original_img = torch.einsum('nchw->nhwc', original_img.cpu())\n",
|
|
" # masked image\n",
|
|
" img_masked = original_img * (1 - mask)\n",
|
|
" # reconstructed image pasted with visible patches\n",
|
|
" img_paste = original_img * (1 - mask) + pred_img * mask\n",
|
|
"\n",
|
|
" # muptiply std and add mean to each image\n",
|
|
" original_img = recover_norm(original_img[0])\n",
|
|
" img_masked = recover_norm(img_masked[0])\n",
|
|
"\n",
|
|
" pred_img = recover_norm(pred_img[0])\n",
|
|
" img_paste = recover_norm(img_paste[0])\n",
|
|
"\n",
|
|
" return original_img, img_masked, pred_img, img_paste\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Load a pre-trained MAE model\n",
|
|
"\n",
|
|
"This is an MAE model trained with config 'mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k.py'.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"--2022-11-08 11:00:50-- https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth\n",
|
|
"正在解析主机 download.openmmlab.com (download.openmmlab.com)... 47.102.71.233\n",
|
|
"正在连接 download.openmmlab.com (download.openmmlab.com)|47.102.71.233|:443... 已连接。\n",
|
|
"已发出 HTTP 请求,正在等待回应... 200 OK\n",
|
|
"长度: 1355429265 (1.3G) [application/octet-stream]\n",
|
|
"正在保存至: “mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth”\n",
|
|
"\n",
|
|
"e_in1k_20220825-cc7 99%[==================> ] 1.26G 913KB/s 剩余 0s s "
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# download checkpoint if not exist\n",
|
|
"!wget -nc https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"local loads checkpoint from path: mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth\n",
|
|
"The model and loaded state dict do not match exactly\n",
|
|
"\n",
|
|
"unexpected key in source state_dict: data_preprocessor.mean, data_preprocessor.std\n",
|
|
"\n",
|
|
"Model loaded.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"ckpt_path = \"mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth\"\n",
|
|
"model = init_model(\n",
|
|
" '../configs/selfsup/mae/mae_vit-large-p16_8xb512-amp-coslr-1600e_in1k.py',\n",
|
|
" ckpt_path,\n",
|
|
" device='cpu')\n",
|
|
"print('Model loaded.')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Load an image"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"<torch._C.Generator at 0x7fb2ccfbac90>"
|
|
]
|
|
},
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# make random mask reproducible (comment out to make it change)\n",
|
|
"torch.manual_seed(2)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"--2022-11-08 11:21:14-- https://download.openmmlab.com/mmselfsup/mae/fox.jpg\n",
|
|
"正在解析主机 download.openmmlab.com (download.openmmlab.com)... 47.102.71.233\n",
|
|
"正在连接 download.openmmlab.com (download.openmmlab.com)|47.102.71.233|:443... 已连接。\n",
|
|
"已发出 HTTP 请求,正在等待回应... 200 OK\n",
|
|
"长度: 60133 (59K) [image/jpeg]\n",
|
|
"正在保存至: “fox.jpg”\n",
|
|
"\n",
|
|
"fox.jpg 100%[===================>] 58.72K --.-KB/s 用时 0.05s \n",
|
|
"\n",
|
|
"2022-11-08 11:21:15 (1.08 MB/s) - 已保存 “fox.jpg” [60133/60133])\n",
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"!wget -nc 'https://download.openmmlab.com/mmselfsup/mae/fox.jpg'"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"img_path = 'fox.jpg'"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Build Pipeline"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model.cfg.test_dataloader = dict(\n",
|
|
" dataset=dict(pipeline=[\n",
|
|
" dict(type='LoadImageFromFile'),\n",
|
|
" dict(type='Resize', scale=(224, 224), backend='pillow'),\n",
|
|
" dict(type='PackSelfSupInputs', meta_keys=['img_path'])\n",
|
|
" ]))\n",
|
|
"\n",
|
|
"vis_pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"data = dict(img_path=img_path)\n",
|
|
"data = vis_pipeline(data)\n",
|
|
"data = default_collate([data])\n",
|
|
"img, _ = model.data_preprocessor(data, False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Reconstruction pipeline"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# for MAE reconstruction\n",
|
|
"img_embedding = model.head.patchify(img[0])\n",
|
|
"# normalize the target image\n",
|
|
"mean = img_embedding.mean(dim=-1, keepdim=True)\n",
|
|
"std = (img_embedding.var(dim=-1, keepdim=True) + 1.e-6)**.5"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# get reconstruction image\n",
|
|
"features = inference_model(model, img_path)\n",
|
|
"results = model.reconstruct(features, mean=mean, std=std)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"original_target = img[0]\n",
|
|
"original_img, img_masked, pred_img, img_paste = post_process(\n",
|
|
" original_target,\n",
|
|
" results.pred.value,\n",
|
|
" results.mask.value,\n",
|
|
" mean=mean,\n",
|
|
" std=std)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Show the image"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"save_images(original_img, img_masked, pred_img, img_paste)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3.7.0 ('openmmlab')",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.0 (default, Oct 9 2018, 10:31:47) \n[GCC 7.3.0]"
|
|
},
|
|
"vscode": {
|
|
"interpreter": {
|
|
"hash": "5909b3386efe3692f76356628babf720cfd47771f5d858315790cc041eb41361"
|
|
}
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|