# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import Mock

import torch

from mmengine.data import BaseDataElement
from mmengine.hooks import NaiveVisualizationHook


class TestNaiveVisualizationHook:

    def test_after_train_iter(self):
        naive_visualization_hook = NaiveVisualizationHook()
        runner = Mock(iter=1)
        runner.writer.add_image = Mock()
        inputs = torch.randn(1, 3, 15, 15)
        batch_idx = 10
        # test with normalize, resize, pad
        gt_datasamples = [
            BaseDataElement(
                metainfo=dict(
                    img_norm_cfg=dict(
                        mean=(0, 0, 0), std=(0.5, 0.5, 0.5), to_bgr=True),
                    scale=(10, 10),
                    pad_shape=(15, 15, 3),
                    ori_height=5,
                    ori_width=5,
                    img_path='tmp.jpg'))
        ]
        pred_datasamples = [BaseDataElement()]
        data_batch = (inputs, gt_datasamples)
        naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
                                                 pred_datasamples)
        # test with resize, pad
        gt_datasamples = [
            BaseDataElement(
                metainfo=dict(
                    scale=(10, 10),
                    pad_shape=(15, 15, 3),
                    ori_height=5,
                    ori_width=5,
                    img_path='tmp.jpg')),
        ]
        pred_datasamples = [BaseDataElement()]
        data_batch = (inputs, gt_datasamples)
        naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
                                                 pred_datasamples)
        # test with only resize
        gt_datasamples = [
            BaseDataElement(
                metainfo=dict(
                    scale=(15, 15),
                    ori_height=5,
                    ori_width=5,
                    img_path='tmp.jpg')),
        ]
        pred_datasamples = [BaseDataElement()]
        data_batch = (inputs, gt_datasamples)
        naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
                                                 pred_datasamples)

        # test with only pad
        gt_datasamples = [
            BaseDataElement(
                metainfo=dict(
                    pad_shape=(15, 15, 3),
                    ori_height=5,
                    ori_width=5,
                    img_path='tmp.jpg')),
        ]
        pred_datasamples = [BaseDataElement()]
        data_batch = (inputs, gt_datasamples)
        naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
                                                 pred_datasamples)

        # test no transform
        gt_datasamples = [
            BaseDataElement(
                metainfo=dict(ori_height=15, ori_width=15,
                              img_path='tmp.jpg')),
        ]
        pred_datasamples = [BaseDataElement()]
        data_batch = (inputs, gt_datasamples)
        naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
                                                 pred_datasamples)