diff --git a/tests/test_models/test_backbones/test_t2t_vit.py b/tests/test_models/test_backbones/test_t2t_vit.py index cc7e839c..a7e6861c 100644 --- a/tests/test_models/test_backbones/test_t2t_vit.py +++ b/tests/test_models/test_backbones/test_t2t_vit.py @@ -89,7 +89,7 @@ class TestT2TViT(TestCase): os.remove(checkpoint) def test_forward(self): - imgs = torch.randn(3, 3, 224, 224) + imgs = torch.randn(1, 3, 224, 224) # test with_cls_token=False cfg = deepcopy(self.cfg) @@ -106,7 +106,7 @@ class TestT2TViT(TestCase): self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 1) patch_token = outs[-1] - self.assertEqual(patch_token.shape, (3, 384, 14, 14)) + self.assertEqual(patch_token.shape, (1, 384, 14, 14)) # test with output_cls_token cfg = deepcopy(self.cfg) @@ -115,8 +115,8 @@ class TestT2TViT(TestCase): self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 1) patch_token, cls_token = outs[-1] - self.assertEqual(patch_token.shape, (3, 384, 14, 14)) - self.assertEqual(cls_token.shape, (3, 384)) + self.assertEqual(patch_token.shape, (1, 384, 14, 14)) + self.assertEqual(cls_token.shape, (1, 384)) # test without output_cls_token cfg = deepcopy(self.cfg) @@ -126,7 +126,7 @@ class TestT2TViT(TestCase): self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 1) patch_token = outs[-1] - self.assertEqual(patch_token.shape, (3, 384, 14, 14)) + self.assertEqual(patch_token.shape, (1, 384, 14, 14)) # Test forward with multi out indices cfg = deepcopy(self.cfg) @@ -137,13 +137,13 @@ class TestT2TViT(TestCase): self.assertEqual(len(outs), 3) for out in outs: patch_token, cls_token = out - self.assertEqual(patch_token.shape, (3, 384, 14, 14)) - self.assertEqual(cls_token.shape, (3, 384)) + self.assertEqual(patch_token.shape, (1, 384, 14, 14)) + self.assertEqual(cls_token.shape, (1, 384)) # Test forward with dynamic input size - imgs1 = torch.randn(3, 3, 224, 224) - imgs2 = torch.randn(3, 3, 256, 256) - imgs3 = torch.randn(3, 3, 256, 309) + imgs1 = torch.randn(1, 3, 224, 224) + imgs2 = torch.randn(1, 3, 256, 256) + imgs3 = torch.randn(1, 3, 256, 309) cfg = deepcopy(self.cfg) model = T2T_ViT(**cfg) for imgs in [imgs1, imgs2, imgs3]: @@ -153,5 +153,5 @@ class TestT2TViT(TestCase): patch_token, cls_token = outs[-1] expect_feat_shape = (math.ceil(imgs.shape[2] / 16), math.ceil(imgs.shape[3] / 16)) - self.assertEqual(patch_token.shape, (3, 384, *expect_feat_shape)) - self.assertEqual(cls_token.shape, (3, 384)) + self.assertEqual(patch_token.shape, (1, 384, *expect_feat_shape)) + self.assertEqual(cls_token.shape, (1, 384))