import torch
import numpy as np
import cv2
import tempfile
import matplotlib.pyplot as plt
from cog import BasePredictor, Path, Input, BaseModel

from basicsr.models import create_model
from basicsr.utils import img2tensor as _img2tensor, tensor2img, imwrite
from basicsr.utils.options import parse


class Predictor(BasePredictor):
    def setup(self):
        opt_path_denoise = "options/test/SIDD/NAFNet-width64.yml"
        opt_denoise = parse(opt_path_denoise, is_train=False)
        opt_denoise["dist"] = False

        opt_path_deblur = "options/test/GoPro/NAFNet-width64.yml"
        opt_deblur = parse(opt_path_deblur, is_train=False)
        opt_deblur["dist"] = False

        opt_path_stereo = "options/test/NAFSSR/NAFSSR-L_4x.yml"
        opt_stereo = parse(opt_path_stereo, is_train=False)
        opt_stereo["dist"] = False

        self.models = {
            "Image Denoising": create_model(opt_denoise),
            "Image Debluring": create_model(opt_deblur),
            "Stereo Image Super-Resolution": create_model(opt_stereo),
        }

    def predict(
        self,
        task_type: str = Input(
            choices=[
                "Image Denoising",
                "Image Debluring",
                "Stereo Image Super-Resolution",
            ],
            default="Image Debluring",
            description="Choose task type.",
        ),
        image: Path = Input(
            description="Input image. Stereo Image Super-Resolution, upload the left image here.",
        ),
        image_r: Path = Input(
            default=None,
            description="Right Input image for Stereo Image Super-Resolution. Optional, only valid for Stereo"
            " Image Super-Resolution task.",
        ),
    ) -> Path:

        out_path = Path(tempfile.mkdtemp()) / "output.png"

        model = self.models[task_type]
        if task_type == "Stereo Image Super-Resolution":
            assert image_r is not None, (
                "Please provide both left and right input image for "
                "Stereo Image Super-Resolution task."
            )

            img_l = imread(str(image))
            inp_l = img2tensor(img_l)
            img_r = imread(str(image_r))
            inp_r = img2tensor(img_r)
            stereo_image_inference(model, inp_l, inp_r, str(out_path))

        else:

            img_input = imread(str(image))
            inp = img2tensor(img_input)
            out_path = Path(tempfile.mkdtemp()) / "output.png"
            single_image_inference(model, inp, str(out_path))

        return out_path


def imread(img_path):
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img


def img2tensor(img, bgr2rgb=False, float32=True):
    img = img.astype(np.float32) / 255.0
    return _img2tensor(img, bgr2rgb=bgr2rgb, float32=float32)


def single_image_inference(model, img, save_path):
    model.feed_data(data={"lq": img.unsqueeze(dim=0)})

    if model.opt["val"].get("grids", False):
        model.grids()

    model.test()

    if model.opt["val"].get("grids", False):
        model.grids_inverse()

    visuals = model.get_current_visuals()
    sr_img = tensor2img([visuals["result"]])
    imwrite(sr_img, save_path)


def stereo_image_inference(model, img_l, img_r, out_path):
    img = torch.cat([img_l, img_r], dim=0)
    model.feed_data(data={"lq": img.unsqueeze(dim=0)})

    if model.opt["val"].get("grids", False):
        model.grids()

    model.test()

    if model.opt["val"].get("grids", False):
        model.grids_inverse()

    visuals = model.get_current_visuals()
    img_L = visuals["result"][:, :3]
    img_R = visuals["result"][:, 3:]
    img_L, img_R = tensor2img([img_L, img_R], rgb2bgr=False)

    # save_stereo_image
    h, w = img_L.shape[:2]
    fig = plt.figure(figsize=(w // 40, h // 40))
    ax1 = fig.add_subplot(2, 1, 1)
    plt.title("NAFSSR output (Left)", fontsize=14)
    ax1.axis("off")
    ax1.imshow(img_L)

    ax2 = fig.add_subplot(2, 1, 2)
    plt.title("NAFSSR output (Right)", fontsize=14)
    ax2.axis("off")
    ax2.imshow(img_R)

    plt.subplots_adjust(hspace=0.08)
    plt.savefig(str(out_path), bbox_inches="tight", dpi=600)