[Fix] Fix the unit test of correlation op (#1659)

pull/1666/head
Zaida Zhou 2022-01-13 14:27:50 +08:00 committed by GitHub
parent ccdc61c087
commit 227037fcd3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 4 deletions

View File

@ -30,10 +30,13 @@ class TestCorrelation:
out = layer(input1, input2)
out.backward(torch.ones_like(out))
gt_out = torch.tensor(_gt_out, dtype=dtype)
assert_equal_tensor(out.cpu(), gt_out)
assert_equal_tensor(input1.grad.detach().cpu(), input2.cpu())
assert_equal_tensor(input2.grad.detach().cpu(), input1.cpu())
# `eq_cpu` is not implemented for 'Half' in torch1.5.0,
# so we need to make a comparison for cuda tensor
# rather than cpu tensor
gt_out = torch.tensor(_gt_out, dtype=dtype).cuda()
assert_equal_tensor(out, gt_out)
assert_equal_tensor(input1.grad.detach(), input2)
assert_equal_tensor(input2.grad.detach(), input1)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')