mirror of https://github.com/FoundationVision/GLEE
46 lines
1.2 KiB
Python
46 lines
1.2 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
import unittest
|
|
import torch
|
|
|
|
from detectron2.modeling.meta_arch import GeneralizedRCNN
|
|
from detectron2.utils.registry import _convert_target_to_string, locate
|
|
|
|
|
|
class A:
|
|
class B:
|
|
pass
|
|
|
|
|
|
class TestLocate(unittest.TestCase):
|
|
def _test_obj(self, obj):
|
|
name = _convert_target_to_string(obj)
|
|
newobj = locate(name)
|
|
self.assertIs(obj, newobj)
|
|
|
|
def test_basic(self):
|
|
self._test_obj(GeneralizedRCNN)
|
|
|
|
def test_inside_class(self):
|
|
# requires using __qualname__ instead of __name__
|
|
self._test_obj(A.B)
|
|
|
|
def test_builtin(self):
|
|
self._test_obj(len)
|
|
self._test_obj(dict)
|
|
|
|
def test_pytorch_optim(self):
|
|
# pydoc.locate does not work for it
|
|
self._test_obj(torch.optim.SGD)
|
|
|
|
def test_failure(self):
|
|
with self.assertRaises(ImportError):
|
|
locate("asdf")
|
|
|
|
def test_compress_target(self):
|
|
from detectron2.data.transforms import RandomCrop
|
|
|
|
name = _convert_target_to_string(RandomCrop)
|
|
# name shouldn't contain 'augmentation_impl'
|
|
self.assertEqual(name, "detectron2.data.transforms.RandomCrop")
|
|
self.assertIs(RandomCrop, locate(name))
|