# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Any
from unittest import TestCase

import matplotlib.pyplot as plt
import numpy as np
import pytest
import torch
import torch.nn as nn

from mmengine import VISBACKENDS, Config
from mmengine.visualization import Visualizer


@VISBACKENDS.register_module()
class MockVisBackend:

    def __init__(self, save_dir: str):
        self._save_dir = save_dir
        self._close = False

    @property
    def experiment(self) -> Any:
        return self

    def add_config(self, config, **kwargs) -> None:
        self._add_config = True

    def add_graph(self, model, data_batch, **kwargs) -> None:
        self._add_graph = True

    def add_image(self, name, image, step=0, **kwargs) -> None:
        self._add_image = True

    def add_scalar(self, name, value, step=0, **kwargs) -> None:
        self._add_scalar = True

    def add_scalars(self,
                    scalar_dict,
                    step=0,
                    file_path=None,
                    **kwargs) -> None:
        self._add_scalars = True

    def close(self) -> None:
        """close an opened object."""
        self._close = True


class TestVisualizer(TestCase):

    def setUp(self):
        """Setup the demo image in every test method.

        TestCase calls functions in this order: setUp() -> testMethod() ->
        tearDown() -> cleanUp()
        """
        self.image = np.random.randint(
            0, 256, size=(10, 10, 3)).astype('uint8')
        self.vis_backend_cfg = [
            dict(type='MockVisBackend', name='mock1'),
            dict(type='MockVisBackend', name='mock2')
        ]

    def test_init(self):
        visualizer = Visualizer(image=self.image)
        visualizer.get_image()

        # test save_dir
        with pytest.warns(
                Warning,
                match='`Visualizer` backend is not initialized '
                'because save_dir is None.'):
            Visualizer()

        visualizer = Visualizer(
            vis_backends=copy.deepcopy(self.vis_backend_cfg))
        assert visualizer.get_backend('mock1') is None

        visualizer = Visualizer(
            vis_backends=copy.deepcopy(self.vis_backend_cfg),
            save_dir='temp_dir')
        assert isinstance(visualizer.get_backend('mock1'), MockVisBackend)
        assert len(visualizer._vis_backends) == 2

        # test empty list
        with pytest.raises(AssertionError):
            Visualizer(vis_backends=[], save_dir='temp_dir')

        # test name
        # If one of them has a name attribute, all backends must
        # use the name attribute
        with pytest.raises(RuntimeError):
            Visualizer(
                vis_backends=[
                    dict(type='MockVisBackend'),
                    dict(type='MockVisBackend', name='mock2')
                ],
                save_dir='temp_dir')

        # The name fields cannot be the same
        with pytest.raises(RuntimeError):
            Visualizer(
                vis_backends=[
                    dict(type='MockVisBackend'),
                    dict(type='MockVisBackend')
                ],
                save_dir='temp_dir')

        with pytest.raises(RuntimeError):
            Visualizer(
                vis_backends=[
                    dict(type='MockVisBackend', name='mock1'),
                    dict(type='MockVisBackend', name='mock1')
                ],
                save_dir='temp_dir')

        # test global init
        visualizer = Visualizer.get_instance(
            'visualizer',
            vis_backends=copy.deepcopy(self.vis_backend_cfg),
            save_dir='temp_dir')
        assert len(visualizer._vis_backends) == 2
        visualizer_any = Visualizer.get_instance('visualizer')
        assert visualizer_any == visualizer

    def test_set_image(self):
        visualizer = Visualizer()
        visualizer.set_image(self.image)
        with pytest.raises(AssertionError):
            visualizer.set_image(None)

    def test_get_image(self):
        visualizer = Visualizer(image=self.image)
        visualizer.get_image()

    def test_draw_bboxes(self):
        visualizer = Visualizer(image=self.image)

        # only support 4 or nx4 tensor and numpy
        visualizer.draw_bboxes(torch.tensor([1, 1, 2, 2]))
        # valid bbox
        visualizer.draw_bboxes(torch.tensor([1, 1, 1, 2]))
        bboxes = torch.tensor([[1, 1, 2, 2], [1, 2, 2, 2.5]])
        visualizer.draw_bboxes(
            bboxes, alpha=0.5, edge_colors=(255, 0, 0), line_styles='-')
        bboxes = bboxes.numpy()
        visualizer.draw_bboxes(bboxes)

        # test invalid bbox
        with pytest.raises(AssertionError):
            # x1 > x2
            visualizer.draw_bboxes(torch.tensor([5, 1, 2, 2]))

        # test out of bounds
        with pytest.warns(
                UserWarning,
                match='Warning: The bbox is out of bounds,'
                ' the drawn bbox may not be in the image'):
            visualizer.draw_bboxes(torch.tensor([1, 1, 20, 2]))

        # test incorrect bbox format
        with pytest.raises(TypeError):
            visualizer.draw_bboxes([1, 1, 2, 2])

    def test_close(self):
        visualizer = Visualizer(
            image=self.image,
            vis_backends=copy.deepcopy(self.vis_backend_cfg),
            save_dir='temp_dir')
        fig_num = visualizer.fig_save_num
        assert fig_num in plt.get_fignums()
        for name in ['mock1', 'mock2']:
            assert visualizer.get_backend(name)._close is False
        visualizer.close()
        assert fig_num not in plt.get_fignums()
        for name in ['mock1', 'mock2']:
            assert visualizer.get_backend(name)._close is True

    def test_draw_points(self):
        visualizer = Visualizer(image=self.image)

        with pytest.raises(TypeError):
            visualizer.draw_points(positions=[1, 2])
        with pytest.raises(AssertionError):
            visualizer.draw_points(positions=np.array([1, 2, 3]))
        # test color
        visualizer.draw_points(
            positions=torch.tensor([[1, 1], [3, 3]]),
            colors=['g', (255, 255, 0)])
        visualizer.draw_points(
            positions=torch.tensor([[1, 1], [3, 3]]),
            colors=['g', (255, 255, 0)],
            marker='.',
            sizes=[1, 5])

    def test_draw_texts(self):
        visualizer = Visualizer(image=self.image)

        # only support tensor and numpy
        visualizer.draw_texts(
            'text1', positions=torch.tensor([5, 5]), colors=(0, 255, 0))
        visualizer.draw_texts(['text1', 'text2'],
                              positions=torch.tensor([[5, 5], [3, 3]]),
                              colors=[(255, 0, 0), (255, 0, 0)])
        visualizer.draw_texts('text1', positions=np.array([5, 5]))
        visualizer.draw_texts(['text1', 'text2'],
                              positions=np.array([[5, 5], [3, 3]]))
        visualizer.draw_texts(
            'text1',
            positions=torch.tensor([5, 5]),
            bboxes=dict(facecolor='r', alpha=0.6))
        # test out of bounds
        with pytest.warns(
                UserWarning,
                match='Warning: The text is out of bounds,'
                ' the drawn text may not be in the image'):
            visualizer.draw_texts('text1', positions=torch.tensor([15, 5]))

        # test incorrect format
        with pytest.raises(TypeError):
            visualizer.draw_texts('text', positions=[5, 5])

        # test length mismatch
        with pytest.raises(AssertionError):
            visualizer.draw_texts(['text1', 'text2'],
                                  positions=torch.tensor([5, 5]))
        with pytest.raises(AssertionError):
            visualizer.draw_texts(
                'text1', positions=torch.tensor([[5, 5], [3, 3]]))
        with pytest.raises(AssertionError):
            visualizer.draw_texts(['text1', 'test2'],
                                  positions=torch.tensor([[5, 5], [3, 3]]),
                                  colors=['r'])
        with pytest.raises(AssertionError):
            visualizer.draw_texts(['text1', 'test2'],
                                  positions=torch.tensor([[5, 5], [3, 3]]),
                                  vertical_alignments=['top'])
        with pytest.raises(AssertionError):
            visualizer.draw_texts(['text1', 'test2'],
                                  positions=torch.tensor([[5, 5], [3, 3]]),
                                  horizontal_alignments=['left'])
        with pytest.raises(AssertionError):
            visualizer.draw_texts(['text1', 'test2'],
                                  positions=torch.tensor([[5, 5], [3, 3]]),
                                  font_sizes=[1])

        # test type valid
        with pytest.raises(TypeError):
            visualizer.draw_texts(['text1', 'test2'],
                                  positions=torch.tensor([[5, 5], [3, 3]]),
                                  font_sizes='b')

    def test_draw_lines(self):
        visualizer = Visualizer(image=self.image)

        # only support tensor and numpy
        visualizer.draw_lines(
            x_datas=torch.tensor([1, 5]), y_datas=torch.tensor([2, 6]))
        visualizer.draw_lines(
            x_datas=np.array([[1, 5], [2, 4]]),
            y_datas=np.array([[2, 6], [4, 7]]))
        visualizer.draw_lines(
            x_datas=np.array([[1, 5], [2, 4]]),
            y_datas=np.array([[2, 6], [4, 7]]),
            colors='r',
            line_styles=['-', '-.'],
            line_widths=[1, 2])
        # test out of bounds
        with pytest.warns(
                UserWarning,
                match='Warning: The line is out of bounds,'
                ' the drawn line may not be in the image'):
            visualizer.draw_lines(
                x_datas=torch.tensor([12, 5]), y_datas=torch.tensor([2, 6]))

        # test incorrect format
        with pytest.raises(TypeError):
            visualizer.draw_lines(x_datas=[5, 5], y_datas=torch.tensor([2, 6]))
        with pytest.raises(TypeError):
            visualizer.draw_lines(y_datas=[5, 5], x_datas=torch.tensor([2, 6]))

        # test length mismatch
        with pytest.raises(AssertionError):
            visualizer.draw_lines(
                x_datas=torch.tensor([1, 5]),
                y_datas=torch.tensor([[2, 6], [4, 7]]))

    def test_draw_circles(self):
        visualizer = Visualizer(image=self.image)

        # only support tensor and numpy
        visualizer.draw_circles(torch.tensor([1, 5]), torch.tensor([1]))
        visualizer.draw_circles(np.array([1, 5]), np.array([1]))
        visualizer.draw_circles(
            torch.tensor([[1, 5], [2, 6]]), radius=torch.tensor([1, 2]))

        # test face_colors
        visualizer.draw_circles(
            torch.tensor([[1, 5], [2, 6]]),
            radius=torch.tensor([1, 2]),
            face_colors=(255, 0, 0),
            edge_colors=(255, 0, 0))

        # test config
        visualizer.draw_circles(
            torch.tensor([[1, 5], [2, 6]]),
            radius=torch.tensor([1, 2]),
            edge_colors=['g', 'r'],
            line_styles=['-', '-.'],
            line_widths=[1, 2])

        # test out of bounds
        with pytest.warns(
                UserWarning,
                match='Warning: The circle is out of bounds,'
                ' the drawn circle may not be in the image'):
            visualizer.draw_circles(
                torch.tensor([12, 5]), radius=torch.tensor([1]))
            visualizer.draw_circles(
                torch.tensor([1, 5]), radius=torch.tensor([10]))

        # test incorrect format
        with pytest.raises(TypeError):
            visualizer.draw_circles([1, 5], radius=torch.tensor([1]))
        with pytest.raises(TypeError):
            visualizer.draw_circles(np.array([1, 5]), radius=10)

        # test length mismatch
        with pytest.raises(AssertionError):
            visualizer.draw_circles(
                torch.tensor([[1, 5]]), radius=torch.tensor([1, 2]))

    def test_draw_polygons(self):
        visualizer = Visualizer(image=self.image)
        # shape Nx2 or list[Nx2]
        visualizer.draw_polygons(torch.tensor([[1, 1], [2, 2], [3, 4]]))
        visualizer.draw_polygons(np.array([[1, 1], [2, 2], [3, 4]]))
        visualizer.draw_polygons([
            np.array([[1, 1], [2, 2], [3, 4]]),
            torch.tensor([[1, 1], [2, 2], [3, 4]])
        ])
        visualizer.draw_polygons(
            polygons=[
                np.array([[1, 1], [2, 2], [3, 4]]),
                torch.tensor([[1, 1], [2, 2], [3, 4]])
            ],
            face_colors=(255, 0, 0),
            edge_colors=(255, 0, 0))
        visualizer.draw_polygons(
            polygons=[
                np.array([[1, 1], [2, 2], [3, 4]]),
                torch.tensor([[1, 1], [2, 2], [3, 4]])
            ],
            edge_colors=['r', 'g'],
            line_styles='-',
            line_widths=[2, 1])

        # test out of bounds
        with pytest.warns(
                UserWarning,
                match='Warning: The polygon is out of bounds,'
                ' the drawn polygon may not be in the image'):
            visualizer.draw_polygons(torch.tensor([[1, 1], [2, 2], [16, 4]]))

    def test_draw_binary_masks(self):
        binary_mask = np.random.randint(0, 2, size=(10, 10)).astype(np.bool)
        visualizer = Visualizer(image=self.image)
        visualizer.draw_binary_masks(binary_mask)
        visualizer.draw_binary_masks(torch.from_numpy(binary_mask))
        # multi binary
        binary_mask = np.random.randint(0, 2, size=(2, 10, 10)).astype(np.bool)
        visualizer = Visualizer(image=self.image)
        visualizer.draw_binary_masks(binary_mask, colors=['r', (0, 255, 0)])
        # test the error that the size of mask and image are different.
        with pytest.raises(AssertionError):
            binary_mask = np.random.randint(0, 2, size=(8, 10)).astype(np.bool)
            visualizer.draw_binary_masks(binary_mask)

        # test non binary mask error
        binary_mask = np.random.randint(0, 2, size=(10, 10, 3)).astype(np.bool)
        with pytest.raises(AssertionError):
            visualizer.draw_binary_masks(binary_mask)

        # test color dim
        with pytest.raises(AssertionError):
            visualizer.draw_binary_masks(
                binary_mask, colors=np.array([1, 22, 4, 45]))
        binary_mask = np.random.randint(0, 2, size=(10, 10))
        with pytest.raises(AssertionError):
            visualizer.draw_binary_masks(binary_mask)

    def test_draw_featmap(self):
        visualizer = Visualizer()
        image = np.random.randint(0, 256, size=(3, 3, 3), dtype='uint8')

        # must be Tensor
        with pytest.raises(
                AssertionError,
                match='`featmap` should be torch.Tensor, but got '
                "<class 'numpy.ndarray'>"):
            visualizer.draw_featmap(np.ones((3, 3, 3)))

        # test tensor format
        with pytest.raises(
                AssertionError, match='Input dimension must be 3, but got 4'):
            visualizer.draw_featmap(torch.randn(1, 1, 3, 3))

        # test overlaid_image shape
        with pytest.warns(Warning):
            visualizer.draw_featmap(torch.randn(1, 4, 3), overlaid_image=image)

        # test resize_shape
        featmap = visualizer.draw_featmap(
            torch.randn(1, 4, 3), resize_shape=(6, 7))
        assert featmap.shape[:2] == (6, 7)
        featmap = visualizer.draw_featmap(
            torch.randn(1, 4, 3), overlaid_image=image, resize_shape=(6, 7))
        assert featmap.shape[:2] == (6, 7)

        # test channel_reduction parameter
        # mode only supports 'squeeze_mean' and 'select_max'
        with pytest.raises(AssertionError):
            visualizer.draw_featmap(
                torch.randn(2, 3, 3), channel_reduction='xx')

        featmap = visualizer.draw_featmap(
            torch.randn(2, 3, 3), channel_reduction='squeeze_mean')
        assert featmap.shape[:2] == (3, 3)
        featmap = visualizer.draw_featmap(
            torch.randn(2, 3, 3), channel_reduction='select_max')
        assert featmap.shape[:2] == (3, 3)
        featmap = visualizer.draw_featmap(
            torch.randn(2, 4, 3),
            overlaid_image=image,
            channel_reduction='select_max')
        assert featmap.shape[:2] == (3, 3)

        # test topk parameter
        with pytest.raises(
                AssertionError,
                match='The input tensor channel dimension must be 1 or 3 '
                'when topk is less than 1, but the channel '
                'dimension you input is 6, you can use the '
                'channel_reduction parameter or set topk '
                'greater than 0 to solve the error'):
            visualizer.draw_featmap(
                torch.randn(6, 3, 3), channel_reduction=None, topk=0)

        featmap = visualizer.draw_featmap(
            torch.randn(6, 3, 3), channel_reduction='select_max', topk=10)
        assert featmap.shape[:2] == (3, 3)
        featmap = visualizer.draw_featmap(
            torch.randn(1, 4, 3), channel_reduction=None, topk=-1)
        assert featmap.shape[:2] == (4, 3)

        featmap = visualizer.draw_featmap(
            torch.randn(3, 4, 3),
            overlaid_image=image,
            channel_reduction=None,
            topk=-1)
        assert featmap.shape[:2] == (3, 3)
        featmap = visualizer.draw_featmap(
            torch.randn(6, 3, 3),
            channel_reduction=None,
            topk=4,
            arrangement=(2, 2))
        assert featmap.shape[:2] == (6, 6)
        featmap = visualizer.draw_featmap(
            torch.randn(6, 3, 3),
            channel_reduction=None,
            topk=4,
            arrangement=(1, 4))
        assert featmap.shape[:2] == (3, 12)
        with pytest.raises(
                AssertionError,
                match='The product of row and col in the `arrangement` '
                'is less than topk, please set '
                'the `arrangement` correctly'):
            visualizer.draw_featmap(
                torch.randn(6, 3, 3),
                channel_reduction=None,
                topk=4,
                arrangement=(1, 2))

        # test gray
        featmap = visualizer.draw_featmap(
            torch.randn(6, 3, 3),
            overlaid_image=np.random.randint(
                0, 256, size=(3, 3), dtype='uint8'),
            channel_reduction=None,
            topk=4,
            arrangement=(2, 2))
        assert featmap.shape[:2] == (6, 6)

    def test_chain_call(self):
        visualizer = Visualizer(image=self.image)
        binary_mask = np.random.randint(0, 2, size=(10, 10)).astype(np.bool)
        visualizer.draw_bboxes(torch.tensor([1, 1, 2, 2])). \
            draw_texts('test', torch.tensor([5, 5])). \
            draw_lines(x_datas=torch.tensor([1, 5]),
                       y_datas=torch.tensor([2, 6])). \
            draw_circles(torch.tensor([1, 5]), radius=torch.tensor([2])). \
            draw_polygons(torch.tensor([[1, 1], [2, 2], [3, 4]])). \
            draw_binary_masks(binary_mask)

    def test_get_backend(self):
        visualizer = Visualizer(
            image=self.image,
            vis_backends=copy.deepcopy(self.vis_backend_cfg),
            save_dir='temp_dir')
        for name in ['mock1', 'mock2']:
            assert isinstance(visualizer.get_backend(name), MockVisBackend)

    def test_add_config(self):
        visualizer = Visualizer(
            vis_backends=copy.deepcopy(self.vis_backend_cfg),
            save_dir='temp_dir')

        cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
        visualizer.add_config(cfg)
        for name in ['mock1', 'mock2']:
            assert visualizer.get_backend(name)._add_config is True

    def test_add_graph(self):
        visualizer = Visualizer(
            vis_backends=copy.deepcopy(self.vis_backend_cfg),
            save_dir='temp_dir')

        class Model(nn.Module):

            def __init__(self):
                super().__init__()
                self.conv = nn.Conv2d(1, 2, 1)

            def forward(self, x, y=None):
                return self.conv(x)

        visualizer.add_graph(Model(), np.zeros([1, 1, 3, 3]))
        for name in ['mock1', 'mock2']:
            assert visualizer.get_backend(name)._add_graph is True

    def test_add_image(self):
        image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8)
        visualizer = Visualizer(
            vis_backends=copy.deepcopy(self.vis_backend_cfg),
            save_dir='temp_dir')

        visualizer.add_image('img', image)
        for name in ['mock1', 'mock2']:
            assert visualizer.get_backend(name)._add_image is True

    def test_add_scalar(self):
        visualizer = Visualizer(
            vis_backends=copy.deepcopy(self.vis_backend_cfg),
            save_dir='temp_dir')
        visualizer.add_scalar('map', 0.9, step=0)
        for name in ['mock1', 'mock2']:
            assert visualizer.get_backend(name)._add_scalar is True

    def test_add_scalars(self):
        visualizer = Visualizer(
            vis_backends=copy.deepcopy(self.vis_backend_cfg),
            save_dir='temp_dir')
        input_dict = {'map': 0.7, 'acc': 0.9}
        visualizer.add_scalars(input_dict)
        for name in ['mock1', 'mock2']:
            assert visualizer.get_backend(name)._add_scalars is True

    def test_get_instance(self):

        class DetLocalVisualizer(Visualizer):

            def __init__(self, name):
                super().__init__(name)

        visualizer1 = DetLocalVisualizer.get_instance('name1')
        visualizer2 = Visualizer.get_current_instance()
        visualizer3 = DetLocalVisualizer.get_current_instance()
        assert id(visualizer1) == id(visualizer2) == id(visualizer3)

    def test_data_info(self):
        visualizer = Visualizer()
        visualizer.dataset_meta = {'class': 'cat'}
        assert visualizer.dataset_meta['class'] == 'cat'