# encoding: utf-8
"""
@author:  liaoxingyu
@contact: sherlockliao01@gmail.com
"""

import os
import pickle
import random

import matplotlib.pyplot as plt
import numpy as np
import tqdm
from scipy.stats import norm
from sklearn import metrics

from .file_io import PathManager


class Visualizer:
    r"""Visualize images(activation map) ranking list of features generated by reid models."""

    def __init__(self, dataset):
        self.dataset = dataset

    def get_model_output(self, all_ap, dist, q_pids, g_pids, q_camids, g_camids):
        self.all_ap = all_ap
        self.dist = dist
        self.sim = 1 - dist
        self.q_pids = q_pids
        self.g_pids = g_pids
        self.q_camids = q_camids
        self.g_camids = g_camids

        self.indices = np.argsort(dist, axis=1)
        self.matches = (g_pids[self.indices] == q_pids[:, np.newaxis]).astype(np.int32)

        self.num_query = len(q_pids)

    def get_matched_result(self, q_index):
        q_pid = self.q_pids[q_index]
        q_camid = self.q_camids[q_index]

        order = self.indices[q_index]
        remove = (self.g_pids[order] == q_pid) & (self.g_camids[order] == q_camid)
        keep = np.invert(remove)
        cmc = self.matches[q_index][keep]
        sort_idx = order[keep]
        return cmc, sort_idx

    def save_rank_result(self, query_indices, output, max_rank=5, vis_label=False, label_sort='ascending',
                         actmap=False):
        if vis_label:
            fig, axes = plt.subplots(2, max_rank + 1, figsize=(3 * max_rank, 12))
        else:
            fig, axes = plt.subplots(1, max_rank + 1, figsize=(3 * max_rank, 6))
        for cnt, q_idx in enumerate(tqdm.tqdm(query_indices)):
            all_imgs = []
            cmc, sort_idx = self.get_matched_result(q_idx)
            query_info = self.dataset[q_idx]
            query_img = query_info['images']
            cam_id = query_info['camid']
            query_name = query_info['img_path'].split('/')[-1]
            all_imgs.append(query_img)
            query_img = np.rollaxis(np.asarray(query_img.numpy(), dtype=np.uint8), 0, 3)
            plt.clf()
            ax = fig.add_subplot(1, max_rank + 1, 1)
            ax.imshow(query_img)
            ax.set_title('{}/{:.2f}/cam{}'.format(query_name, self.all_ap[q_idx], cam_id))
            ax.axis("off")
            for i in range(max_rank):
                if vis_label:
                    ax = fig.add_subplot(2, max_rank + 1, i + 2)
                else:
                    ax = fig.add_subplot(1, max_rank + 1, i + 2)
                g_idx = self.num_query + sort_idx[i]
                gallery_info = self.dataset[g_idx]
                gallery_img = gallery_info['images']
                cam_id = gallery_info['camid']
                all_imgs.append(gallery_img)
                gallery_img = np.rollaxis(np.asarray(gallery_img, dtype=np.uint8), 0, 3)
                if cmc[i] == 1:
                    label = 'true'
                    ax.add_patch(plt.Rectangle(xy=(0, 0), width=gallery_img.shape[1] - 1,
                                               height=gallery_img.shape[0] - 1, edgecolor=(1, 0, 0),
                                               fill=False, linewidth=5))
                else:
                    label = 'false'
                    ax.add_patch(plt.Rectangle(xy=(0, 0), width=gallery_img.shape[1] - 1,
                                               height=gallery_img.shape[0] - 1,
                                               edgecolor=(0, 0, 1), fill=False, linewidth=5))
                ax.imshow(gallery_img)
                ax.set_title(f'{self.sim[q_idx, sort_idx[i]]:.3f}/{label}/cam{cam_id}')
                ax.axis("off")
            # if actmap:
            #     act_outputs = []
            #
            #     def hook_fns_forward(module, input, output):
            #         act_outputs.append(output.cpu())
            #
            #     all_imgs = np.stack(all_imgs, axis=0)  # (b, 3, h, w)
            #     all_imgs = torch.from_numpy(all_imgs).float()
            #     # normalize
            #     all_imgs = all_imgs.sub_(self.mean).div_(self.std)
            #     sz = list(all_imgs.shape[-2:])
            #     handle = m.base.register_forward_hook(hook_fns_forward)
            #     with torch.no_grad():
            #         _ = m(all_imgs.cuda())
            #     handle.remove()
            #     acts = self.get_actmap(act_outputs[0], sz)
            #     for i in range(top + 1):
            #         axes.flat[i].imshow(acts[i], alpha=0.3, cmap='jet')
            if vis_label:
                label_indice = np.where(cmc == 1)[0]
                if label_sort == "ascending":   label_indice = label_indice[::-1]
                label_indice = label_indice[:max_rank]
                for i in range(max_rank):
                    if i >= len(label_indice): break
                    j = label_indice[i]
                    g_idx = self.num_query + sort_idx[j]
                    gallery_info = self.dataset[g_idx]
                    gallery_img = gallery_info['images']
                    cam_id = gallery_info['camid']
                    gallery_img = np.rollaxis(np.asarray(gallery_img, dtype=np.uint8), 0, 3)
                    ax = fig.add_subplot(2, max_rank + 1, max_rank + 3 + i)
                    ax.add_patch(plt.Rectangle(xy=(0, 0), width=gallery_img.shape[1] - 1,
                                               height=gallery_img.shape[0] - 1,
                                               edgecolor=(1, 0, 0),
                                               fill=False, linewidth=5))
                    ax.imshow(gallery_img)
                    ax.set_title(f'{self.sim[q_idx, sort_idx[j]]:.3f}/cam{cam_id}')
                    ax.axis("off")

            plt.tight_layout()
            filepath = os.path.join(output, "{}.jpg".format(cnt))
            fig.savefig(filepath)

    def vis_rank_list(self, output, vis_label, num_vis=100, rank_sort="ascending", label_sort="ascending", max_rank=5,
                      actmap=False):
        r"""Visualize rank list of query instance
        Args:
            output (str): a directory to save rank list result.
            vis_label (bool): if visualize label of query
            num_vis (int):
            rank_sort (str): save visualization results by which order,
                if rank_sort is ascending, AP from low to high, vice versa.
            label_sort (bool):
            max_rank (int): maximum number of rank result to visualize
            actmap (bool):
        """
        assert rank_sort in ['ascending', 'descending'], "{} not match [ascending, descending]".format(rank_sort)

        query_indices = np.argsort(self.all_ap)
        if rank_sort == 'descending': query_indices = query_indices[::-1]

        query_indices = query_indices[:num_vis]
        self.save_rank_result(query_indices, output, max_rank, vis_label, label_sort, actmap)

    def vis_roc_curve(self, output):
        PathManager.mkdirs(output)
        pos, neg = [], []
        for i, q in enumerate(self.q_pids):
            cmc, sort_idx = self.get_matched_result(i)  # remove same id in same camera
            ind_pos = np.where(cmc == 1)[0]
            q_dist = self.dist[i]
            pos.extend(q_dist[sort_idx[ind_pos]])

            ind_neg = np.where(cmc == 0)[0]
            neg.extend(q_dist[sort_idx[ind_neg]])

        scores = np.hstack((pos, neg))
        labels = np.hstack((np.zeros(len(pos)), np.ones(len(neg))))

        fpr, tpr, thresholds = metrics.roc_curve(labels, scores)

        self.plot_roc_curve(fpr, tpr)
        filepath = os.path.join(output, "roc.jpg")
        plt.savefig(filepath)
        self.plot_distribution(pos, neg)
        filepath = os.path.join(output, "pos_neg_dist.jpg")
        plt.savefig(filepath)
        return fpr, tpr, pos, neg

    @staticmethod
    def plot_roc_curve(fpr, tpr, name='model', fig=None):
        if fig is None:
            fig = plt.figure()
            plt.semilogx(np.arange(0, 1, 0.01), np.arange(0, 1, 0.01), 'r', linestyle='--', label='Random guess')
        plt.semilogx(fpr, tpr, color=(random.uniform(0, 1), random.uniform(0, 1), random.uniform(0, 1)),
                     label='ROC curve with {}'.format(name))
        plt.title('Receiver Operating Characteristic')
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.legend(loc='best')
        return fig

    @staticmethod
    def plot_distribution(pos, neg, name='model', fig=None):
        if fig is None:
            fig = plt.figure()
        pos_color = (random.uniform(0, 1), random.uniform(0, 1), random.uniform(0, 1))
        n, bins, _ = plt.hist(pos, bins=80, alpha=0.7, density=True,
                              color=pos_color,
                              label='positive with {}'.format(name))
        mu = np.mean(pos)
        sigma = np.std(pos)
        y = norm.pdf(bins, mu, sigma)  # fitting curve
        plt.plot(bins, y, color=pos_color)  # plot y curve

        neg_color = (random.uniform(0, 1), random.uniform(0, 1), random.uniform(0, 1))
        n, bins, _ = plt.hist(neg, bins=80, alpha=0.5, density=True,
                              color=neg_color,
                              label='negative with {}'.format(name))
        mu = np.mean(neg)
        sigma = np.std(neg)
        y = norm.pdf(bins, mu, sigma)  # fitting curve
        plt.plot(bins, y, color=neg_color)  # plot y curve

        plt.xticks(np.arange(0, 1.5, 0.1))
        plt.title('positive and negative pairs distribution')
        plt.legend(loc='best')
        return fig

    @staticmethod
    def save_roc_info(output, fpr, tpr, pos, neg):
        results = {
            "fpr": np.asarray(fpr),
            "tpr": np.asarray(tpr),
            "pos": np.asarray(pos),
            "neg": np.asarray(neg),
        }
        with open(os.path.join(output, "roc_info.pickle"), "wb") as handle:
            pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL)

    @staticmethod
    def load_roc_info(path):
        with open(path, 'rb') as handle: res = pickle.load(handle)
        return res

    # def plot_camera_dist(self):
    #     same_cam, diff_cam = [], []
    #     for i, q in enumerate(self.q_pids):
    #         q_camid = self.q_camids[i]
    #
    #         order = self.indices[i]
    #         same = (self.g_pids[order] == q) & (self.g_camids[order] == q_camid)
    #         diff = (self.g_pids[order] == q) & (self.g_camids[order] != q_camid)
    #         sameCam_idx = order[same]
    #         diffCam_idx = order[diff]
    #
    #         same_cam.extend(self.sim[i, sameCam_idx])
    #         diff_cam.extend(self.sim[i, diffCam_idx])
    #
    #     fig = plt.figure(figsize=(10, 5))
    #     plt.hist(same_cam, bins=80, alpha=0.7, density=True, color='red', label='same camera')
    #     plt.hist(diff_cam, bins=80, alpha=0.5, density=True, color='blue', label='diff camera')
    #     plt.xticks(np.arange(0.1, 1.0, 0.1))
    #     plt.title('positive and negative pair distribution')
    #     return fig

    # def get_actmap(self, features, sz):
    #     """
    #     :param features: (1, 2048, 16, 8) activation map
    #     :return:
    #     """
    #     features = (features ** 2).sum(1)  # (1, 16, 8)
    #     b, h, w = features.size()
    #     features = features.view(b, h * w)
    #     features = nn.functional.normalize(features, p=2, dim=1)
    #     acts = features.view(b, h, w)
    #     all_acts = []
    #     for i in range(b):
    #         act = acts[i].numpy()
    #         act = cv2.resize(act, (sz[1], sz[0]))
    #         act = 255 * (act - act.max()) / (act.max() - act.min() + 1e-12)
    #         act = np.uint8(np.floor(act))
    #         all_acts.append(act)
    #     return all_acts