# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import shutil
import time
from unittest import TestCase
from unittest.mock import Mock

import torch
from mmengine.structures import InstanceData

from mmocr.engine.hooks import VisualizationHook
from mmocr.structures import TextDetDataSample
from mmocr.visualization import TextDetLocalVisualizer


def _rand_bboxes(num_boxes, h, w):
    cx, cy, bw, bh = torch.rand(num_boxes, 4).T

    tl_x = ((cx * w) - (w * bw / 2)).clamp(0, w).unsqueeze(0)
    tl_y = ((cy * h) - (h * bh / 2)).clamp(0, h).unsqueeze(0)
    br_x = ((cx * w) + (w * bw / 2)).clamp(0, w).unsqueeze(0)
    br_y = ((cy * h) + (h * bh / 2)).clamp(0, h).unsqueeze(0)

    bboxes = torch.cat([tl_x, tl_y, br_x, br_y], dim=0).T
    return bboxes


class TestVisualizationHook(TestCase):

    def setUp(self) -> None:

        data_sample = TextDetDataSample()
        data_sample.set_metainfo({
            'img_path':
            osp.join(
                osp.dirname(__file__),
                '../../data/det_toy_dataset/imgs/test/img_1.jpg')
        })

        pred_instances = InstanceData()
        pred_instances.bboxes = _rand_bboxes(5, 10, 12)
        pred_instances.labels = torch.randint(0, 2, (5, ))
        pred_instances.scores = torch.rand((5, ))

        data_sample.pred_instances = pred_instances
        self.outputs = [data_sample] * 2
        self.data_batch = None

    def test_after_val_iter(self):
        timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
        TextDetLocalVisualizer.get_instance(
            'visualizer_val',
            vis_backends=[dict(type='LocalVisBackend', img_save_dir='')],
            save_dir=timestamp)
        runner = Mock()
        runner.iter = 1
        hook = VisualizationHook(enable=True, interval=1)
        self.assertFalse(osp.exists(timestamp))
        hook.after_val_iter(runner, 1, self.data_batch, self.outputs)
        self.assertTrue(osp.exists(timestamp))
        shutil.rmtree(timestamp)

    def test_after_test_iter(self):
        timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
        TextDetLocalVisualizer.get_instance(
            'visualizer_test',
            vis_backends=[dict(type='LocalVisBackend', img_save_dir='')],
            save_dir=timestamp)
        runner = Mock()
        runner.iter = 1

        hook = VisualizationHook(enable=False)
        hook.after_test_iter(runner, 1, self.data_batch, self.outputs)
        self.assertFalse(osp.exists(timestamp))

        hook = VisualizationHook(enable=True)
        hook.after_test_iter(runner, 1, self.data_batch, self.outputs)
        self.assertTrue(osp.exists(timestamp))
        shutil.rmtree(timestamp)