[Fix] Reduce unit test memory usage of T2T-ViT

pull/787/merge
mzr1996 2022-05-16 17:01:30 +08:00
parent 73c056b79f
commit b5193a9029
1 changed files with 12 additions and 12 deletions

View File

@ -89,7 +89,7 @@ class TestT2TViT(TestCase):
os.remove(checkpoint)
def test_forward(self):
imgs = torch.randn(3, 3, 224, 224)
imgs = torch.randn(1, 3, 224, 224)
# test with_cls_token=False
cfg = deepcopy(self.cfg)
@ -106,7 +106,7 @@ class TestT2TViT(TestCase):
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token = outs[-1]
self.assertEqual(patch_token.shape, (3, 384, 14, 14))
self.assertEqual(patch_token.shape, (1, 384, 14, 14))
# test with output_cls_token
cfg = deepcopy(self.cfg)
@ -115,8 +115,8 @@ class TestT2TViT(TestCase):
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token, cls_token = outs[-1]
self.assertEqual(patch_token.shape, (3, 384, 14, 14))
self.assertEqual(cls_token.shape, (3, 384))
self.assertEqual(patch_token.shape, (1, 384, 14, 14))
self.assertEqual(cls_token.shape, (1, 384))
# test without output_cls_token
cfg = deepcopy(self.cfg)
@ -126,7 +126,7 @@ class TestT2TViT(TestCase):
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token = outs[-1]
self.assertEqual(patch_token.shape, (3, 384, 14, 14))
self.assertEqual(patch_token.shape, (1, 384, 14, 14))
# Test forward with multi out indices
cfg = deepcopy(self.cfg)
@ -137,13 +137,13 @@ class TestT2TViT(TestCase):
self.assertEqual(len(outs), 3)
for out in outs:
patch_token, cls_token = out
self.assertEqual(patch_token.shape, (3, 384, 14, 14))
self.assertEqual(cls_token.shape, (3, 384))
self.assertEqual(patch_token.shape, (1, 384, 14, 14))
self.assertEqual(cls_token.shape, (1, 384))
# Test forward with dynamic input size
imgs1 = torch.randn(3, 3, 224, 224)
imgs2 = torch.randn(3, 3, 256, 256)
imgs3 = torch.randn(3, 3, 256, 309)
imgs1 = torch.randn(1, 3, 224, 224)
imgs2 = torch.randn(1, 3, 256, 256)
imgs3 = torch.randn(1, 3, 256, 309)
cfg = deepcopy(self.cfg)
model = T2T_ViT(**cfg)
for imgs in [imgs1, imgs2, imgs3]:
@ -153,5 +153,5 @@ class TestT2TViT(TestCase):
patch_token, cls_token = outs[-1]
expect_feat_shape = (math.ceil(imgs.shape[2] / 16),
math.ceil(imgs.shape[3] / 16))
self.assertEqual(patch_token.shape, (3, 384, *expect_feat_shape))
self.assertEqual(cls_token.shape, (3, 384))
self.assertEqual(patch_token.shape, (1, 384, *expect_feat_shape))
self.assertEqual(cls_token.shape, (1, 384))