mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
* [Refactor]: refactor MAE visualization * [Fix]: fix lint * [Refactor]: refactor MAE visualization * [Feature]: add mae_visualization.py * [UT]: add unit test * [Refactor]: move mae_visualization.py to tools/analysis_tools * [Docs]: Add the purpose of the function unpatchify() * [Fix]: fix lint
109 lines
3.4 KiB
Python
109 lines
3.4 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# Modified from https://colab.research.google.com/github/facebookresearch/mae
|
|
# /blob/main/demo/mae_visualize.ipynb
|
|
from argparse import ArgumentParser
|
|
from typing import Tuple
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import torch
|
|
from mmengine.dataset import Compose, default_collate
|
|
|
|
from mmselfsup.apis import inference_model, init_model
|
|
from mmselfsup.registry import MODELS
|
|
from mmselfsup.utils import register_all_modules
|
|
|
|
imagenet_mean = np.array([0.485, 0.456, 0.406])
|
|
imagenet_std = np.array([0.229, 0.224, 0.225])
|
|
|
|
|
|
def show_image(image: torch.Tensor, title: str = '') -> None:
|
|
# image is [H, W, 3]
|
|
assert image.shape[2] == 3
|
|
image = torch.clip((image * imagenet_std + imagenet_mean) * 255, 0,
|
|
255).int()
|
|
plt.imshow(image)
|
|
plt.title(title, fontsize=16)
|
|
plt.axis('off')
|
|
return
|
|
|
|
|
|
def save_images(x: torch.Tensor, im_masked: torch.Tensor, y: torch.Tensor,
|
|
im_paste: torch.Tensor, out_file: str) -> None:
|
|
# make the plt figure larger
|
|
plt.rcParams['figure.figsize'] = [24, 6]
|
|
|
|
plt.subplot(1, 4, 1)
|
|
show_image(x, 'original')
|
|
|
|
plt.subplot(1, 4, 2)
|
|
show_image(im_masked, 'masked')
|
|
|
|
plt.subplot(1, 4, 3)
|
|
show_image(y, 'reconstruction')
|
|
|
|
plt.subplot(1, 4, 4)
|
|
show_image(im_paste, 'reconstruction + visible')
|
|
|
|
plt.savefig(out_file)
|
|
|
|
|
|
def post_process(
|
|
x: torch.Tensor, y: torch.Tensor, mask: torch.Tensor
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
x = torch.einsum('nchw->nhwc', x.cpu())
|
|
# masked image
|
|
im_masked = x * (1 - mask)
|
|
# MAE reconstruction pasted with visible patches
|
|
im_paste = x * (1 - mask) + y * mask
|
|
return x[0], im_masked[0], y[0], im_paste[0]
|
|
|
|
|
|
def main():
|
|
parser = ArgumentParser()
|
|
parser.add_argument('img_path', help='Image file path')
|
|
parser.add_argument('config', help='MAE Config file')
|
|
parser.add_argument('checkpoint', help='Checkpoint file')
|
|
parser.add_argument('out_file', help='The output image file path')
|
|
parser.add_argument(
|
|
'--device', default='cuda:0', help='Device used for inference')
|
|
args = parser.parse_args()
|
|
|
|
register_all_modules()
|
|
|
|
# build the model from a config file and a checkpoint file
|
|
model = init_model(args.config, args.checkpoint, device=args.device)
|
|
print('Model loaded.')
|
|
|
|
# make random mask reproducible (comment out to make it change)
|
|
torch.manual_seed(2)
|
|
print('MAE with pixel reconstruction:')
|
|
|
|
model.cfg.test_dataloader = dict(
|
|
dataset=dict(pipeline=[
|
|
dict(
|
|
type='LoadImageFromFile',
|
|
file_client_args=dict(backend='disk')),
|
|
dict(type='Resize', scale=(224, 224)),
|
|
dict(type='PackSelfSupInputs', meta_keys=['img_path'])
|
|
]))
|
|
|
|
results = inference_model(model, args.img_path)
|
|
|
|
cfg = model.cfg
|
|
test_pipeline = Compose(cfg.test_dataloader.dataset.pipeline)
|
|
data_preprocessor = MODELS.build(cfg.model.data_preprocessor)
|
|
data = dict(img_path=args.img_path)
|
|
data = test_pipeline(data)
|
|
data = default_collate([data])
|
|
img, _ = data_preprocessor(data, False)
|
|
|
|
x, im_masked, y, im_paste = post_process(img[0], results.pred.value,
|
|
results.mask.value)
|
|
save_images(x, im_masked, y, im_paste, args.out_file)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|