mirror of https://github.com/open-mmlab/mmcv.git
[Fix] Fix the unit test of correlation op (#1659)
parent
ccdc61c087
commit
227037fcd3
|
@ -30,10 +30,13 @@ class TestCorrelation:
|
|||
out = layer(input1, input2)
|
||||
out.backward(torch.ones_like(out))
|
||||
|
||||
gt_out = torch.tensor(_gt_out, dtype=dtype)
|
||||
assert_equal_tensor(out.cpu(), gt_out)
|
||||
assert_equal_tensor(input1.grad.detach().cpu(), input2.cpu())
|
||||
assert_equal_tensor(input2.grad.detach().cpu(), input1.cpu())
|
||||
# `eq_cpu` is not implemented for 'Half' in torch1.5.0,
|
||||
# so we need to make a comparison for cuda tensor
|
||||
# rather than cpu tensor
|
||||
gt_out = torch.tensor(_gt_out, dtype=dtype).cuda()
|
||||
assert_equal_tensor(out, gt_out)
|
||||
assert_equal_tensor(input1.grad.detach(), input2)
|
||||
assert_equal_tensor(input2.grad.detach(), input1)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
|
|
Loading…
Reference in New Issue