[Refactor]: refactor MAE visualization (#471)

* [Refactor]: refactor MAE visualization

* [Fix]: fix lint

* [Refactor]: refactor MAE visualization

* [Feature]: add mae_visualization.py

* [UT]: add unit test

* [Refactor]: move mae_visualization.py to tools/analysis_tools

* [Docs]: Add the purpose of the function unpatchify()

* [Fix]: fix lint
pull/582/head
RenQin 2022-10-28 15:37:08 +08:00 committed by Yixiao Fang
parent d142699265
commit 41747f73c7
8 changed files with 719 additions and 1 deletions

View File

@ -0,0 +1,403 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Copyright (c) OpenMMLab. All rights reserved.\n",
"\n",
"Copyright (c) Meta Platforms, Inc. and affiliates.\n",
"\n",
"Modified from https://colab.research.google.com/github/facebookresearch/mae/blob/main/demo/mae_visualize.ipynb\n",
"\n",
"## Masked Autoencoders: Visualization Demo\n",
"\n",
"This is a visualization demo using our pre-trained MAE models. No GPU is needed."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Prepare\n",
"Check environment. Install packages if in Colab."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"\n",
"# check whether run in Colab\n",
"if 'google.colab' in sys.modules:\n",
" print('Running in Colab.')\n",
" !pip3 install openmim\n",
" !pip install -U openmim\n",
" !mim install 'mmengine==0.1.0' 'mmcv>=2.0.0rc1'\n",
"\n",
" !git clone https://github.com/open-mmlab/mmselfsup.git\n",
" %cd mmselfsup/\n",
" !git checkout dev-1.x\n",
" !pip install -e .\n",
"\n",
" sys.path.append('./mmselfsup')\n",
" %cd demo\n",
"else:\n",
" sys.path.append('..')"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import torch\n",
"from mmengine.dataset import Compose, default_collate\n",
"\n",
"from mmselfsup.apis import inference_model\n",
"from mmselfsup.models.utils import SelfSupDataPreprocessor\n",
"from mmselfsup.registry import MODELS\n",
"from mmselfsup.utils import register_all_modules"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define utils"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# define the utils\n",
"\n",
"imagenet_mean = np.array([0.485, 0.456, 0.406])\n",
"imagenet_std = np.array([0.229, 0.224, 0.225])\n",
"\n",
"def show_image(image, title=''):\n",
" # image is [H, W, 3]\n",
" assert image.shape[2] == 3\n",
" image = torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int()\n",
" plt.imshow(image)\n",
" plt.title(title, fontsize=16)\n",
" plt.axis('off')\n",
" return\n",
"\n",
"\n",
"def show_images(x, im_masked, y, im_paste):\n",
" # make the plt figure larger\n",
" plt.rcParams['figure.figsize'] = [24, 6]\n",
"\n",
" plt.subplot(1, 4, 1)\n",
" show_image(x, \"original\")\n",
"\n",
" plt.subplot(1, 4, 2)\n",
" show_image(im_masked, \"masked\")\n",
"\n",
" plt.subplot(1, 4, 3)\n",
" show_image(y, \"reconstruction\")\n",
"\n",
" plt.subplot(1, 4, 4)\n",
" show_image(im_paste, \"reconstruction + visible\")\n",
"\n",
" plt.show()\n",
"\n",
"\n",
"def post_process(x, y, mask):\n",
" x = torch.einsum('nchw->nhwc', x.cpu())\n",
" # masked image\n",
" im_masked = x * (1 - mask)\n",
" # MAE reconstruction pasted with visible patches\n",
" im_paste = x * (1 - mask) + y * mask\n",
" return x[0], im_masked[0], y[0], im_paste[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Prepare config file"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Overwriting ../configs/selfsup/mae/mae_visualization.py\n"
]
}
],
"source": [
"%%writefile ../configs/selfsup/mae/mae_visualization.py\n",
"model = dict(\n",
" type='MAE',\n",
" data_preprocessor=dict(\n",
" mean=[123.675, 116.28, 103.53],\n",
" std=[58.395, 57.12, 57.375],\n",
" bgr_to_rgb=True),\n",
" backbone=dict(type='MAEViT', arch='l', patch_size=16, mask_ratio=0.75),\n",
" neck=dict(\n",
" type='MAEPretrainDecoder',\n",
" patch_size=16,\n",
" in_chans=3,\n",
" embed_dim=1024,\n",
" decoder_embed_dim=512,\n",
" decoder_depth=8,\n",
" decoder_num_heads=16,\n",
" mlp_ratio=4.,\n",
" ),\n",
" head=dict(\n",
" type='MAEPretrainHead',\n",
" norm_pix=True,\n",
" patch_size=16,\n",
" loss=dict(type='MAEReconstructionLoss')),\n",
" init_cfg=[\n",
" dict(type='Xavier', distribution='uniform', layer='Linear'),\n",
" dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0)\n",
" ])\n",
"\n",
"file_client_args = dict(backend='disk')\n",
"\n",
"# dataset summary\n",
"test_dataloader = dict(\n",
" dataset=dict(pipeline=[\n",
" dict(type='LoadImageFromFile', file_client_args=file_client_args),\n",
" dict(type='Resize', scale=(224, 224)),\n",
" dict(type='PackSelfSupInputs', meta_keys=['img_path'])\n",
" ]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load a pre-trained MAE model"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2022-09-03 00:34:55-- https://download.openmmlab.com/mmselfsup/mae/mae_visualize_vit_large.pth\n",
"正在解析主机 download.openmmlab.com (download.openmmlab.com)... 47.107.10.247\n",
"正在连接 download.openmmlab.com (download.openmmlab.com)|47.107.10.247|:443... 已连接。\n",
"已发出 HTTP 请求,正在等待回应... 200 OK\n",
"长度: 1318299501 (1.2G) [application/octet-stream]\n",
"正在保存至: “mae_visualize_vit_large.pth”\n",
"\n",
"mae_visualize_vit_l 100%[===================>] 1.23G 3.22MB/s 用时 6m 4s \n",
"\n",
"2022-09-03 00:40:59 (3.46 MB/s) - 已保存 “mae_visualize_vit_large.pth” [1318299501/1318299501])\n",
"\n"
]
}
],
"source": [
"# This is an MAE model trained with pixels as targets for visualization (ViT-large, training mask ratio=0.75)\n",
"\n",
"# download checkpoint if not exist\n",
"# This ckpt is converted from https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth\n",
"!wget -nc https://download.openmmlab.com/mmselfsup/mae/mae_visualize_vit_large.pth"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"local loads checkpoint from path: mae_visualize_vit_large.pth\n",
"Model loaded.\n"
]
}
],
"source": [
"from mmselfsup.apis import init_model\n",
"ckpt_path = \"mae_visualize_vit_large.pth\"\n",
"model = init_model('../configs/selfsup/mae/mae_visualization.py', ckpt_path, device='cpu')\n",
"print('Model loaded.')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load an image"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch._C.Generator at 0x7f5029d19950>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# make random mask reproducible (comment out to make it change)\n",
"register_all_modules()\n",
"torch.manual_seed(2)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2022-09-03 00:41:01-- https://download.openmmlab.com/mmselfsup/mae/fox.jpg\n",
"正在解析主机 download.openmmlab.com (download.openmmlab.com)... 101.133.111.186\n",
"正在连接 download.openmmlab.com (download.openmmlab.com)|101.133.111.186|:443... 已连接。\n",
"已发出 HTTP 请求,正在等待回应... 200 OK\n",
"长度: 60133 (59K) [image/jpeg]\n",
"正在保存至: “fox.jpg”\n",
"\n",
"fox.jpg 100%[===================>] 58.72K --.-KB/s 用时 0.06s \n",
"\n",
"2022-09-03 00:41:01 (962 KB/s) - 已保存 “fox.jpg” [60133/60133])\n",
"\n"
]
}
],
"source": [
"!wget -nc 'https://download.openmmlab.com/mmselfsup/mae/fox.jpg'"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"img_path = 'fox.jpg'"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"cfg = model.cfg\n",
"test_pipeline = Compose(cfg.test_dataloader.dataset.pipeline)\n",
"data_preprocessor = MODELS.build(cfg.model.data_preprocessor)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"data = dict(img_path=img_path)\n",
"data = test_pipeline(data)\n",
"data = default_collate([data])\n",
"img, _ = data_preprocessor(data, False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.rcParams['figure.figsize'] = [5, 5]\n",
"show_image(torch.einsum('nchw->nhwc', img[0].cpu())[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Run MAE on the image"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"results = inference_model(model, img_path)\n",
"x, im_masked, y, im_paste = post_process(img[0], results.pred.value, results.mask.value)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print('MAE with pixel reconstruction:')\n",
"show_images(x, im_masked, y, im_paste)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"vscode": {
"interpreter": {
"hash": "1742319693997e01e5942276ccf039297cd0a474ab9a20f711b7fa536eca5436"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .inference import inference_model, init_model
__all__ = ['init_model', 'inference_model']

View File

@ -0,0 +1,85 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Union
import numpy as np
import torch
from mmengine.config import Config
from mmengine.dataset import Compose, default_collate
from mmengine.runner import load_checkpoint
from torch import nn
from mmselfsup.models import build_algorithm
from mmselfsup.structures import SelfSupDataSample
def init_model(config: Union[str, Config],
checkpoint: Optional[str] = None,
device: str = 'cuda:0',
options: Optional[dict] = None) -> nn.Module:
"""Initialize a model from config file.
Args:
config (str or :obj:`mmengine.Config`): Config file path or the config
object.
checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights.
device (str): The device where the anchors will be put on.
Defaults to cuda:0.
options (dict): Options to override some settings in the used config.
Returns:
nn.Module: The initialized model.
"""
if isinstance(config, str):
config = Config.fromfile(config)
elif not isinstance(config, Config):
raise TypeError('config must be a filename or Config object, '
f'but got {type(config)}')
if options is not None:
config.merge_from_dict(options)
config.model.pretrained = None
config.model.setdefault('data_preprocessor',
config.get('data_preprocessor', None))
model = build_algorithm(config.model)
if checkpoint is not None:
# Mapping the weights to GPU may cause unexpected video memory leak
# which refers to https://github.com/open-mmlab/mmdetection/pull/6405
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
model.cfg = config # save the config in the model for convenience
model.to(device)
model.eval()
return model
def inference_model(model: nn.Module,
img: Union[str, np.ndarray]) -> SelfSupDataSample:
"""Inference an image with the mmselfsup model.
Args:
model (nn.Module): The loaded model.
img (Union[str, ndarray]):
Either image path or loaded image.
Returns:
SelfSupDataSample: Output of model inference.
"""
cfg = model.cfg
# build the data pipeline
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
if isinstance(img, str):
if test_pipeline_cfg[0]['type'] != 'LoadImageFromFile':
test_pipeline_cfg.insert(0, dict(type='LoadImageFromFile'))
data = dict(img_path=img)
else:
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
test_pipeline_cfg.pop(0)
data = dict(img=img)
test_pipeline = Compose(test_pipeline_cfg)
data = test_pipeline(data)
data = default_collate([data])
# forward the model
with torch.no_grad():
results = model.test_step(data)
return results

View File

@ -1,7 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
import torch
from mmengine.structures import BaseDataElement
from mmselfsup.registry import MODELS
from mmselfsup.structures import SelfSupDataSample
@ -30,6 +31,40 @@ class MAE(BaseModel):
pred = self.neck(latent, ids_restore)
return pred
def predict(self,
inputs: List[torch.Tensor],
data_samples: Optional[List[SelfSupDataSample]] = None,
**kwargs) -> SelfSupDataSample:
"""The forward function in testing. It is mainly for image
reconstruction.
Args:
inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns:
SelfSupDataSample: The prediction from model.
"""
latent, mask, ids_restore = self.backbone(inputs[0])
pred = self.neck(latent, ids_restore)
pred = self.head.unpatchify(pred)
pred = torch.einsum('nchw->nhwc', pred).detach().cpu()
mask = mask.detach()
mask = mask.unsqueeze(-1).repeat(1, 1, self.head.patch_size**2 *
3) # (N, H*W, p*p*3)
mask = self.head.unpatchify(mask) # 1 is removing, 0 is keeping
mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
results = SelfSupDataSample()
results.mask = BaseDataElement(**dict(value=mask))
results.pred = BaseDataElement(**dict(value=pred))
return results
def loss(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> Dict[str, torch.Tensor]:

View File

@ -43,6 +43,23 @@ class MAEPretrainHead(BaseModule):
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
return x
def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
"""Combine non-overlapped patches into images.
Args:
x (torch.Tensor): The shape is (N, L, patch_size**2 *3)
Returns:
imgs (torch.Tensor): The shape is (N, 3, H, W)
"""
p = self.patch_size
h = w = int(x.shape[1]**.5)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
return imgs
def construct_target(self, target: torch.Tensor) -> torch.Tensor:
"""Construct the reconstruction target.

View File

@ -0,0 +1,62 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import platform
from typing import List, Optional
import pytest
import torch
import torch.nn as nn
from mmengine.config import Config
from mmselfsup.apis import inference_model
from mmselfsup.models import BaseModel
from mmselfsup.structures import SelfSupDataSample
from mmselfsup.utils import register_all_modules
backbone = dict(
type='ResNet',
depth=18,
in_channels=2,
out_indices=[4], # 0: conv-1, x: stage-x
norm_cfg=dict(type='BN'))
class ExampleModel(BaseModel):
def __init__(self, backbone=backbone):
super(ExampleModel, self).__init__(backbone=backbone)
self.layer = nn.Linear(1, 1)
def predict(self,
inputs: List[torch.Tensor],
data_samples: Optional[List[SelfSupDataSample]] = None,
**kwargs) -> SelfSupDataSample:
out = self.layer(inputs[0])
return out
@pytest.mark.skipif(platform.system() == 'Windows', reason='')
def test_inference_model():
register_all_modules()
# Specify the data settings
cfg = Config.fromfile(
'configs/selfsup/relative_loc/relative-loc_resnet50_8xb64-steplr-70e_in1k.py' # noqa: E501
)
# Build the algorithm
model = ExampleModel()
model.cfg = cfg
model.cfg.test_dataloader = dict(
dataset=dict(pipeline=[
dict(
type='LoadImageFromFile',
file_client_args=dict(backend='disk')),
dict(type='Resize', scale=(1, 1)),
dict(type='PackSelfSupInputs', meta_keys=['img_path'])
]))
img_path = osp.join(osp.dirname(__file__), '..', 'data', 'color.jpg')
# inference model
out = inference_model(model, img_path)
assert out.size() == torch.Size([1, 3, 1, 1])

View File

@ -50,3 +50,7 @@ def test_mae():
fake_feats = alg(fake_batch_inputs, fake_data_samples, mode='tensor')
assert list(fake_feats.shape) == [2, 196, 768]
results = alg(fake_batch_inputs, fake_data_samples, mode='predict')
assert list(results.mask.value.shape) == [2, 224, 224, 3]
assert list(results.pred.value.shape) == [2, 224, 224, 3]

View File

@ -0,0 +1,108 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Modified from https://colab.research.google.com/github/facebookresearch/mae
# /blob/main/demo/mae_visualize.ipynb
from argparse import ArgumentParser
from typing import Tuple
import matplotlib.pyplot as plt
import numpy as np
import torch
from mmengine.dataset import Compose, default_collate
from mmselfsup.apis import inference_model, init_model
from mmselfsup.registry import MODELS
from mmselfsup.utils import register_all_modules
imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_std = np.array([0.229, 0.224, 0.225])
def show_image(image: torch.Tensor, title: str = '') -> None:
# image is [H, W, 3]
assert image.shape[2] == 3
image = torch.clip((image * imagenet_std + imagenet_mean) * 255, 0,
255).int()
plt.imshow(image)
plt.title(title, fontsize=16)
plt.axis('off')
return
def save_images(x: torch.Tensor, im_masked: torch.Tensor, y: torch.Tensor,
im_paste: torch.Tensor, out_file: str) -> None:
# make the plt figure larger
plt.rcParams['figure.figsize'] = [24, 6]
plt.subplot(1, 4, 1)
show_image(x, 'original')
plt.subplot(1, 4, 2)
show_image(im_masked, 'masked')
plt.subplot(1, 4, 3)
show_image(y, 'reconstruction')
plt.subplot(1, 4, 4)
show_image(im_paste, 'reconstruction + visible')
plt.savefig(out_file)
def post_process(
x: torch.Tensor, y: torch.Tensor, mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
x = torch.einsum('nchw->nhwc', x.cpu())
# masked image
im_masked = x * (1 - mask)
# MAE reconstruction pasted with visible patches
im_paste = x * (1 - mask) + y * mask
return x[0], im_masked[0], y[0], im_paste[0]
def main():
parser = ArgumentParser()
parser.add_argument('img_path', help='Image file path')
parser.add_argument('config', help='MAE Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument('out_file', help='The output image file path')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
args = parser.parse_args()
register_all_modules()
# build the model from a config file and a checkpoint file
model = init_model(args.config, args.checkpoint, device=args.device)
print('Model loaded.')
# make random mask reproducible (comment out to make it change)
torch.manual_seed(2)
print('MAE with pixel reconstruction:')
model.cfg.test_dataloader = dict(
dataset=dict(pipeline=[
dict(
type='LoadImageFromFile',
file_client_args=dict(backend='disk')),
dict(type='Resize', scale=(224, 224)),
dict(type='PackSelfSupInputs', meta_keys=['img_path'])
]))
results = inference_model(model, args.img_path)
cfg = model.cfg
test_pipeline = Compose(cfg.test_dataloader.dataset.pipeline)
data_preprocessor = MODELS.build(cfg.model.data_preprocessor)
data = dict(img_path=args.img_path)
data = test_pipeline(data)
data = default_collate([data])
img, _ = data_preprocessor(data, False)
x, im_masked, y, im_paste = post_process(img[0], results.pred.value,
results.mask.value)
save_images(x, im_masked, y, im_paste, args.out_file)
if __name__ == '__main__':
main()