mmcv/tests/test_ops/test_ms_deformable_attn.py

404 lines
15 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmcv.ops.multi_scale_deform_attn import (
MultiScaleDeformableAttention, MultiScaleDeformableAttnFunction,
multi_scale_deformable_attn_pytorch)
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
_USING_PARROTS = True
_IS_AUTOCAST_AVAILABLE = True
try:
from parrots.autograd import gradcheck
except ImportError:
from torch.autograd import gradcheck
_USING_PARROTS = False
try:
# If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast
# would be imported and used; we should test if our modules support it.
from torch.cuda.amp import autocast
except ImportError:
_IS_AUTOCAST_AVAILABLE = False
pass
@pytest.mark.parametrize('device', [
'cpu',
pytest.param(
'cuda:0',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
def test_multiscale_deformable_attention(device):
with pytest.raises(ValueError):
# embed_dims must be divisible by num_heads,
MultiScaleDeformableAttention(
embed_dims=256,
num_heads=7,
)
device = torch.device(device)
msda = MultiScaleDeformableAttention(
embed_dims=3, num_levels=2, num_heads=3)
msda.init_weights()
num_query = 5
bs = 1
embed_dims = 3
query = torch.rand(num_query, bs, embed_dims).to(device)
key = torch.rand(num_query, bs, embed_dims).to(device)
spatial_shapes = torch.Tensor([[2, 2], [1, 1]]).long().to(device)
level_start_index = torch.Tensor([0, 4]).long().to(device)
reference_points = torch.rand(bs, num_query, 2, 2).to(device)
msda.to(device)
msda(
query,
key,
key,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index)
# test with value_proj_ratio
embed_dims = 6
value_proj_ratio = 0.5
query = torch.rand(num_query, bs, embed_dims).to(device)
key = torch.rand(num_query, bs, embed_dims).to(device)
msda = MultiScaleDeformableAttention(
embed_dims=embed_dims,
num_levels=2,
num_heads=3,
value_proj_ratio=value_proj_ratio)
msda.init_weights()
msda.to(device)
msda(
query,
key,
key,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index)
def test_forward_multi_scale_deformable_attn_pytorch():
N, M, D = 1, 2, 2
Lq, L, P = 2, 2, 2
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long)
S = sum((H * W).item() for H, W in shapes)
torch.manual_seed(3)
value = torch.rand(N, S, M, D) * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2)
attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
attention_weights /= attention_weights.sum(
-1, keepdim=True).sum(
-2, keepdim=True)
multi_scale_deformable_attn_pytorch(value.double(), shapes,
sampling_locations.double(),
attention_weights.double()).detach()
@pytest.mark.skipif(not IS_CUDA_AVAILABLE, reason='requires CUDA support')
def test_forward_equal_with_pytorch_double():
N, M, D = 1, 2, 2
Lq, L, P = 2, 2, 2
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long)
level_start_index = torch.cat((shapes.new_zeros(
(1, )), shapes.prod(1).cumsum(0)[:-1]))
S = sum((H * W).item() for H, W in shapes)
torch.manual_seed(3)
value = torch.rand(N, S, M, D) * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2)
attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
attention_weights /= attention_weights.sum(
-1, keepdim=True).sum(
-2, keepdim=True)
im2col_step = 2
output_pytorch = multi_scale_deformable_attn_pytorch(
value.double(), shapes, sampling_locations.double(),
attention_weights.double()).detach().cpu()
output_cuda = MultiScaleDeformableAttnFunction.apply(
value.cuda().double(), shapes.cuda(), level_start_index.cuda(),
sampling_locations.cuda().double(),
attention_weights.cuda().double(), im2col_step).detach().cpu()
assert torch.allclose(output_cuda, output_pytorch)
max_abs_err = (output_cuda - output_pytorch).abs().max()
max_rel_err = ((output_cuda - output_pytorch).abs() /
output_pytorch.abs()).max()
assert max_abs_err < 1e-18
assert max_rel_err < 1e-15
@pytest.mark.skipif(not IS_NPU_AVAILABLE, reason='requires NPU support')
def test_forward_equal_with_pytorch_npu():
N, M, D = 6, 4, 8
Lq, L, P = 10000, 4, 8
shapes = torch.as_tensor([(60, 40), (30, 20), (16, 24), (53, 32)],
dtype=torch.int32)
level_start_index = torch.cat((shapes.new_zeros(
(1, )), shapes.prod(1).cumsum(0)[:-1]))
S = sum((H * W).item() for H, W in shapes)
torch.manual_seed(3)
value = torch.rand(N, S, M, D) * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2)
attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
attention_weights /= attention_weights.sum(
-1, keepdim=True).sum(
-2, keepdim=True)
im2col_step = 2
output_pytorch = multi_scale_deformable_attn_pytorch(
value.float(), shapes, sampling_locations.float(),
attention_weights.float()).detach().cpu()
output_npu = MultiScaleDeformableAttnFunction.apply(
value.npu().float(), shapes.npu(), level_start_index.npu(),
sampling_locations.npu().float(),
attention_weights.npu().float(), im2col_step).detach().cpu()
assert torch.allclose(output_npu, output_pytorch)
max_abs_err = (output_npu - output_pytorch).abs().max()
max_rel_err = ((output_npu - output_pytorch).abs() /
output_pytorch.abs()).max()
assert max_abs_err < 1e-18
assert max_rel_err < 1e-15
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
def test_forward_equal_with_pytorch_float(device):
N, M, D = 1, 2, 2
Lq, L, P = 2, 2, 2
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long)
level_start_index = torch.cat((shapes.new_zeros(
(1, )), shapes.prod(1).cumsum(0)[:-1]))
S = sum((H * W).item() for H, W in shapes)
torch.manual_seed(3)
value = torch.rand(N, S, M, D) * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2)
attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
attention_weights /= attention_weights.sum(
-1, keepdim=True).sum(
-2, keepdim=True)
im2col_step = 2
output_pytorch = multi_scale_deformable_attn_pytorch(
value, shapes, sampling_locations, attention_weights).detach().cpu()
output_device = MultiScaleDeformableAttnFunction.apply(
value.to(device), shapes.to(device), level_start_index.to(device),
sampling_locations.to(device), attention_weights.to(device),
im2col_step).detach().cpu()
assert torch.allclose(output_device, output_pytorch, rtol=1e-2, atol=1e-3)
max_abs_err = (output_device - output_pytorch).abs().max()
max_rel_err = ((output_device - output_pytorch).abs() /
output_pytorch.abs()).max()
assert max_abs_err < 1e-9
assert max_rel_err < 1e-6
@pytest.mark.skipif(
not _IS_AUTOCAST_AVAILABLE, reason='requires autocast support')
@pytest.mark.skipif(not IS_CUDA_AVAILABLE, reason='requires CUDA support')
def test_forward_equal_with_autocast():
N, M, D = 1, 2, 2
Lq, L, P = 2, 2, 2
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long)
level_start_index = torch.cat((shapes.new_zeros(
(1, )), shapes.prod(1).cumsum(0)[:-1]))
S = sum((H * W).item() for H, W in shapes)
torch.manual_seed(3)
value = torch.rand(N, S, M, D) * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2)
attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
attention_weights /= attention_weights.sum(
-1, keepdim=True).sum(
-2, keepdim=True)
im2col_step = 2
output_pytorch = multi_scale_deformable_attn_pytorch(
value, shapes, sampling_locations, attention_weights).detach().cpu()
# float test
dtype = torch.float
with autocast(enabled=True):
output_device = MultiScaleDeformableAttnFunction.apply(
value.cuda().type(dtype), shapes.cuda(), level_start_index.cuda(),
sampling_locations.cuda(), attention_weights.cuda(),
im2col_step).detach().cpu()
assert torch.allclose(output_device, output_pytorch, rtol=1e-2, atol=1e-3)
max_abs_err = (output_device - output_pytorch).abs().max()
max_rel_err = ((output_device - output_pytorch).abs() /
output_pytorch.abs()).max()
assert max_abs_err < 1e-9
assert max_rel_err < 1e-6
# half test
dtype = torch.half
with autocast(enabled=True):
output_device = MultiScaleDeformableAttnFunction.apply(
value.cuda().type(dtype), shapes.cuda(), level_start_index.cuda(),
sampling_locations.cuda(), attention_weights.cuda(),
im2col_step).detach().cpu()
assert torch.allclose(
output_device, output_pytorch.half(), rtol=1e-2, atol=1e-3)
max_abs_err = (output_device - output_pytorch).abs().max()
max_rel_err = ((output_device - output_pytorch).abs() /
output_pytorch.abs()).max()
assert max_abs_err < 1e-5
assert max_rel_err < 1e-2
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
@pytest.mark.parametrize('dtype', [
torch.float,
pytest.param(
torch.double,
marks=pytest.mark.skipif(
IS_MLU_AVAILABLE,
reason='MLU does not support for 64-bit floating point')),
torch.half
])
@pytest.mark.parametrize('channels', [
4,
30,
32,
64,
71,
1025,
])
def test_gradient_numerical(channels,
device,
dtype,
grad_value=True,
grad_sampling_loc=True,
grad_attn_weight=True):
N, M, _ = 1, 2, 2
Lq, L, P = 2, 2, 2
shapes = torch.as_tensor([(3, 2), (2, 1)], dtype=torch.long).to(device)
level_start_index = torch.cat((shapes.new_zeros(
(1, )), shapes.prod(1).cumsum(0)[:-1]))
S = sum((H * W).item() for H, W in shapes)
value = torch.rand(N, S, M, channels).to(device) * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2).to(device)
attention_weights = torch.rand(N, Lq, M, L, P).to(device) + 1e-5
attention_weights /= attention_weights.sum(
-1, keepdim=True).sum(
-2, keepdim=True)
im2col_step = 2
func = MultiScaleDeformableAttnFunction.apply
value.requires_grad = grad_value
sampling_locations.requires_grad = grad_sampling_loc
attention_weights.requires_grad = grad_attn_weight
if device == 'cuda':
dtype = torch.double
eps = 1e-6
elif device == 'mlu':
dtype = torch.float
eps = 1e-4
if _USING_PARROTS:
assert gradcheck(
func, (value.to(dtype), shapes, level_start_index,
sampling_locations.to(dtype), attention_weights.to(dtype),
im2col_step),
no_grads=[shapes, level_start_index],
eps=eps)
else:
assert gradcheck(
func, (value.to(dtype), shapes, level_start_index,
sampling_locations.to(dtype), attention_weights.to(dtype),
im2col_step),
eps=eps,
atol=1e-2)
@pytest.mark.skipif(not IS_NPU_AVAILABLE, reason='requires NPU support')
def test_backward_equal_with_pytorch_npu():
N, M, D = 6, 4, 8
Lq, L, P = 10000, 4, 8
shapes = torch.as_tensor([(60, 40), (30, 20), (16, 24), (53, 32)],
dtype=torch.int32)
level_start_index = torch.cat((shapes.new_zeros(
(1, )), shapes.prod(1).cumsum(0)[:-1]))
S = sum((H * W).item() for H, W in shapes)
torch.manual_seed(3)
value = torch.rand(N, S, M, D) * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2)
attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
attention_weights /= attention_weights.sum(
-1, keepdim=True).sum(
-2, keepdim=True)
im2col_step = 2
value.requires_grad = True
sampling_locations.requires_grad = True
attention_weights.requires_grad = True
output_pytorch = multi_scale_deformable_attn_pytorch(
value.float(), shapes, sampling_locations.float(),
attention_weights.float())
grad_output_pytorch = torch.ones_like(output_pytorch)
output_pytorch.backward(grad_output_pytorch)
grad_value = value.grad.detach().cpu()
grad_location = sampling_locations.grad.detach().cpu()
grad_attn_weight = attention_weights.grad.detach().cpu()
value_npu = value.npu()
shapes_npu = shapes.npu()
level_start_index_npu = level_start_index.npu()
sampling_locations_npu = sampling_locations.npu()
attention_weights_npu = attention_weights.npu()
output_npu = MultiScaleDeformableAttnFunction.apply(
value_npu.float(), shapes_npu, level_start_index_npu,
sampling_locations_npu.float(), attention_weights_npu.float(),
im2col_step)
grad_output_npu = torch.ones_like(output_npu)
output_npu.backward(grad_output_npu)
grad_value_npu = value_npu.grad.detach().cpu()
grad_location_npu = sampling_locations_npu.grad.detach().cpu()
grad_attn_weight_npu = attention_weights_npu.grad.detach().cpu()
assert torch.allclose(grad_value_npu, grad_value)
max_abs_err_1 = (grad_value_npu - grad_value).abs().max()
max_rel_err_1 = ((grad_value_npu - grad_value).abs() /
grad_value.abs()).max()
assert max_abs_err_1 < 1e-5
assert max_rel_err_1 < 1e-4
assert torch.allclose(grad_location_npu, grad_location)
max_abs_err_2 = (grad_location_npu - grad_location).abs().max()
max_rel_err_2 = ((grad_location_npu - grad_location).abs() /
grad_location.abs()).max()
assert max_abs_err_2 < 1e-5
assert max_rel_err_2 < 1e-4
assert torch.allclose(grad_attn_weight_npu, grad_attn_weight)
max_abs_err_3 = (grad_attn_weight_npu - grad_attn_weight).abs().max()
max_rel_err_3 = ((grad_attn_weight_npu - grad_attn_weight).abs() /
grad_attn_weight.abs()).max()
assert max_abs_err_3 < 1e-5
assert max_rel_err_3 < 1e-4