# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import random
import unittest

import numpy as np
from tests.ut_config import DET_DATA_RAW_LOCAL

from easycv.datasets.detection.data_sources.raw import DetSourceRaw
from easycv.file import io


class DetSourceRawTest(unittest.TestCase):

    def setUp(self):
        print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
        data_root = DET_DATA_RAW_LOCAL
        for cache_file in io.glob(os.path.join(data_root, 'labels/*.cache')):
            io.remove(cache_file)

    def _base_test(self, data_source):
        index_list = random.choices(list(range(20)), k=3)
        for idx in index_list:
            data = data_source.get_sample(idx)
            self.assertIn('img_shape', data)
            self.assertIn('ori_img_shape', data)
            self.assertEqual(len(data['img_shape']), 3)
            self.assertEqual(data['img_fields'], ['img'])
            self.assertEqual(data['bbox_fields'], ['gt_bboxes'])
            self.assertEqual(data['gt_bboxes'].shape[-1], 4)
            self.assertGreaterEqual(len(data['gt_labels']), 1)
            self.assertIn('filename', data)
            self.assertEqual(data['img'].shape[-1], 3)

        exclude_idx = [i for i in list(range(20)) if i not in index_list]
        for i in range(3):

            if data_source.cache_at_init:
                self.assertIn('img', data_source.samples_list[exclude_idx[i]])
            if data_source.cache_on_the_fly:
                self.assertNotIn('img',
                                 data_source.samples_list[exclude_idx[i]])

        length = data_source.get_length()
        self.assertEqual(length, 128)

        exists = False
        for idx in range(length):
            result = data_source.get_sample(idx)
            file_name = result.get('filename', '')
            if file_name.endswith('000000000086.jpg'):
                exists = True
                self.assertEqual(result['img_shape'], (640, 512, 3))
                self.assertEqual(result['gt_labels'].tolist(),
                                 np.array([0, 3, 26], dtype=np.int32).tolist())
                self.assertEqual(
                    result['gt_bboxes'].astype(np.int32).tolist(),
                    np.array([[152.11008, 183.09982, 281.26004, 559.0697],
                              [130.13991, 346.51968, 424.40015, 635.15967],
                              [265.10004, 294.5901, 399.7002, 404.7699]],
                             dtype=np.int32).tolist())
        self.assertTrue(exists)

    def test_default(self):
        data_root = DET_DATA_RAW_LOCAL
        data_source = DetSourceRaw(
            img_root_path=os.path.join(data_root, 'images/train2017'),
            label_root_path=os.path.join(data_root, 'labels/train2017'),
        )
        self._base_test(data_source)

    def test_cache_on_the_fly(self):
        data_root = DET_DATA_RAW_LOCAL
        data_source = DetSourceRaw(
            img_root_path=os.path.join(data_root, 'images/train2017'),
            label_root_path=os.path.join(data_root, 'labels/train2017'),
            cache_on_the_fly=True,
            cache_at_init=False,
        )
        self._base_test(data_source)

    def test_cache_at_init(self):
        data_root = DET_DATA_RAW_LOCAL
        data_source = DetSourceRaw(
            img_root_path=os.path.join(data_root, 'images/train2017'),
            label_root_path=os.path.join(data_root, 'labels/train2017'),
            cache_on_the_fly=False,
            cache_at_init=True,
        )
        self._base_test(data_source)


if __name__ == '__main__':
    unittest.main()