# Copyright (c) OpenMMLab. All rights reserved.
import copy
import unittest

import numpy as np

from mmocr.datasets.transforms import (PadToWidth, PyramidRescale,
                                       RescaleToHeight)


class TestPadToWidth(unittest.TestCase):

    def test_pad_to_width(self):
        data_info = dict(img=np.random.random((16, 25, 3)))
        # test size and size_divisor are both set
        with self.assertRaises(AssertionError):
            PadToWidth(width=10.5)

        transform = PadToWidth(width=100)
        results = transform(copy.deepcopy(data_info))
        self.assertTupleEqual(results['img'].shape[:2], (16, 100))
        self.assertEqual(results['valid_ratio'], 25 / 100)

    def test_repr(self):
        transform = PadToWidth(width=100)
        self.assertEqual(
            repr(transform),
            ("PadToWidth(width=100, pad_cfg={'type': 'Pad'})"))


class TestPyramidRescale(unittest.TestCase):

    def setUp(self):
        self.data_info = dict(img=np.random.random((128, 100, 3)))

    def test_init(self):
        # factor is int
        transform = PyramidRescale(factor=4, randomize_factor=False)
        self.assertEqual(transform.factor, 4)
        # factor is float
        with self.assertRaisesRegex(TypeError,
                                    '`factor` should be an integer'):
            PyramidRescale(factor=4.0)
        # invalid base_shape
        with self.assertRaisesRegex(TypeError,
                                    '`base_shape` should be a list or tuple'):
            PyramidRescale(base_shape=128)
        with self.assertRaisesRegex(
                ValueError, '`base_shape` should contain two integers'):
            PyramidRescale(base_shape=(128, ))
        with self.assertRaisesRegex(
                ValueError, '`base_shape` should contain two integers'):
            PyramidRescale(base_shape=(128.0, 2.0))
        # invalid randomize_factor
        with self.assertRaisesRegex(TypeError,
                                    '`randomize_factor` should be a bool'):
            PyramidRescale(randomize_factor=None)

    def test_transform(self):
        # test if the rescale keeps the original size
        transform = PyramidRescale()
        results = transform(copy.deepcopy(self.data_info))
        self.assertEqual(results['img'].shape, (128, 100, 3))
        # test factor = 0
        transform = PyramidRescale(factor=0, randomize_factor=False)
        results = transform(copy.deepcopy(self.data_info))
        self.assertTrue(np.all(results['img'] == self.data_info['img']))

    def test_repr(self):
        transform = PyramidRescale(
            factor=4, base_shape=(128, 512), randomize_factor=False)
        self.assertEqual(
            repr(transform),
            ('PyramidRescale(factor = 4, randomize_factor = False, '
             'base_w = 128, base_h = 512)'))


class TestRescaleToHeight(unittest.TestCase):

    def test_rescale_height(self):
        data_info = dict(
            img=np.random.random((16, 25, 3)),
            gt_seg_map=np.random.random((16, 25, 3)),
            gt_bboxes=np.array([[0, 0, 10, 10]]),
            gt_keypoints=np.array([[[10, 10, 1]]]))
        with self.assertRaises(AssertionError):
            RescaleToHeight(height=20.9)
        with self.assertRaises(AssertionError):
            RescaleToHeight(height=20, min_width=20.9)
        with self.assertRaises(AssertionError):
            RescaleToHeight(height=20, max_width=20.9)
        with self.assertRaises(AssertionError):
            RescaleToHeight(height=20, width_divisor=0.5)
        transform = RescaleToHeight(height=32)
        results = transform(copy.deepcopy(data_info))
        self.assertTupleEqual(results['img'].shape[:2], (32, 50))
        self.assertTupleEqual(results['scale'], (50, 32))
        self.assertTupleEqual(results['scale_factor'], (50 / 25, 32 / 16))

        # test min_width
        transform = RescaleToHeight(height=32, min_width=60)
        results = transform(copy.deepcopy(data_info))
        self.assertTupleEqual(results['img'].shape[:2], (32, 60))
        self.assertTupleEqual(results['scale'], (60, 32))
        self.assertTupleEqual(results['scale_factor'], (60 / 25, 32 / 16))

        # test max_width
        transform = RescaleToHeight(height=32, max_width=45)
        results = transform(copy.deepcopy(data_info))
        self.assertTupleEqual(results['img'].shape[:2], (32, 45))
        self.assertTupleEqual(results['scale'], (45, 32))
        self.assertTupleEqual(results['scale_factor'], (45 / 25, 32 / 16))

        # test width_divisor
        transform = RescaleToHeight(height=32, width_divisor=4)
        results = transform(copy.deepcopy(data_info))
        self.assertTupleEqual(results['img'].shape[:2], (32, 48))
        self.assertTupleEqual(results['scale'], (48, 32))
        self.assertTupleEqual(results['scale_factor'], (48 / 25, 32 / 16))

    def test_repr(self):
        transform = RescaleToHeight(height=32)
        self.assertEqual(
            repr(transform), ('RescaleToHeight(height=32, '
                              'min_width=None, max_width=None, '
                              'width_divisor=1, '
                              "resize_cfg={'type': 'Resize', 'scale': 0})"))