[Fix] Reduce unit test memory usage of T2T-ViT
parent
73c056b79f
commit
b5193a9029
|
@ -89,7 +89,7 @@ class TestT2TViT(TestCase):
|
|||
os.remove(checkpoint)
|
||||
|
||||
def test_forward(self):
|
||||
imgs = torch.randn(3, 3, 224, 224)
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
|
||||
# test with_cls_token=False
|
||||
cfg = deepcopy(self.cfg)
|
||||
|
@ -106,7 +106,7 @@ class TestT2TViT(TestCase):
|
|||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (3, 384, 14, 14))
|
||||
self.assertEqual(patch_token.shape, (1, 384, 14, 14))
|
||||
|
||||
# test with output_cls_token
|
||||
cfg = deepcopy(self.cfg)
|
||||
|
@ -115,8 +115,8 @@ class TestT2TViT(TestCase):
|
|||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token, cls_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (3, 384, 14, 14))
|
||||
self.assertEqual(cls_token.shape, (3, 384))
|
||||
self.assertEqual(patch_token.shape, (1, 384, 14, 14))
|
||||
self.assertEqual(cls_token.shape, (1, 384))
|
||||
|
||||
# test without output_cls_token
|
||||
cfg = deepcopy(self.cfg)
|
||||
|
@ -126,7 +126,7 @@ class TestT2TViT(TestCase):
|
|||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (3, 384, 14, 14))
|
||||
self.assertEqual(patch_token.shape, (1, 384, 14, 14))
|
||||
|
||||
# Test forward with multi out indices
|
||||
cfg = deepcopy(self.cfg)
|
||||
|
@ -137,13 +137,13 @@ class TestT2TViT(TestCase):
|
|||
self.assertEqual(len(outs), 3)
|
||||
for out in outs:
|
||||
patch_token, cls_token = out
|
||||
self.assertEqual(patch_token.shape, (3, 384, 14, 14))
|
||||
self.assertEqual(cls_token.shape, (3, 384))
|
||||
self.assertEqual(patch_token.shape, (1, 384, 14, 14))
|
||||
self.assertEqual(cls_token.shape, (1, 384))
|
||||
|
||||
# Test forward with dynamic input size
|
||||
imgs1 = torch.randn(3, 3, 224, 224)
|
||||
imgs2 = torch.randn(3, 3, 256, 256)
|
||||
imgs3 = torch.randn(3, 3, 256, 309)
|
||||
imgs1 = torch.randn(1, 3, 224, 224)
|
||||
imgs2 = torch.randn(1, 3, 256, 256)
|
||||
imgs3 = torch.randn(1, 3, 256, 309)
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = T2T_ViT(**cfg)
|
||||
for imgs in [imgs1, imgs2, imgs3]:
|
||||
|
@ -153,5 +153,5 @@ class TestT2TViT(TestCase):
|
|||
patch_token, cls_token = outs[-1]
|
||||
expect_feat_shape = (math.ceil(imgs.shape[2] / 16),
|
||||
math.ceil(imgs.shape[3] / 16))
|
||||
self.assertEqual(patch_token.shape, (3, 384, *expect_feat_shape))
|
||||
self.assertEqual(cls_token.shape, (3, 384))
|
||||
self.assertEqual(patch_token.shape, (1, 384, *expect_feat_shape))
|
||||
self.assertEqual(cls_token.shape, (1, 384))
|
||||
|
|
Loading…
Reference in New Issue