mmselfsup/tools/analysis_tools/mae_visualization.py
RenQin 41747f73c7 [Refactor]: refactor MAE visualization (#471)
* [Refactor]: refactor MAE visualization

* [Fix]: fix lint

* [Refactor]: refactor MAE visualization

* [Feature]: add mae_visualization.py

* [UT]: add unit test

* [Refactor]: move mae_visualization.py to tools/analysis_tools

* [Docs]: Add the purpose of the function unpatchify()

* [Fix]: fix lint
2022-11-03 16:09:36 +08:00

109 lines
3.4 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Modified from https://colab.research.google.com/github/facebookresearch/mae
# /blob/main/demo/mae_visualize.ipynb
from argparse import ArgumentParser
from typing import Tuple
import matplotlib.pyplot as plt
import numpy as np
import torch
from mmengine.dataset import Compose, default_collate
from mmselfsup.apis import inference_model, init_model
from mmselfsup.registry import MODELS
from mmselfsup.utils import register_all_modules
imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_std = np.array([0.229, 0.224, 0.225])
def show_image(image: torch.Tensor, title: str = '') -> None:
# image is [H, W, 3]
assert image.shape[2] == 3
image = torch.clip((image * imagenet_std + imagenet_mean) * 255, 0,
255).int()
plt.imshow(image)
plt.title(title, fontsize=16)
plt.axis('off')
return
def save_images(x: torch.Tensor, im_masked: torch.Tensor, y: torch.Tensor,
im_paste: torch.Tensor, out_file: str) -> None:
# make the plt figure larger
plt.rcParams['figure.figsize'] = [24, 6]
plt.subplot(1, 4, 1)
show_image(x, 'original')
plt.subplot(1, 4, 2)
show_image(im_masked, 'masked')
plt.subplot(1, 4, 3)
show_image(y, 'reconstruction')
plt.subplot(1, 4, 4)
show_image(im_paste, 'reconstruction + visible')
plt.savefig(out_file)
def post_process(
x: torch.Tensor, y: torch.Tensor, mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
x = torch.einsum('nchw->nhwc', x.cpu())
# masked image
im_masked = x * (1 - mask)
# MAE reconstruction pasted with visible patches
im_paste = x * (1 - mask) + y * mask
return x[0], im_masked[0], y[0], im_paste[0]
def main():
parser = ArgumentParser()
parser.add_argument('img_path', help='Image file path')
parser.add_argument('config', help='MAE Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument('out_file', help='The output image file path')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
args = parser.parse_args()
register_all_modules()
# build the model from a config file and a checkpoint file
model = init_model(args.config, args.checkpoint, device=args.device)
print('Model loaded.')
# make random mask reproducible (comment out to make it change)
torch.manual_seed(2)
print('MAE with pixel reconstruction:')
model.cfg.test_dataloader = dict(
dataset=dict(pipeline=[
dict(
type='LoadImageFromFile',
file_client_args=dict(backend='disk')),
dict(type='Resize', scale=(224, 224)),
dict(type='PackSelfSupInputs', meta_keys=['img_path'])
]))
results = inference_model(model, args.img_path)
cfg = model.cfg
test_pipeline = Compose(cfg.test_dataloader.dataset.pipeline)
data_preprocessor = MODELS.build(cfg.model.data_preprocessor)
data = dict(img_path=args.img_path)
data = test_pipeline(data)
data = default_collate([data])
img, _ = data_preprocessor(data, False)
x, im_masked, y, im_paste = post_process(img[0], results.pred.value,
results.mask.value)
save_images(x, im_masked, y, im_paste, args.out_file)
if __name__ == '__main__':
main()