# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
import unittest

import numpy as np
from mmengine.data import BaseDataElement

from mmseg.data import SegDataSample
from mmseg.datasets.transforms import PackSegInputs


class TestPackSegInputs(unittest.TestCase):

    def setUp(self):
        """Setup the model and optimizer which are used in every test method.

        TestCase calls functions in this order: setUp() -> testMethod() ->
        tearDown() -> cleanUp()
        """
        data_prefix = osp.join(osp.dirname(__file__), '../../data')
        img_path = osp.join(data_prefix, 'color.jpg')
        rng = np.random.RandomState(0)
        self.results = {
            'img_path': img_path,
            'ori_shape': (300, 400),
            'pad_shape': (600, 800),
            'img_shape': (600, 800),
            'scale_factor': 2.0,
            'flip': False,
            'flip_direction': 'horizontal',
            'img_norm_cfg': None,
            'img': rng.rand(300, 400),
            'gt_seg_map': rng.rand(300, 400),
        }
        self.meta_keys = ('img_path', 'ori_shape', 'img_shape', 'pad_shape',
                          'scale_factor', 'flip', 'flip_direction')

    def test_transform(self):
        transform = PackSegInputs(meta_keys=self.meta_keys)
        results = transform(copy.deepcopy(self.results))
        self.assertIn('data_sample', results)
        self.assertIsInstance(results['data_sample'], SegDataSample)
        self.assertIsInstance(results['data_sample'].gt_sem_seg,
                              BaseDataElement)
        self.assertEqual(results['data_sample'].ori_shape,
                         results['data_sample'].gt_sem_seg.shape)

    def test_repr(self):
        transform = PackSegInputs(meta_keys=self.meta_keys)
        self.assertEqual(
            repr(transform), f'PackSegInputs(meta_keys={self.meta_keys})')