# Copyright (c) Alibaba, Inc. and its affiliates.
import cv2
import mmcv
import numpy as np
import torch
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon

from easycv.core.visualization.image import imshow_bboxes
from easycv.predictors.builder import PREDICTORS
from .base import PredictorV2


@PREDICTORS.register_module()
class SegmentationPredictor(PredictorV2):
    """Predictor for Segmentation.

    Args:
        model_path (str): Path of model path.
        config_file (Optinal[str]): config file path for model and processor to init. Defaults to None.
        batch_size (int): batch size for forward.
        device (str): Support 'cuda' or 'cpu', if is None, detect device automatically.
        save_results (bool): Whether to save predict results.
        save_path (str): File path for saving results, only valid when `save_results` is True.
        pipelines (list[dict]): Data pipeline configs.
    """

    def __init__(self,
                 model_path,
                 config_file,
                 batch_size=1,
                 device=None,
                 save_results=False,
                 save_path=None,
                 pipelines=None,
                 *args,
                 **kwargs):

        super(SegmentationPredictor, self).__init__(
            model_path,
            config_file,
            batch_size=batch_size,
            device=device,
            save_results=save_results,
            save_path=save_path,
            pipelines=pipelines,
            *args,
            **kwargs)

        self.CLASSES = self.cfg.CLASSES
        self.PALETTE = self.cfg.PALETTE

    def show_result(self,
                    img,
                    result,
                    palette=None,
                    win_name='',
                    show=False,
                    wait_time=0,
                    out_file=None,
                    opacity=0.5):
        """Draw `result` over `img`.

        Args:
            img (str or Tensor): The image to be displayed.
            result (Tensor): The semantic segmentation results to draw over
                `img`.
            palette (list[list[int]]] | np.ndarray | None): The palette of
                segmentation map. If None is given, random palette will be
                generated. Default: None
            win_name (str): The window name.
            wait_time (int): Value of waitKey param.
                Default: 0.
            show (bool): Whether to show the image.
                Default: False.
            out_file (str or None): The filename to write the image.
                Default: None.
            opacity(float): Opacity of painted segmentation map.
                Default 0.5.
                Must be in (0, 1] range.
        Returns:
            img (Tensor): Only if not `show` or `out_file`
        """

        img = mmcv.imread(img)
        img = img.copy()
        seg = result[0]
        if palette is None:
            if self.PALETTE is None:
                # Get random state before set seed,
                # and restore random state later.
                # It will prevent loss of randomness, as the palette
                # may be different in each iteration if not specified.
                # See: https://github.com/open-mmlab/mmdetection/issues/5844
                state = np.random.get_state()
                np.random.seed(42)
                # random palette
                palette = np.random.randint(
                    0, 255, size=(len(self.CLASSES), 3))
                np.random.set_state(state)
            else:
                palette = self.PALETTE
        palette = np.array(palette)
        assert palette.shape[0] == len(self.CLASSES)
        assert palette.shape[1] == 3
        assert len(palette.shape) == 2
        assert 0 < opacity <= 1.0
        color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
        for label, color in enumerate(palette):
            color_seg[seg == label, :] = color
        # convert to BGR
        color_seg = color_seg[..., ::-1]

        img = img * (1 - opacity) + color_seg * opacity
        img = img.astype(np.uint8)
        # if out_file specified, do not show image in window
        if out_file is not None:
            show = False

        if show:
            mmcv.imshow(img, win_name, wait_time)
        if out_file is not None:
            mmcv.imwrite(img, out_file)

        if not (show or out_file):
            return img


