404 lines
11 KiB
Plaintext
404 lines
11 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Copyright (c) OpenMMLab. All rights reserved.\n",
|
|
"\n",
|
|
"Copyright (c) Meta Platforms, Inc. and affiliates.\n",
|
|
"\n",
|
|
"Modified from https://colab.research.google.com/github/facebookresearch/mae/blob/main/demo/mae_visualize.ipynb\n",
|
|
"\n",
|
|
"## Masked Autoencoders: Visualization Demo\n",
|
|
"\n",
|
|
"This is a visualization demo using our pre-trained MAE models. No GPU is needed."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Prepare\n",
|
|
"Check environment. Install packages if in Colab."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import sys\n",
|
|
"\n",
|
|
"# check whether run in Colab\n",
|
|
"if 'google.colab' in sys.modules:\n",
|
|
" print('Running in Colab.')\n",
|
|
" !pip3 install openmim\n",
|
|
" !pip install -U openmim\n",
|
|
" !mim install 'mmengine==0.1.0' 'mmcv>=2.0.0rc1'\n",
|
|
"\n",
|
|
" !git clone https://github.com/open-mmlab/mmselfsup.git\n",
|
|
" %cd mmselfsup/\n",
|
|
" !git checkout dev-1.x\n",
|
|
" !pip install -e .\n",
|
|
"\n",
|
|
" sys.path.append('./mmselfsup')\n",
|
|
" %cd demo\n",
|
|
"else:\n",
|
|
" sys.path.append('..')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import numpy as np\n",
|
|
"import torch\n",
|
|
"from mmengine.dataset import Compose, default_collate\n",
|
|
"\n",
|
|
"from mmselfsup.apis import inference_model\n",
|
|
"from mmselfsup.models.utils import SelfSupDataPreprocessor\n",
|
|
"from mmselfsup.registry import MODELS\n",
|
|
"from mmselfsup.utils import register_all_modules"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Define utils"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# define the utils\n",
|
|
"\n",
|
|
"imagenet_mean = np.array([0.485, 0.456, 0.406])\n",
|
|
"imagenet_std = np.array([0.229, 0.224, 0.225])\n",
|
|
"\n",
|
|
"def show_image(image, title=''):\n",
|
|
" # image is [H, W, 3]\n",
|
|
" assert image.shape[2] == 3\n",
|
|
" image = torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int()\n",
|
|
" plt.imshow(image)\n",
|
|
" plt.title(title, fontsize=16)\n",
|
|
" plt.axis('off')\n",
|
|
" return\n",
|
|
"\n",
|
|
"\n",
|
|
"def show_images(x, im_masked, y, im_paste):\n",
|
|
" # make the plt figure larger\n",
|
|
" plt.rcParams['figure.figsize'] = [24, 6]\n",
|
|
"\n",
|
|
" plt.subplot(1, 4, 1)\n",
|
|
" show_image(x, \"original\")\n",
|
|
"\n",
|
|
" plt.subplot(1, 4, 2)\n",
|
|
" show_image(im_masked, \"masked\")\n",
|
|
"\n",
|
|
" plt.subplot(1, 4, 3)\n",
|
|
" show_image(y, \"reconstruction\")\n",
|
|
"\n",
|
|
" plt.subplot(1, 4, 4)\n",
|
|
" show_image(im_paste, \"reconstruction + visible\")\n",
|
|
"\n",
|
|
" plt.show()\n",
|
|
"\n",
|
|
"\n",
|
|
"def post_process(x, y, mask):\n",
|
|
" x = torch.einsum('nchw->nhwc', x.cpu())\n",
|
|
" # masked image\n",
|
|
" im_masked = x * (1 - mask)\n",
|
|
" # MAE reconstruction pasted with visible patches\n",
|
|
" im_paste = x * (1 - mask) + y * mask\n",
|
|
" return x[0], im_masked[0], y[0], im_paste[0]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Prepare config file"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Overwriting ../configs/selfsup/mae/mae_visualization.py\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%%writefile ../configs/selfsup/mae/mae_visualization.py\n",
|
|
"model = dict(\n",
|
|
" type='MAE',\n",
|
|
" data_preprocessor=dict(\n",
|
|
" mean=[123.675, 116.28, 103.53],\n",
|
|
" std=[58.395, 57.12, 57.375],\n",
|
|
" bgr_to_rgb=True),\n",
|
|
" backbone=dict(type='MAEViT', arch='l', patch_size=16, mask_ratio=0.75),\n",
|
|
" neck=dict(\n",
|
|
" type='MAEPretrainDecoder',\n",
|
|
" patch_size=16,\n",
|
|
" in_chans=3,\n",
|
|
" embed_dim=1024,\n",
|
|
" decoder_embed_dim=512,\n",
|
|
" decoder_depth=8,\n",
|
|
" decoder_num_heads=16,\n",
|
|
" mlp_ratio=4.,\n",
|
|
" ),\n",
|
|
" head=dict(\n",
|
|
" type='MAEPretrainHead',\n",
|
|
" norm_pix=True,\n",
|
|
" patch_size=16,\n",
|
|
" loss=dict(type='MAEReconstructionLoss')),\n",
|
|
" init_cfg=[\n",
|
|
" dict(type='Xavier', distribution='uniform', layer='Linear'),\n",
|
|
" dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0)\n",
|
|
" ])\n",
|
|
"\n",
|
|
"file_client_args = dict(backend='disk')\n",
|
|
"\n",
|
|
"# dataset summary\n",
|
|
"test_dataloader = dict(\n",
|
|
" dataset=dict(pipeline=[\n",
|
|
" dict(type='LoadImageFromFile', file_client_args=file_client_args),\n",
|
|
" dict(type='Resize', scale=(224, 224)),\n",
|
|
" dict(type='PackSelfSupInputs', meta_keys=['img_path'])\n",
|
|
" ]))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Load a pre-trained MAE model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"--2022-09-03 00:34:55-- https://download.openmmlab.com/mmselfsup/mae/mae_visualize_vit_large.pth\n",
|
|
"正在解析主机 download.openmmlab.com (download.openmmlab.com)... 47.107.10.247\n",
|
|
"正在连接 download.openmmlab.com (download.openmmlab.com)|47.107.10.247|:443... 已连接。\n",
|
|
"已发出 HTTP 请求,正在等待回应... 200 OK\n",
|
|
"长度: 1318299501 (1.2G) [application/octet-stream]\n",
|
|
"正在保存至: “mae_visualize_vit_large.pth”\n",
|
|
"\n",
|
|
"mae_visualize_vit_l 100%[===================>] 1.23G 3.22MB/s 用时 6m 4s \n",
|
|
"\n",
|
|
"2022-09-03 00:40:59 (3.46 MB/s) - 已保存 “mae_visualize_vit_large.pth” [1318299501/1318299501])\n",
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# This is an MAE model trained with pixels as targets for visualization (ViT-large, training mask ratio=0.75)\n",
|
|
"\n",
|
|
"# download checkpoint if not exist\n",
|
|
"# This ckpt is converted from https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth\n",
|
|
"!wget -nc https://download.openmmlab.com/mmselfsup/mae/mae_visualize_vit_large.pth"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"local loads checkpoint from path: mae_visualize_vit_large.pth\n",
|
|
"Model loaded.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from mmselfsup.apis import init_model\n",
|
|
"ckpt_path = \"mae_visualize_vit_large.pth\"\n",
|
|
"model = init_model('../configs/selfsup/mae/mae_visualization.py', ckpt_path, device='cpu')\n",
|
|
"print('Model loaded.')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Load an image"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"<torch._C.Generator at 0x7f5029d19950>"
|
|
]
|
|
},
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# make random mask reproducible (comment out to make it change)\n",
|
|
"register_all_modules()\n",
|
|
"torch.manual_seed(2)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"--2022-09-03 00:41:01-- https://download.openmmlab.com/mmselfsup/mae/fox.jpg\n",
|
|
"正在解析主机 download.openmmlab.com (download.openmmlab.com)... 101.133.111.186\n",
|
|
"正在连接 download.openmmlab.com (download.openmmlab.com)|101.133.111.186|:443... 已连接。\n",
|
|
"已发出 HTTP 请求,正在等待回应... 200 OK\n",
|
|
"长度: 60133 (59K) [image/jpeg]\n",
|
|
"正在保存至: “fox.jpg”\n",
|
|
"\n",
|
|
"fox.jpg 100%[===================>] 58.72K --.-KB/s 用时 0.06s \n",
|
|
"\n",
|
|
"2022-09-03 00:41:01 (962 KB/s) - 已保存 “fox.jpg” [60133/60133])\n",
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"!wget -nc 'https://download.openmmlab.com/mmselfsup/mae/fox.jpg'"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"img_path = 'fox.jpg'"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"cfg = model.cfg\n",
|
|
"test_pipeline = Compose(cfg.test_dataloader.dataset.pipeline)\n",
|
|
"data_preprocessor = MODELS.build(cfg.model.data_preprocessor)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"data = dict(img_path=img_path)\n",
|
|
"data = test_pipeline(data)\n",
|
|
"data = default_collate([data])\n",
|
|
"img, _ = data_preprocessor(data, False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"plt.rcParams['figure.figsize'] = [5, 5]\n",
|
|
"show_image(torch.einsum('nchw->nhwc', img[0].cpu())[0])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Run MAE on the image"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"results = inference_model(model, img_path)\n",
|
|
"x, im_masked, y, im_paste = post_process(img[0], results.pred.value, results.mask.value)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print('MAE with pixel reconstruction:')\n",
|
|
"show_images(x, im_masked, y, im_paste)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.13"
|
|
},
|
|
"vscode": {
|
|
"interpreter": {
|
|
"hash": "1742319693997e01e5942276ccf039297cd0a474ab9a20f711b7fa536eca5436"
|
|
}
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|