179 lines
5.8 KiB
Python
179 lines
5.8 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
# Modified from https://colab.research.google.com/github/facebookresearch/mae
|
|
# /blob/main/demo/mae_visualize.ipynb
|
|
import random
|
|
from argparse import ArgumentParser
|
|
from typing import Tuple
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import torch
|
|
from mmengine.dataset import Compose, default_collate
|
|
|
|
from mmselfsup.apis import inference_model, init_model
|
|
from mmselfsup.utils import register_all_modules
|
|
|
|
imagenet_mean = np.array([0.485, 0.456, 0.406])
|
|
imagenet_std = np.array([0.229, 0.224, 0.225])
|
|
|
|
|
|
def show_image(img: torch.Tensor, title: str = '') -> None:
|
|
# image is [H, W, 3]
|
|
assert img.shape[2] == 3
|
|
|
|
plt.imshow(img)
|
|
plt.title(title, fontsize=16)
|
|
plt.axis('off')
|
|
return
|
|
|
|
|
|
def save_images(original_img: torch.Tensor, img_masked: torch.Tensor,
|
|
pred_img: torch.Tensor, img_paste: torch.Tensor,
|
|
out_file: str) -> None:
|
|
# make the plt figure larger
|
|
plt.rcParams['figure.figsize'] = [24, 6]
|
|
|
|
plt.subplot(1, 4, 1)
|
|
show_image(original_img, 'original')
|
|
|
|
plt.subplot(1, 4, 2)
|
|
show_image(img_masked, 'masked')
|
|
|
|
plt.subplot(1, 4, 3)
|
|
show_image(pred_img, 'reconstruction')
|
|
|
|
plt.subplot(1, 4, 4)
|
|
show_image(img_paste, 'reconstruction + visible')
|
|
|
|
plt.savefig(out_file)
|
|
print(f'Images are saved to {out_file}')
|
|
|
|
|
|
def recover_norm(img: torch.Tensor,
|
|
mean: np.ndarray = imagenet_mean,
|
|
std: np.ndarray = imagenet_std):
|
|
if mean is not None and std is not None:
|
|
img = torch.clip((img * std + mean) * 255, 0, 255).int()
|
|
return img
|
|
|
|
|
|
def post_process(
|
|
original_img: torch.Tensor,
|
|
pred_img: torch.Tensor,
|
|
mask: torch.Tensor,
|
|
mean: np.ndarray = imagenet_mean,
|
|
std: np.ndarray = imagenet_std
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
# channel conversion
|
|
original_img = torch.einsum('nchw->nhwc', original_img.cpu())
|
|
# masked image
|
|
img_masked = original_img * (1 - mask)
|
|
# reconstructed image pasted with visible patches
|
|
img_paste = original_img * (1 - mask) + pred_img * mask
|
|
|
|
# muptiply std and add mean to each image
|
|
original_img = recover_norm(original_img[0])
|
|
img_masked = recover_norm(img_masked[0])
|
|
|
|
pred_img = recover_norm(pred_img[0])
|
|
img_paste = recover_norm(img_paste[0])
|
|
|
|
return original_img, img_masked, pred_img, img_paste
|
|
|
|
|
|
def main():
|
|
parser = ArgumentParser()
|
|
parser.add_argument('config', help='Model config file')
|
|
parser.add_argument('--checkpoint', help='Checkpoint file')
|
|
parser.add_argument('--img-path', help='Image file path')
|
|
parser.add_argument('--out-file', help='The output image file path')
|
|
parser.add_argument(
|
|
'--use-vis-pipeline',
|
|
action='store_true',
|
|
help='Use vis_pipeline defined in config. For some algorithms, such '
|
|
'as SimMIM and MaskFeat, they generate mask in data pipeline, thus '
|
|
'the visualization process applies vis_pipeline in config to obtain '
|
|
'the mask.')
|
|
parser.add_argument(
|
|
'--norm-pix',
|
|
action='store_true',
|
|
help='MAE uses `norm_pix_loss` for optimization in pre-training, thus '
|
|
'the visualization process also need to compute mean and std of each '
|
|
'patch embedding while reconstructing the original images.')
|
|
parser.add_argument(
|
|
'--target-generator',
|
|
action='store_true',
|
|
help='Some algorithms use target_generator for optimization in '
|
|
'pre-training, such as MaskFeat, thus the visualization process could '
|
|
'turn this on to visualize the target instead of RGB image.')
|
|
parser.add_argument(
|
|
'--device', default='cuda:0', help='Device used for inference')
|
|
parser.add_argument(
|
|
'--seed',
|
|
type=int,
|
|
default=0,
|
|
help='The random seed for visualization')
|
|
args = parser.parse_args()
|
|
|
|
register_all_modules()
|
|
|
|
# build the model from a config file and a checkpoint file
|
|
model = init_model(args.config, args.checkpoint, device=args.device)
|
|
print('Model loaded.')
|
|
|
|
# make random mask reproducible (comment out to make it change)
|
|
random.seed(args.seed)
|
|
np.random.seed(args.seed)
|
|
torch.manual_seed(args.seed)
|
|
|
|
print('Reconstruction visualization.')
|
|
|
|
if args.use_vis_pipeline:
|
|
model.cfg.test_dataloader = dict(
|
|
dataset=dict(pipeline=model.cfg.vis_pipeline))
|
|
else:
|
|
model.cfg.test_dataloader = dict(
|
|
dataset=dict(pipeline=[
|
|
dict(
|
|
type='LoadImageFromFile',
|
|
file_client_args=dict(backend='disk')),
|
|
dict(type='Resize', scale=(224, 224), backend='pillow'),
|
|
dict(type='PackSelfSupInputs', meta_keys=['img_path'])
|
|
]))
|
|
|
|
# get original image
|
|
vis_pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline)
|
|
data = dict(img_path=args.img_path)
|
|
data = vis_pipeline(data)
|
|
data = default_collate([data])
|
|
img, _ = model.data_preprocessor(data, False)
|
|
|
|
if args.norm_pix:
|
|
# for MAE reconstruction
|
|
img_embedding = model.head.patchify(img[0])
|
|
# normalize the target image
|
|
mean = img_embedding.mean(dim=-1, keepdim=True)
|
|
std = (img_embedding.var(dim=-1, keepdim=True) + 1.e-6)**.5
|
|
else:
|
|
mean = imagenet_mean
|
|
std = imagenet_std
|
|
|
|
# get reconstruction image
|
|
features = inference_model(model, args.img_path)
|
|
results = model.reconstruct(features, mean=mean, std=std)
|
|
|
|
original_target = model.target if args.target_generator else img[0]
|
|
|
|
original_img, img_masked, pred_img, img_paste = post_process(
|
|
original_target,
|
|
results.pred.value,
|
|
results.mask.value,
|
|
mean=mean,
|
|
std=std)
|
|
|
|
save_images(original_img, img_masked, pred_img, img_paste, args.out_file)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|