@PREDICTORS.register_module()
class Mask2formerPredictor(SegmentationPredictor):
    """Predictor for Mask2former.

    Args:
        model_path (str): Path of model path.
        config_file (Optinal[str]): config file path for model and processor to init. Defaults to None.
        batch_size (int): batch size for forward.
        device (str): Support 'cuda' or 'cpu', if is None, detect device automatically.
        save_results (bool): Whether to save predict results.
        save_path (str): File path for saving results, only valid when `save_results` is True.
        pipelines (list[dict]): Data pipeline configs.
    """

    def __init__(self,
                 model_path,
                 config_file=None,
                 batch_size=1,
                 device=None,
                 save_results=False,
                 save_path=None,
                 pipelines=None,
                 task_mode='panoptic',
                 *args,
                 **kwargs):
        super(Mask2formerPredictor, self).__init__(
            model_path,
            config_file,
            batch_size=batch_size,
            device=device,
            save_results=save_results,
            save_path=save_path,
            pipelines=pipelines,
            *args,
            **kwargs)
        self.task_mode = task_mode
        self.class_name = self.cfg.CLASSES
        self.PALETTE = self.cfg.PALETTE

    def forward(self, inputs):
        """Model forward.
        """
        with torch.no_grad():
            outputs = self.model.forward(**inputs, mode='test', encode=False)
        return outputs

    def postprocess_single(self, inputs, *args, **kwargs):
        output = {}
        if self.task_mode == 'panoptic':
            pan_results = inputs['pan_results']
            # keep objects ahead
            ids = np.unique(pan_results)[::-1]
            legal_indices = ids != len(self.CLASSES)  # for VOID label
            ids = ids[legal_indices]
            labels = np.array([id % 1000 for id in ids], dtype=np.int64)
            segms = (pan_results[None] == ids[:, None, None])
            masks = [it.astype(np.int) for it in segms]
            labels_txt = np.array(self.CLASSES)[labels].tolist()

            output['masks'] = masks
            output['labels'] = labels_txt
            output['labels_ids'] = labels
        elif self.task_mode == 'instance':
            output['segms'] = inputs['detection_masks']
            output['bboxes'] = inputs['detection_boxes']
            output['scores'] = inputs['detection_scores']
            output['labels'] = inputs['detection_classes']
        elif self.task_mode == 'semantic':
            output['seg_pred'] = inputs['seg_pred']
        else:
            raise ValueError(f'Not support model {self.task_mode}')
        return output

    def show_panoptic(self, img, masks, labels):
        palette = np.asarray(self.cfg.PALETTE)
        palette = palette[labels % 1000]
        panoptic_result = draw_masks(img, masks, palette)
        return panoptic_result

    def show_instance(self, img, segms, bboxes, scores, labels, score_thr=0.5):
        if score_thr > 0:
            inds = scores > score_thr
            bboxes = bboxes[inds, :]
            segms = segms[inds, ...]
            labels = labels[inds]
        palette = np.asarray(self.PALETTE)
        palette = palette[labels]

        instance_result = draw_masks(img, segms, palette)
        class_name = np.array(self.CLASSES)
        instance_result = imshow_bboxes(
            instance_result, bboxes, class_name[labels], show=False)
        return instance_result

    def show_semantic(self, img, seg_pred, alpha=0.5, palette=None):

        if palette is None:
            if self.PALETTE is None:
                # Get random state before set seed,
                # and restore random state later.
                # It will prevent loss of randomness, as the palette
                # may be different in each iteration if not specified.
                # See: https://github.com/open-mmlab/mmdetection/issues/5844
                state = np.random.get_state()
                np.random.seed(42)
                # random palette
                palette = np.random.randint(
                    0, 255, size=(len(self.CLASSES), 3))
                np.random.set_state(state)
            else:
                palette = self.PALETTE
        palette = np.array(palette)
        assert palette.shape[0] == len(self.CLASSES)
        assert palette.shape[1] == 3
        assert len(palette.shape) == 2
        assert 0 < alpha <= 1.0
        color_seg = np.zeros((seg_pred.shape[0], seg_pred.shape[1], 3),
                             dtype=np.uint8)
        for label, color in enumerate(palette):
            color_seg[seg_pred == label, :] = color
        # convert to BGR
        color_seg = color_seg[..., ::-1]

        img = img * (1 - alpha) + color_seg * alpha
        img = img.astype(np.uint8)

        return img


def _get_bias_color(base, max_dist=30):
    """Get different colors for each masks.

    Get different colors for each masks by adding a bias
    color to the base category color.
    Args:
        base (ndarray): The base category color with the shape
            of (3, ).
        max_dist (int): The max distance of bias. Default: 30.

    Returns:
        ndarray: The new color for a mask with the shape of (3, ).
    """
    new_color = base + np.random.randint(
        low=-max_dist, high=max_dist + 1, size=3)
    return np.clip(new_color, 0, 255, new_color)


def draw_masks(img, masks, color=None, with_edge=True, alpha=0.8):
    """Draw masks on the image and their edges on the axes.

    Args:
        ax (matplotlib.Axes): The input axes.
        img (ndarray): The image with the shape of (3, h, w).
        masks (ndarray): The masks with the shape of (n, h, w).
        color (ndarray): The colors for each masks with the shape
            of (n, 3).
        with_edge (bool): Whether to draw edges. Default: True.
        alpha (float): Transparency of bounding boxes. Default: 0.8.

    Returns:
        matplotlib.Axes: The result axes.
        ndarray: The result image.
    """
    taken_colors = set([0, 0, 0])
    if color is None:
        random_colors = np.random.randint(0, 255, (masks.size(0), 3))
        color = [tuple(c) for c in random_colors]
        color = np.array(color, dtype=np.uint8)
    polygons = []
    for i, mask in enumerate(masks):
        if with_edge:
            contours, _ = bitmap_to_polygon(mask)
            polygons += [Polygon(c) for c in contours]

        color_mask = color[i]
        while tuple(color_mask) in taken_colors:
            color_mask = _get_bias_color(color_mask)
        taken_colors.add(tuple(color_mask))

        mask = mask.astype(bool)
        img[mask] = img[mask] * (1 - alpha) + color_mask * alpha

    p = PatchCollection(
        polygons, facecolor='none', edgecolors='w', linewidths=1, alpha=0.8)

    return img


def bitmap_to_polygon(bitmap):
    """Convert masks from the form of bitmaps to polygons.

    Args:
        bitmap (ndarray): masks in bitmap representation.

    Return:
        list[ndarray]: the converted mask in polygon representation.
        bool: whether the mask has holes.
    """
    bitmap = np.ascontiguousarray(bitmap).astype(np.uint8)
    # cv2.RETR_CCOMP: retrieves all of the contours and organizes them
    #   into a two-level hierarchy. At the top level, there are external
    #   boundaries of the components. At the second level, there are
    #   boundaries of the holes. If there is another contour inside a hole
    #   of a connected component, it is still put at the top level.
    # cv2.CHAIN_APPROX_NONE: stores absolutely all the contour points.
    outs = cv2.findContours(bitmap, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
    contours = outs[-2]
    hierarchy = outs[-1]
    if hierarchy is None:
        return [], False
    # hierarchy[i]: 4 elements, for the indexes of next, previous,
    # parent, or nested contours. If there is no corresponding contour,
    # it will be -1.
    with_hole = (hierarchy.reshape(-1, 4)[:, 3] >= 0).any()
    contours = [c.reshape(-1, 2) for c in contours]
    return contours, with_hole