[Enhancement] Support MultiScaleDeformableAttention with AMP ()

* [Enhance] Support FP16 for MSDeformAttn

* [Fix] Data type mismatch

* Update mmcv/ops/multi_scale_deform_attn.py

* Add UT

Author:    nijkah <nijkah@gmail.com>

* Add cuda available condition

---------

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
pull/2618/head
Hakjin Lee 2023-02-17 20:27:14 +09:00 committed by GitHub
parent cb94ffb672
commit 8e8ab22686
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 75 additions and 2 deletions

View File

@ -255,7 +255,7 @@ at::Tensor ms_deform_attn_cuda_forward(const at::Tensor &value,
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
for (int n = 0; n < batch / im2col_step_; ++n) {
auto columns = output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
value.scalar_type(), "ms_deform_attn_forward_cuda", ([&] {
ms_deformable_im2col_cuda(
at::cuda::getCurrentCUDAStream(),
@ -326,7 +326,7 @@ void ms_deform_attn_cuda_backward(
for (int n = 0; n < batch / im2col_step_; ++n) {
auto grad_output_g = grad_output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
value.scalar_type(), "ms_deform_attn_backward_cuda", ([&] {
ms_deformable_col2im_cuda(
at::cuda::getCurrentCUDAStream(),

View File

@ -50,6 +50,18 @@ class MultiScaleDeformableAttnFunction(Function):
"""
ctx.im2col_step = im2col_step
# When pytorch version >= 1.6.0, amp is adopted for fp16 mode;
# amp won't cast the type of sampling_locations, attention_weights
# (float32), but "value" is cast to float16, leading to the type
# mismatch with input (when it is float32) or weight.
# The flag for whether to use fp16 or amp is the type of "value",
# we cast sampling_locations and attention_weights to
# temporarily support fp16 and amp whatever the
# pytorch version is.
sampling_locations = sampling_locations.type_as(value)
attention_weights = attention_weights.type_as(value)
output = ext_module.ms_deform_attn_forward(
value,
value_spatial_shapes,

View File

@ -8,12 +8,21 @@ from mmcv.ops.multi_scale_deform_attn import (
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
_USING_PARROTS = True
_IS_AUTOCAST_AVAILABLE = True
try:
from parrots.autograd import gradcheck
except ImportError:
from torch.autograd import gradcheck
_USING_PARROTS = False
try:
# If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast
# would be imported and used; we should test if our modules support it.
from torch.cuda.amp import autocast
except ImportError:
_IS_AUTOCAST_AVAILABLE = False
pass
@pytest.mark.parametrize('device', [
'cpu',
@ -148,6 +157,58 @@ def test_forward_equal_with_pytorch_float(device):
assert max_rel_err < 1e-6
@pytest.mark.skipif(
not _IS_AUTOCAST_AVAILABLE, reason='requires autocast support')
@pytest.mark.skipif(not IS_CUDA_AVAILABLE, reason='requires CUDA support')
def test_forward_equal_with_autocast():
N, M, D = 1, 2, 2
Lq, L, P = 2, 2, 2
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long)
level_start_index = torch.cat((shapes.new_zeros(
(1, )), shapes.prod(1).cumsum(0)[:-1]))
S = sum((H * W).item() for H, W in shapes)
torch.manual_seed(3)
value = torch.rand(N, S, M, D) * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2)
attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
attention_weights /= attention_weights.sum(
-1, keepdim=True).sum(
-2, keepdim=True)
im2col_step = 2
output_pytorch = multi_scale_deformable_attn_pytorch(
value, shapes, sampling_locations, attention_weights).detach().cpu()
# float test
dtype = torch.float
with autocast(enabled=True):
output_device = MultiScaleDeformableAttnFunction.apply(
value.cuda().type(dtype), shapes.cuda(), level_start_index.cuda(),
sampling_locations.cuda(), attention_weights.cuda(),
im2col_step).detach().cpu()
assert torch.allclose(output_device, output_pytorch, rtol=1e-2, atol=1e-3)
max_abs_err = (output_device - output_pytorch).abs().max()
max_rel_err = ((output_device - output_pytorch).abs() /
output_pytorch.abs()).max()
assert max_abs_err < 1e-9
assert max_rel_err < 1e-6
# half test
dtype = torch.half
with autocast(enabled=True):
output_device = MultiScaleDeformableAttnFunction.apply(
value.cuda().type(dtype), shapes.cuda(), level_start_index.cuda(),
sampling_locations.cuda(), attention_weights.cuda(),
im2col_step).detach().cpu()
assert torch.allclose(
output_device, output_pytorch.half(), rtol=1e-2, atol=1e-3)
max_abs_err = (output_device - output_pytorch).abs().max()
max_rel_err = ((output_device - output_pytorch).abs() /
output_pytorch.abs()).max()
assert max_abs_err < 1e-5
assert max_rel_err < 1e-2
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',