# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import Mock

import torch
from mmengine.data import PixelData

from mmseg.engine.hooks import SegVisualizationHook
from mmseg.structures import SegDataSample
from mmseg.visualization import SegLocalVisualizer


class TestVisualizationHook(TestCase):

    def setUp(self) -> None:

        h = 288
        w = 512
        num_class = 2

        SegLocalVisualizer.get_instance('visualizer')
        SegLocalVisualizer.dataset_meta = dict(
            classes=('background', 'foreground'),
            palette=[[120, 120, 120], [6, 230, 230]])

        data_sample = SegDataSample()
        data_sample.set_metainfo({'img_path': 'tests/data/color.jpg'})
        self.data_batch = [{'data_sample': data_sample}] * 2

        pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
        pred_sem_seg = PixelData(**pred_sem_seg_data)
        pred_seg_data_sample = SegDataSample()
        pred_seg_data_sample.pred_sem_seg = pred_sem_seg
        self.outputs = [pred_seg_data_sample] * 2

    def test_after_iter(self):
        runner = Mock()
        runner.iter = 1
        hook = SegVisualizationHook(draw=True, interval=1)
        hook._after_iter(
            runner, 1, self.data_batch, self.outputs, mode='train')
        hook._after_iter(runner, 1, self.data_batch, self.outputs, mode='val')
        hook._after_iter(runner, 1, self.data_batch, self.outputs, mode='test')

    def test_after_val_iter(self):
        runner = Mock()
        runner.iter = 2
        hook = SegVisualizationHook(interval=1)
        hook.after_val_iter(runner, 1, self.data_batch, self.outputs)

        hook = SegVisualizationHook(draw=True, interval=1)
        hook.after_val_iter(runner, 1, self.data_batch, self.outputs)

        hook = SegVisualizationHook(
            draw=True, interval=1, show=True, wait_time=1)
        hook.after_val_iter(runner, 1, self.data_batch, self.outputs)

    def test_after_test_iter(self):
        runner = Mock()
        runner.iter = 3
        hook = SegVisualizationHook(draw=True, interval=1)
        hook.after_test_iter(runner, 1, self.data_batch, self.outputs)