# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from copy import deepcopy
from unittest import TestCase
from unittest.mock import MagicMock

from mmocr.datasets import ConcatDataset, OCRDataset
from mmocr.registry import TRANSFORMS
from mmocr.utils import register_all_modules


class TestConcatDataset(TestCase):

    @TRANSFORMS.register_module()
    class MockTransform:

        def __init__(self, return_value):
            self.return_value = return_value

        def __call__(self, *args, **kwargs):
            return self.return_value

    def setUp(self):

        register_all_modules()
        dataset = OCRDataset

        # create dataset_a
        data_info = dict(filename='img_1.jpg', height=720, width=1280)
        dataset.parse_data_info = MagicMock(return_value=data_info)

        self.dataset_a = dataset(
            data_root=osp.join(
                osp.dirname(__file__), '../data/det_toy_dataset'),
            data_prefix=dict(img_path='imgs'),
            ann_file='instances_test.json')

        self.dataset_a_with_pipeline = dataset(
            data_root=osp.join(
                osp.dirname(__file__), '../data/det_toy_dataset'),
            data_prefix=dict(img_path='imgs'),
            ann_file='instances_test.json',
            pipeline=[dict(type='MockTransform', return_value=1)])

        # create dataset_b
        data_info = dict(filename='img_2.jpg', height=720, width=1280)
        dataset.parse_data_info = MagicMock(return_value=data_info)
        self.dataset_b = dataset(
            data_root=osp.join(
                osp.dirname(__file__), '../data/det_toy_dataset'),
            data_prefix=dict(img_path='imgs'),
            ann_file='instances_test.json')
        self.dataset_b_with_pipeline = dataset(
            data_root=osp.join(
                osp.dirname(__file__), '../data/det_toy_dataset'),
            data_prefix=dict(img_path='imgs'),
            ann_file='instances_test.json',
            pipeline=[dict(type='MockTransform', return_value=2)])

    def test_init(self):
        with self.assertRaises(TypeError):
            ConcatDataset(datasets=[0])
        with self.assertRaises(ValueError):
            ConcatDataset(
                datasets=[
                    deepcopy(self.dataset_a_with_pipeline),
                    deepcopy(self.dataset_b)
                ],
                pipeline=[dict(type='MockTransform', return_value=3)])

        with self.assertRaises(ValueError):
            ConcatDataset(
                datasets=[
                    deepcopy(self.dataset_a),
                    deepcopy(self.dataset_b_with_pipeline)
                ],
                pipeline=[dict(type='MockTransform', return_value=3)])
        with self.assertRaises(ValueError):
            dataset_a = deepcopy(self.dataset_a)
            dataset_b = OCRDataset(
                metainfo=dict(dummy='dummy'),
                data_root=osp.join(
                    osp.dirname(__file__), '../data/det_toy_dataset'),
                data_prefix=dict(img_path='imgs'),
                ann_file='instances_test.json')
            ConcatDataset(datasets=[dataset_a, dataset_b])
        # test lazy init
        ConcatDataset(
            datasets=[deepcopy(self.dataset_a),
                      deepcopy(self.dataset_b)],
            pipeline=[dict(type='MockTransform', return_value=3)],
            lazy_init=True)

    def test_getitem(self):
        cat_datasets = ConcatDataset(
            datasets=[deepcopy(self.dataset_a),
                      deepcopy(self.dataset_b)],
            pipeline=[dict(type='MockTransform', return_value=3)])
        for datum in cat_datasets:
            self.assertEqual(datum, 3)

        cat_datasets = ConcatDataset(
            datasets=[
                deepcopy(self.dataset_a_with_pipeline),
                deepcopy(self.dataset_b)
            ],
            pipeline=[dict(type='MockTransform', return_value=3)],
            force_apply=True)
        for datum in cat_datasets:
            self.assertEqual(datum, 3)

        cat_datasets = ConcatDataset(datasets=[
            deepcopy(self.dataset_a_with_pipeline),
            deepcopy(self.dataset_b_with_pipeline)
        ])
        self.assertEqual(cat_datasets[0], 1)
        self.assertEqual(cat_datasets[-1], 2)