mirror of https://github.com/RE-OWOD/RE-OWOD
60 lines
2.0 KiB
Python
60 lines
2.0 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
|
|
import unittest
|
|
from typing import List, Sequence, Tuple
|
|
import torch
|
|
|
|
from detectron2.structures import ImageList
|
|
from detectron2.utils.env import TORCH_VERSION
|
|
|
|
|
|
class TestImageList(unittest.TestCase):
|
|
def test_imagelist_padding_shape(self):
|
|
class TensorToImageList(torch.nn.Module):
|
|
def forward(self, tensors: Sequence[torch.Tensor]):
|
|
return ImageList.from_tensors(tensors, 4).tensor
|
|
|
|
func = torch.jit.trace(
|
|
TensorToImageList(), ([torch.ones((3, 10, 10), dtype=torch.float32)],)
|
|
)
|
|
ret = func([torch.ones((3, 15, 20), dtype=torch.float32)])
|
|
self.assertEqual(list(ret.shape), [1, 3, 16, 20], str(ret.shape))
|
|
|
|
func = torch.jit.trace(
|
|
TensorToImageList(),
|
|
(
|
|
[
|
|
torch.ones((3, 16, 10), dtype=torch.float32),
|
|
torch.ones((3, 13, 11), dtype=torch.float32),
|
|
],
|
|
),
|
|
)
|
|
ret = func(
|
|
[
|
|
torch.ones((3, 25, 20), dtype=torch.float32),
|
|
torch.ones((3, 10, 10), dtype=torch.float32),
|
|
]
|
|
)
|
|
# does not support calling with different #images
|
|
self.assertEqual(list(ret.shape), [2, 3, 28, 20], str(ret.shape))
|
|
|
|
@unittest.skipIf(TORCH_VERSION < (1, 6), "Insufficient pytorch version")
|
|
def test_imagelist_scriptability(self):
|
|
image_nums = 2
|
|
image_tensor = torch.randn((image_nums, 10, 20), dtype=torch.float32)
|
|
image_shape = [(10, 20)] * image_nums
|
|
|
|
def f(image_tensor, image_shape: List[Tuple[int, int]]):
|
|
return ImageList(image_tensor, image_shape)
|
|
|
|
ret = f(image_tensor, image_shape)
|
|
ret_script = torch.jit.script(f)(image_tensor, image_shape)
|
|
|
|
self.assertEqual(len(ret), len(ret_script))
|
|
for i in range(image_nums):
|
|
self.assertTrue(torch.equal(ret[i], ret_script[i]))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|