# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import unittest
from typing import List, Sequence, Tuple
import torch

from detectron2.structures import ImageList
from detectron2.utils.env import TORCH_VERSION


class TestImageList(unittest.TestCase):
    def test_imagelist_padding_shape(self):
        class TensorToImageList(torch.nn.Module):
            def forward(self, tensors: Sequence[torch.Tensor]):
                return ImageList.from_tensors(tensors, 4).tensor

        func = torch.jit.trace(
            TensorToImageList(), ([torch.ones((3, 10, 10), dtype=torch.float32)],)
        )
        ret = func([torch.ones((3, 15, 20), dtype=torch.float32)])
        self.assertEqual(list(ret.shape), [1, 3, 16, 20], str(ret.shape))

        func = torch.jit.trace(
            TensorToImageList(),
            (
                [
                    torch.ones((3, 16, 10), dtype=torch.float32),
                    torch.ones((3, 13, 11), dtype=torch.float32),
                ],
            ),
        )
        ret = func(
            [
                torch.ones((3, 25, 20), dtype=torch.float32),
                torch.ones((3, 10, 10), dtype=torch.float32),
            ]
        )
        # does not support calling with different #images
        self.assertEqual(list(ret.shape), [2, 3, 28, 20], str(ret.shape))

    @unittest.skipIf(TORCH_VERSION < (1, 6), "Insufficient pytorch version")
    def test_imagelist_scriptability(self):
        image_nums = 2
        image_tensor = torch.randn((image_nums, 10, 20), dtype=torch.float32)
        image_shape = [(10, 20)] * image_nums

        def f(image_tensor, image_shape: List[Tuple[int, int]]):
            return ImageList(image_tensor, image_shape)

        ret = f(image_tensor, image_shape)
        ret_script = torch.jit.script(f)(image_tensor, image_shape)

        self.assertEqual(len(ret), len(ret_script))
        for i in range(image_nums):
            self.assertTrue(torch.equal(ret[i], ret_script[i]))


if __name__ == "__main__":
    unittest.main()