# Copyright (c) OpenMMLab. All rights reserved.
import copy
import unittest
from typing import Dict, List, Optional

import numpy as np
from shapely.geometry import Polygon

from mmocr.datasets.transforms import ImgAugWrapper, TorchVisionWrapper


class TestImgAug(unittest.TestCase):

    def test_init(self):
        with self.assertRaises(AssertionError):
            ImgAugWrapper(args=[])
        with self.assertRaises(AssertionError):
            ImgAugWrapper(args=['test'])

    def _create_dummy_data(self):
        img = np.random.rand(50, 50, 3)
        poly = np.array([[[0, 0, 50, 0, 50, 50, 0, 50]],
                         [[20, 20, 50, 20, 50, 50, 20, 50]]])
        box = np.array([[0, 0, 50, 50], [20, 20, 50, 50]])
        # It shall always be 0 in MMOCR, but we assign different labels to
        # dummy instances for testing
        labels = np.array([0, 1], dtype=np.int64)
        ignored = np.array([False, True], dtype=bool)
        texts = ['text1', 'text2']
        return dict(
            img=img,
            img_shape=(50, 50),
            gt_polygons=poly,
            gt_bboxes=box,
            gt_bboxes_labels=labels,
            gt_ignored=ignored,
            gt_texts=texts)

    def assertPolyEqual(self, poly1: List[np.ndarray],
                        poly2: List[np.ndarray]) -> None:
        for p1, p2 in zip(poly1, poly2):
            self.assertTrue(
                Polygon(p1.reshape(-1, 2)).equals(Polygon(p2.reshape(-1, 2))))

    def assert_result_equal(self,
                            results: Dict,
                            poly_targets: List[np.ndarray],
                            bbox_targets: np.ndarray,
                            bbox_label_targets: np.ndarray,
                            ignore_targets: np.ndarray,
                            text_targets: Optional[List[str]] = None) -> None:
        self.assertPolyEqual(poly_targets, results['gt_polygons'])
        self.assertTrue(np.array_equal(bbox_targets, results['gt_bboxes']))
        self.assertTrue(
            np.array_equal(bbox_label_targets, results['gt_bboxes_labels']))
        self.assertTrue(np.array_equal(ignore_targets, results['gt_ignored']))
        self.assertEqual(text_targets, results['gt_texts'])
        self.assertEqual(results['img_shape'],
                         (results['img'].shape[0], results['img'].shape[1]))

    def test_transform(self):

        # Test empty transform
        imgaug_transform = ImgAugWrapper()
        results = self._create_dummy_data()
        origin_results = copy.deepcopy(results)
        results = imgaug_transform(results)
        self.assert_result_equal(results, origin_results['gt_polygons'],
                                 origin_results['gt_bboxes'],
                                 origin_results['gt_bboxes_labels'],
                                 origin_results['gt_ignored'],
                                 origin_results['gt_texts'])

        args = [dict(cls='Affine', translate_px=dict(x=-10, y=-10))]
        imgaug_transform = ImgAugWrapper(args)
        results = self._create_dummy_data()
        results = imgaug_transform(results)

        # Polygons and bboxes are partially outside the image after
        # transformation
        poly_target = [
            np.array([0, 0, 40, 0, 40, 40, 0, 40]),
            np.array([10, 10, 40, 10, 40, 40, 10, 40])
        ]
        box_target = np.array([[0, 0, 40, 40], [10, 10, 40, 40]])
        label_target = np.array([0, 1], dtype=np.int64)
        ignored = np.array([False, True], dtype=bool)
        texts = ['text1', 'text2']
        self.assert_result_equal(results, poly_target, box_target,
                                 label_target, ignored, texts)

        # Some polygons and bboxes are no longer inside the image after
        # transformation
        args = [
            dict(cls='Affine', translate_px=dict(x=30, y=30)), ['Fliplr', 1]
        ]
        poly_target = [np.array([0, 30, 20, 30, 20, 50, 0, 50])]
        box_target = np.array([[0, 30, 20, 50]])
        label_target = np.array([0], dtype=np.int64)
        ignored = np.array([False], dtype=bool)
        texts = ['text1']
        imgaug_transform = ImgAugWrapper(args)
        results = self._create_dummy_data()
        results = imgaug_transform(results)
        self.assert_result_equal(results, poly_target, box_target,
                                 label_target, ignored, texts)

        # All polygons and bboxes are no longer inside the image after
        # transformation

        # When some transforms result in empty polygons
        args = [dict(cls='Affine', translate_px=dict(x=100, y=100))]
        results = self._create_dummy_data()
        invalid_transform = ImgAugWrapper(args)
        results = invalid_transform(results)
        self.assertIsNone(results)

        # Everything should work well without gt_texts
        results = self._create_dummy_data()
        del results['gt_texts']
        results = imgaug_transform(results)
        self.assertNotIn('gt_texts', results)

        # Everything should work well without keys required from text detection
        results = imgaug_transform(
            dict(
                img=np.random.rand(10, 20, 3),
                img_shape=(10, 20),
                gt_texts=['text1', 'text2']))
        self.assertEqual(results['gt_texts'], ['text1', 'text2'])

    def test_repr(self):
        args = [['Resize', [0.5, 3.0]], ['Fliplr', 0.5]]
        transform = ImgAugWrapper(args)
        print(repr(transform))
        self.assertEqual(
            repr(transform),
            ("ImgAugWrapper(args = [['Resize', [0.5, 3.0]], ['Fliplr', 0.5]])"
             ))


class TestTorchVisionWrapper(unittest.TestCase):

    def test_transform(self):
        x = {'img': np.ones((128, 100, 3), dtype=np.uint8)}
        # object not found error
        with self.assertRaises(Exception):
            TorchVisionWrapper(op='NonExist')
        with self.assertRaises(TypeError):
            TorchVisionWrapper()
        f = TorchVisionWrapper('Grayscale')
        with self.assertRaises(AssertionError):
            f({})
        results = f(x)
        assert results['img'].shape == (128, 100)
        assert results['img_shape'] == (128, 100)

    def test_repr(self):
        f = TorchVisionWrapper('Grayscale', num_output_channels=3)
        self.assertEqual(
            repr(f),
            'TorchVisionWrapper(op = Grayscale, num_output_channels = 3)')