From 227037fcd35e52703b9e436b57c857f0c1ecc446 Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Thu, 13 Jan 2022 14:27:50 +0800 Subject: [PATCH] [Fix] Fix the unit test of correlation op (#1659) --- tests/test_ops/test_correlation.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/test_ops/test_correlation.py b/tests/test_ops/test_correlation.py index 6b75a9f38..6cf5f9f72 100644 --- a/tests/test_ops/test_correlation.py +++ b/tests/test_ops/test_correlation.py @@ -30,10 +30,13 @@ class TestCorrelation: out = layer(input1, input2) out.backward(torch.ones_like(out)) - gt_out = torch.tensor(_gt_out, dtype=dtype) - assert_equal_tensor(out.cpu(), gt_out) - assert_equal_tensor(input1.grad.detach().cpu(), input2.cpu()) - assert_equal_tensor(input2.grad.detach().cpu(), input1.cpu()) + # `eq_cpu` is not implemented for 'Half' in torch1.5.0, + # so we need to make a comparison for cuda tensor + # rather than cpu tensor + gt_out = torch.tensor(_gt_out, dtype=dtype).cuda() + assert_equal_tensor(out, gt_out) + assert_equal_tensor(input1.grad.detach(), input2) + assert_equal_tensor(input2.grad.detach(), input1) @pytest.mark.skipif( not torch.cuda.is_available(), reason='requires CUDA support')