mmselfsup/tools/analysis_tools/visualize_reconstruction.py

179 lines
5.8 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
# Modified from https://colab.research.google.com/github/facebookresearch/mae
# /blob/main/demo/mae_visualize.ipynb
import random
from argparse import ArgumentParser
from typing import Tuple
import matplotlib.pyplot as plt
import numpy as np
import torch
from mmengine.dataset import Compose, default_collate
from mmselfsup.apis import inference_model, init_model
from mmselfsup.utils import register_all_modules
imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_std = np.array([0.229, 0.224, 0.225])
def show_image(img: torch.Tensor, title: str = '') -> None:
# image is [H, W, 3]
assert img.shape[2] == 3
plt.imshow(img)
plt.title(title, fontsize=16)
plt.axis('off')
return
def save_images(original_img: torch.Tensor, img_masked: torch.Tensor,
pred_img: torch.Tensor, img_paste: torch.Tensor,
out_file: str) -> None:
# make the plt figure larger
plt.rcParams['figure.figsize'] = [24, 6]
plt.subplot(1, 4, 1)
show_image(original_img, 'original')
plt.subplot(1, 4, 2)
show_image(img_masked, 'masked')
plt.subplot(1, 4, 3)
show_image(pred_img, 'reconstruction')
plt.subplot(1, 4, 4)
show_image(img_paste, 'reconstruction + visible')
plt.savefig(out_file)
print(f'Images are saved to {out_file}')
def recover_norm(img: torch.Tensor,
mean: np.ndarray = imagenet_mean,
std: np.ndarray = imagenet_std):
if mean is not None and std is not None:
img = torch.clip((img * std + mean) * 255, 0, 255).int()
return img
def post_process(
original_img: torch.Tensor,
pred_img: torch.Tensor,
mask: torch.Tensor,
mean: np.ndarray = imagenet_mean,
std: np.ndarray = imagenet_std
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# channel conversion
original_img = torch.einsum('nchw->nhwc', original_img.cpu())
# masked image
img_masked = original_img * (1 - mask)
# reconstructed image pasted with visible patches
img_paste = original_img * (1 - mask) + pred_img * mask
# muptiply std and add mean to each image
original_img = recover_norm(original_img[0])
img_masked = recover_norm(img_masked[0])
pred_img = recover_norm(pred_img[0])
img_paste = recover_norm(img_paste[0])
return original_img, img_masked, pred_img, img_paste
def main():
parser = ArgumentParser()
parser.add_argument('config', help='Model config file')
parser.add_argument('--checkpoint', help='Checkpoint file')
parser.add_argument('--img-path', help='Image file path')
parser.add_argument('--out-file', help='The output image file path')
parser.add_argument(
'--use-vis-pipeline',
action='store_true',
help='Use vis_pipeline defined in config. For some algorithms, such '
'as SimMIM and MaskFeat, they generate mask in data pipeline, thus '
'the visualization process applies vis_pipeline in config to obtain '
'the mask.')
parser.add_argument(
'--norm-pix',
action='store_true',
help='MAE uses `norm_pix_loss` for optimization in pre-training, thus '
'the visualization process also need to compute mean and std of each '
'patch embedding while reconstructing the original images.')
parser.add_argument(
'--target-generator',
action='store_true',
help='Some algorithms use target_generator for optimization in '
'pre-training, such as MaskFeat, thus the visualization process could '
'turn this on to visualize the target instead of RGB image.')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--seed',
type=int,
default=0,
help='The random seed for visualization')
args = parser.parse_args()
register_all_modules()
# build the model from a config file and a checkpoint file
model = init_model(args.config, args.checkpoint, device=args.device)
print('Model loaded.')
# make random mask reproducible (comment out to make it change)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
print('Reconstruction visualization.')
if args.use_vis_pipeline:
model.cfg.test_dataloader = dict(
dataset=dict(pipeline=model.cfg.vis_pipeline))
else:
model.cfg.test_dataloader = dict(
dataset=dict(pipeline=[
dict(
type='LoadImageFromFile',
file_client_args=dict(backend='disk')),
dict(type='Resize', scale=(224, 224), backend='pillow'),
dict(type='PackSelfSupInputs', meta_keys=['img_path'])
]))
# get original image
vis_pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline)
data = dict(img_path=args.img_path)
data = vis_pipeline(data)
data = default_collate([data])
img, _ = model.data_preprocessor(data, False)
if args.norm_pix:
# for MAE reconstruction
img_embedding = model.head.patchify(img[0])
# normalize the target image
mean = img_embedding.mean(dim=-1, keepdim=True)
std = (img_embedding.var(dim=-1, keepdim=True) + 1.e-6)**.5
else:
mean = imagenet_mean
std = imagenet_std
# get reconstruction image
features = inference_model(model, args.img_path)
results = model.reconstruct(features, mean=mean, std=std)
original_target = model.target if args.target_generator else img[0]
original_img, img_masked, pred_img, img_paste = post_process(
original_target,
results.pred.value,
results.mask.value,
mean=mean,
std=std)
save_images(original_img, img_masked, pred_img, img_paste, args.out_file)
if __name__ == '__main__':
main()