# Copyright (c) OpenMMLab. All rights reserved.
# Modified from https://colab.research.google.com/github/facebookresearch/mae
# /blob/main/demo/mae_visualize.ipynb
import random
from argparse import ArgumentParser
from typing import Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
from mmengine.dataset import Compose, default_collate

from mmselfsup.apis import inference_model, init_model
from mmselfsup.utils import register_all_modules

imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_std = np.array([0.229, 0.224, 0.225])


def show_image(img: torch.Tensor, title: str = '') -> None:
    # image is [H, W, 3]
    assert img.shape[2] == 3

    plt.imshow(img)
    plt.title(title, fontsize=16)
    plt.axis('off')
    return


def save_images(original_img: torch.Tensor, img_masked: torch.Tensor,
                pred_img: torch.Tensor, img_paste: torch.Tensor,
                out_file: str) -> None:
    # make the plt figure larger
    plt.rcParams['figure.figsize'] = [24, 6]

    plt.subplot(1, 4, 1)
    show_image(original_img, 'original')

    plt.subplot(1, 4, 2)
    show_image(img_masked, 'masked')

    plt.subplot(1, 4, 3)
    show_image(pred_img, 'reconstruction')

    plt.subplot(1, 4, 4)
    show_image(img_paste, 'reconstruction + visible')

    plt.savefig(out_file)
    print(f'Images are saved to {out_file}')


def recover_norm(img: torch.Tensor,
                 mean: np.ndarray = imagenet_mean,
                 std: np.ndarray = imagenet_std):
    if mean is not None and std is not None:
        img = torch.clip((img * std + mean) * 255, 0, 255).int()
    return img


def post_process(
    original_img: torch.Tensor,
    pred_img: torch.Tensor,
    mask: torch.Tensor,
    mean: np.ndarray = imagenet_mean,
    std: np.ndarray = imagenet_std
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    # channel conversion
    original_img = torch.einsum('nchw->nhwc', original_img.cpu())
    # masked image
    img_masked = original_img * (1 - mask)
    # reconstructed image pasted with visible patches
    img_paste = original_img * (1 - mask) + pred_img * mask

    # muptiply std and add mean to each image
    original_img = recover_norm(original_img[0])
    img_masked = recover_norm(img_masked[0])

    pred_img = recover_norm(pred_img[0])
    img_paste = recover_norm(img_paste[0])

    return original_img, img_masked, pred_img, img_paste


def main():
    parser = ArgumentParser()
    parser.add_argument('config', help='Model config file')
    parser.add_argument('--checkpoint', help='Checkpoint file')
    parser.add_argument('--img-path', help='Image file path')
    parser.add_argument('--out-file', help='The output image file path')
    parser.add_argument(
        '--use-vis-pipeline',
        action='store_true',
        help='Use vis_pipeline defined in config. For some algorithms, such '
        'as SimMIM and MaskFeat, they generate mask in data pipeline, thus '
        'the visualization process applies vis_pipeline in config to obtain '
        'the mask.')
    parser.add_argument(
        '--norm-pix',
        action='store_true',
        help='MAE uses `norm_pix_loss` for optimization in pre-training, thus '
        'the visualization process also need to compute mean and std of each '
        'patch embedding while reconstructing the original images.')
    parser.add_argument(
        '--target-generator',
        action='store_true',
        help='Some algorithms use target_generator for optimization in '
        'pre-training, such as MaskFeat, thus the visualization process could '
        'turn this on to visualize the target instead of RGB image.')
    parser.add_argument(
        '--device', default='cuda:0', help='Device used for inference')
    parser.add_argument(
        '--seed',
        type=int,
        default=0,
        help='The random seed for visualization')
    args = parser.parse_args()

    register_all_modules()

    # build the model from a config file and a checkpoint file
    model = init_model(args.config, args.checkpoint, device=args.device)
    print('Model loaded.')

    # make random mask reproducible (comment out to make it change)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    print('Reconstruction visualization.')

    if args.use_vis_pipeline:
        model.cfg.test_dataloader = dict(
            dataset=dict(pipeline=model.cfg.vis_pipeline))
    else:
        model.cfg.test_dataloader = dict(
            dataset=dict(pipeline=[
                dict(
                    type='LoadImageFromFile',
                    file_client_args=dict(backend='disk')),
                dict(type='Resize', scale=(224, 224), backend='pillow'),
                dict(type='PackSelfSupInputs', meta_keys=['img_path'])
            ]))

    # get original image
    vis_pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline)
    data = dict(img_path=args.img_path)
    data = vis_pipeline(data)
    data = default_collate([data])
    img, _ = model.data_preprocessor(data, False)

    if args.norm_pix:
        # for MAE reconstruction
        img_embedding = model.head.patchify(img[0])
        # normalize the target image
        mean = img_embedding.mean(dim=-1, keepdim=True)
        std = (img_embedding.var(dim=-1, keepdim=True) + 1.e-6)**.5
    else:
        mean = imagenet_mean
        std = imagenet_std

    # get reconstruction image
    features = inference_model(model, args.img_path)
    results = model.reconstruct(features, mean=mean, std=std)

    original_target = model.target if args.target_generator else img[0]

    original_img, img_masked, pred_img, img_paste = post_process(
        original_target,
        results.pred.value,
        results.mask.value,
        mean=mean,
        std=std)

    save_images(original_img, img_masked, pred_img, img_paste, args.out_file)


if __name__ == '__main__':
    main()