Copyright (c) OpenMMLab. All rights reserved.

Copyright (c) Meta Platforms, Inc. and affiliates.

Modified from https://colab.research.google.com/github/facebookresearch/mae/blob/main/demo/mae_visualize.ipynb

## Masked Autoencoders: Visualization Demo

This is a visualization demo using our pre-trained MAE models. No GPU is needed.

### Prepare
Check environment. Install packages if in Colab.

In [1]:
import sys

# check whether run in Colab
if 'google.colab' in sys.modules:
 print('Running in Colab.')
 !pip3 install openmim
 !pip install -U openmim
 !mim install 'mmengine==0.1.0' 'mmcv>=2.0.0rc1'

 !git clone https://github.com/open-mmlab/mmselfsup.git
 %cd mmselfsup/
 !git checkout dev-1.x
 !pip install -e .

 sys.path.append('./mmselfsup')
 %cd demo
else:
 sys.path.append('..')

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from mmengine.dataset import Compose, default_collate

from mmselfsup.apis import inference_model
from mmselfsup.models.utils import SelfSupDataPreprocessor
from mmselfsup.registry import MODELS
from mmselfsup.utils import register_all_modules

### Define utils

In [3]:
# define the utils

imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_std = np.array([0.229, 0.224, 0.225])

def show_image(image, title=''):
 # image is [H, W, 3]
 assert image.shape[2] == 3
 image = torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int()
 plt.imshow(image)
 plt.title(title, fontsize=16)
 plt.axis('off')
 return


def show_images(x, im_masked, y, im_paste):
 # make the plt figure larger
 plt.rcParams['figure.figsize'] = [24, 6]

 plt.subplot(1, 4, 1)
 show_image(x, "original")

 plt.subplot(1, 4, 2)
 show_image(im_masked, "masked")

 plt.subplot(1, 4, 3)
 show_image(y, "reconstruction")

 plt.subplot(1, 4, 4)
 show_image(im_paste, "reconstruction + visible")

 plt.show()


def post_process(x, y, mask):
 x = torch.einsum('nchw->nhwc', x.cpu())
 # masked image
 im_masked = x * (1 - mask)
 # MAE reconstruction pasted with visible patches
 im_paste = x * (1 - mask) + y * mask
 return x[0], im_masked[0], y[0], im_paste[0]

### Prepare config file

In [4]:
%%writefile ../configs/selfsup/mae/mae_visualization.py
model = dict(
 type='MAE',
 data_preprocessor=dict(
 mean=[123.675, 116.28, 103.53],
 std=[58.395, 57.12, 57.375],
 bgr_to_rgb=True),
 backbone=dict(type='MAEViT', arch='l', patch_size=16, mask_ratio=0.75),
 neck=dict(
 type='MAEPretrainDecoder',
 patch_size=16,
 in_chans=3,
 embed_dim=1024,
 decoder_embed_dim=512,
 decoder_depth=8,
 decoder_num_heads=16,
 mlp_ratio=4.,
 ),
 head=dict(
 type='MAEPretrainHead',
 norm_pix=True,
 patch_size=16,
 loss=dict(type='MAEReconstructionLoss')),
 init_cfg=[
 dict(type='Xavier', distribution='uniform', layer='Linear'),
 dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0)
 ])

file_client_args = dict(backend='disk')

# dataset summary
test_dataloader = dict(
 dataset=dict(pipeline=[
 dict(type='LoadImageFromFile', file_client_args=file_client_args),
 dict(type='Resize', scale=(224, 224)),
 dict(type='PackSelfSupInputs', meta_keys=['img_path'])
 ]))

Overwriting ../configs/selfsup/mae/mae_visualization.py


### Load a pre-trained MAE model

In [5]:
# This is an MAE model trained with pixels as targets for visualization (ViT-large, training mask ratio=0.75)

# download checkpoint if not exist
# This ckpt is converted from https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth
!wget -nc https://download.openmmlab.com/mmselfsup/mae/mae_visualize_vit_large.pth

--2022-09-03 00:34:55-- https://download.openmmlab.com/mmselfsup/mae/mae_visualize_vit_large.pth
正在解析主机 download.openmmlab.com (download.openmmlab.com)... 47.107.10.247
正在连接 download.openmmlab.com (download.openmmlab.com)|47.107.10.247|:443... 已连接。
已发出 HTTP 请求,正在等待回应... 200 OK
长度: 1318299501 (1.2G) [application/octet-stream]
正在保存至: “mae_visualize_vit_large.pth”


2022-09-03 00:40:59 (3.46 MB/s) - 已保存 “mae_visualize_vit_large.pth” [1318299501/1318299501])



In [6]:
from mmselfsup.apis import init_model
ckpt_path = "mae_visualize_vit_large.pth"
model = init_model('../configs/selfsup/mae/mae_visualization.py', ckpt_path, device='cpu')
print('Model loaded.')

local loads checkpoint from path: mae_visualize_vit_large.pth
Model loaded.


### Load an image

In [7]:
# make random mask reproducible (comment out to make it change)
register_all_modules()
torch.manual_seed(2)



In [8]:
!wget -nc 'https://download.openmmlab.com/mmselfsup/mae/fox.jpg'

--2022-09-03 00:41:01-- https://download.openmmlab.com/mmselfsup/mae/fox.jpg
正在解析主机 download.openmmlab.com (download.openmmlab.com)... 101.133.111.186
正在连接 download.openmmlab.com (download.openmmlab.com)|101.133.111.186|:443... 已连接。
已发出 HTTP 请求,正在等待回应... 200 OK
长度: 60133 (59K) [image/jpeg]
正在保存至: “fox.jpg”


2022-09-03 00:41:01 (962 KB/s) - 已保存 “fox.jpg” [60133/60133])



In [9]:
img_path = 'fox.jpg'

In [10]:
cfg = model.cfg
test_pipeline = Compose(cfg.test_dataloader.dataset.pipeline)
data_preprocessor = MODELS.build(cfg.model.data_preprocessor)

In [11]:
data = dict(img_path=img_path)
data = test_pipeline(data)
data = default_collate([data])
img, _ = data_preprocessor(data, False)

In [None]:
plt.rcParams['figure.figsize'] = [5, 5]
show_image(torch.einsum('nchw->nhwc', img[0].cpu())[0])

### Run MAE on the image

In [13]:
results = inference_model(model, img_path)
x, im_masked, y, im_paste = post_process(img[0], results.pred.value, results.mask.value)

In [None]:
print('MAE with pixel reconstruction:')
show_images(x, im_masked, y, im_paste)