mirror of https://github.com/open-mmlab/mmcv.git
107 lines
4.4 KiB
Python
107 lines
4.4 KiB
Python
import os
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
|
|
_USING_PARROTS = True
|
|
try:
|
|
from parrots.autograd import gradcheck
|
|
except ImportError:
|
|
from torch.autograd import gradcheck
|
|
|
|
_USING_PARROTS = False
|
|
|
|
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
inputs = ([[[[0.4369, -3.7571], [-1.1835, -1.6374], [0.9534, -0.1321]],
|
|
[[-0.4658, 0.2162], [-0.8135, -0.3903], [-0.1720, -0.0599]],
|
|
[[0.4851, 1.8224], [0.8973, 0.3779], [2.3454, 1.0319]],
|
|
[[0.0420, 0.3574], [0.7641, 0.2384], [0.2759, 0.4931]]],
|
|
[[[-0.5897, 0.7544], [1.0593, 0.8388], [-0.5732, 0.5692]],
|
|
[[-0.6766, -1.4657], [1.2362, 0.4913], [-1.1820, -1.4341]],
|
|
[[0.6476, -0.7391], [1.4314, -0.3522], [0.8401, -0.7757]],
|
|
[[1.4306, 0.9726], [1.0518, -0.8820], [-0.5129, -0.7876]]]])
|
|
|
|
shifts = [([[1, 0, 1, -2], [-2, 1, -1, 1]]), ([[2, 1, 2, -1], [-1, 2, 0, 2]])]
|
|
|
|
outputs = [([[[[0.4369, -3.7571], [-1.1835, -1.6374], [0.9534, -0.1321]],
|
|
[[-0.4658, 0.2162], [-0.8135, -0.3903], [-0.1720, -0.0599]],
|
|
[[0.4851, 1.8224], [0.8973, 0.3779], [2.3454, 1.0319]],
|
|
[[0.0420, 0.3574], [0.7641, 0.2384], [0.2759, 0.4931]]],
|
|
[[[0.6476, -0.7391], [1.4314, -0.3522], [0.8401, -0.7757]],
|
|
[[1.4306, 0.9726], [1.0518, -0.8820], [-0.5129, -0.7876]],
|
|
[[0.0000, 0.0000], [0.0000, 0.0000], [0.0000, 0.0000]],
|
|
[[0.0000, 0.0000], [0.0000, 0.0000], [0.0000, 0.0000]]]]),
|
|
([[[[0.4369, -3.7571], [-1.1835, -1.6374], [0.9534, -0.1321]],
|
|
[[-0.4658, 0.2162], [-0.8135, -0.3903], [-0.1720, -0.0599]],
|
|
[[0.4851, 1.8224], [0.8973, 0.3779], [2.3454, 1.0319]],
|
|
[[0.0420, 0.3574], [0.7641, 0.2384], [0.2759, 0.4931]]],
|
|
[[[-0.6766, -1.4657], [1.2362, 0.4913], [-1.1820, -1.4341]],
|
|
[[0.6476, -0.7391], [1.4314, -0.3522], [0.8401, -0.7757]],
|
|
[[1.4306, 0.9726], [1.0518, -0.8820], [-0.5129, -0.7876]],
|
|
[[0.0000, 0.0000], [0.0000, 0.0000], [0.0000, 0.0000]]]])]
|
|
|
|
grads = [[[[[1., 1.], [1., 1.], [1., 1.]], [[1., 1.], [1., 1.], [1., 1.]],
|
|
[[1., 1.], [1., 1.], [1., 1.]], [[1., 1.], [1., 1.], [1., 1.]]],
|
|
[[[1., 1.], [1., 1.], [1., 1.]], [[1., 1.], [1., 1.], [1., 1.]],
|
|
[[0., 0.], [0., 0.], [0., 0.]], [[0., 0.], [0., 0.], [0., 0.]]]],
|
|
[[[[1., 1.], [1., 1.], [1., 1.]], [[1., 1.], [1., 1.], [1., 1.]],
|
|
[[1., 1.], [1., 1.], [1., 1.]], [[1., 1.], [1., 1.], [1., 1.]]],
|
|
[[[1., 1.], [1., 1.], [1., 1.]], [[1., 1.], [1., 1.], [1., 1.]],
|
|
[[1., 1.], [1., 1.], [1., 1.]], [[0., 0.], [0., 0.], [0., 0.]]]]]
|
|
|
|
|
|
def _test_tinshift_gradcheck(dtype):
|
|
try:
|
|
from mmcv.ops import tin_shift
|
|
except ModuleNotFoundError:
|
|
pytest.skip('TinShift op is not successfully compiled')
|
|
|
|
if dtype == torch.half:
|
|
pytest.skip('"add_cpu/sub_cpu" not implemented for Half')
|
|
|
|
for shift in shifts:
|
|
np_input = np.array(inputs)
|
|
np_shift = np.array(shift)
|
|
|
|
x = torch.tensor(
|
|
np_input, dtype=dtype, device='cuda', requires_grad=True)
|
|
shift = torch.tensor(np_shift, device='cuda').int()
|
|
if torch.__version__ == 'parrots':
|
|
gradcheck(tin_shift, (x, shift))
|
|
else:
|
|
gradcheck(tin_shift, (x, shift), atol=1, rtol=0.1)
|
|
|
|
|
|
def _test_tinshift_allclose(dtype):
|
|
try:
|
|
from mmcv.ops import tin_shift
|
|
except ModuleNotFoundError:
|
|
pytest.skip('TinShift op is not successfully compiled')
|
|
|
|
for shift, output, grad in zip(shifts, outputs, grads):
|
|
np_input = np.array(inputs)
|
|
np_shift = np.array(shift)
|
|
np_output = np.array(output)
|
|
np_grad = np.array(grad)
|
|
|
|
x = torch.tensor(
|
|
np_input, dtype=dtype, device='cuda', requires_grad=True)
|
|
shift = torch.tensor(np_shift, device='cuda').int()
|
|
|
|
output = tin_shift(x, shift)
|
|
output.backward(torch.ones_like(output))
|
|
assert np.allclose(
|
|
output.data.type(torch.float).cpu().numpy(), np_output, 1e-3)
|
|
assert np.allclose(
|
|
x.grad.data.type(torch.float).cpu().numpy(), np_grad, 1e-3)
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not torch.cuda.is_available(), reason='requires CUDA support')
|
|
@pytest.mark.parametrize('dtype', [torch.float, torch.double, torch.half])
|
|
def test_tinshift(dtype):
|
|
_test_tinshift_allclose(dtype=dtype)
|
|
_test_tinshift_gradcheck(dtype=dtype)
|