# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
import unittest

import cv2
import numpy as np
from mmengine.data import LabelData

from mmocr.data import TextRecogDataSample
from mmocr.visualization import TextRecogLocalVisualizer


class TestTextDetLocalVisualizer(unittest.TestCase):

    def test_add_datasample(self):
        h, w = 64, 128
        image = np.random.randint(0, 256, size=(h, w, 3)).astype('uint8')

        # test gt_text
        gt_recog_data_sample = TextRecogDataSample()
        img_meta = dict(img_shape=(12, 10, 3))
        gt_text = LabelData(metainfo=img_meta)
        gt_text.item = 'mmocr'
        gt_recog_data_sample.gt_text = gt_text

        recog_local_visualizer = TextRecogLocalVisualizer()
        recog_local_visualizer.add_datasample('image', image,
                                              gt_recog_data_sample)

        # test gt_text and pred_text
        pred_recog_data_sample = TextRecogDataSample()
        pred_text = LabelData(metainfo=img_meta)
        pred_text.item = 'MMOCR'
        pred_recog_data_sample.pred_text = pred_text

        with tempfile.TemporaryDirectory() as tmp_dir:
            # test out
            out_file = osp.join(tmp_dir, 'out_file.jpg')

            # draw_gt = True + gt_sample
            recog_local_visualizer.add_datasample(
                'image', image, gt_recog_data_sample, out_file=out_file)
            self._assert_image_and_shape(out_file, (h * 2, w, 3))

            # draw_gt = True + gt_sample + pred_sample
            recog_local_visualizer.add_datasample(
                'image',
                image,
                gt_recog_data_sample,
                pred_recog_data_sample,
                out_file=out_file)
            self._assert_image_and_shape(out_file, (h * 3, w, 3))

            # draw_gt = False + gt_sample + pred_sample
            recog_local_visualizer.add_datasample(
                'image',
                image,
                gt_recog_data_sample,
                pred_recog_data_sample,
                draw_gt=False,
                out_file=out_file)
            self._assert_image_and_shape(out_file, (h * 2, w, 3))

    def _assert_image_and_shape(self, out_file, out_shape):
        self.assertTrue(osp.exists(out_file))
        drawn_img = cv2.imread(out_file)
        self.assertTrue(drawn_img.shape == out_shape)