mirror of https://github.com/open-mmlab/mmcv.git
[Enhancement] Support MultiScaleDeformableAttention with AMP (#2541)
* [Enhance] Support FP16 for MSDeformAttn * [Fix] Data type mismatch * Update mmcv/ops/multi_scale_deform_attn.py * Add UT Author: nijkah <nijkah@gmail.com> * Add cuda available condition --------- Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>pull/2618/head
parent
cb94ffb672
commit
8e8ab22686
mmcv/ops
csrc/pytorch/cuda
tests/test_ops
|
@ -255,7 +255,7 @@ at::Tensor ms_deform_attn_cuda_forward(const at::Tensor &value,
|
|||
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
|
||||
for (int n = 0; n < batch / im2col_step_; ++n) {
|
||||
auto columns = output_n.select(0, n);
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
value.scalar_type(), "ms_deform_attn_forward_cuda", ([&] {
|
||||
ms_deformable_im2col_cuda(
|
||||
at::cuda::getCurrentCUDAStream(),
|
||||
|
@ -326,7 +326,7 @@ void ms_deform_attn_cuda_backward(
|
|||
|
||||
for (int n = 0; n < batch / im2col_step_; ++n) {
|
||||
auto grad_output_g = grad_output_n.select(0, n);
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
value.scalar_type(), "ms_deform_attn_backward_cuda", ([&] {
|
||||
ms_deformable_col2im_cuda(
|
||||
at::cuda::getCurrentCUDAStream(),
|
||||
|
|
|
@ -50,6 +50,18 @@ class MultiScaleDeformableAttnFunction(Function):
|
|||
"""
|
||||
|
||||
ctx.im2col_step = im2col_step
|
||||
|
||||
# When pytorch version >= 1.6.0, amp is adopted for fp16 mode;
|
||||
# amp won't cast the type of sampling_locations, attention_weights
|
||||
# (float32), but "value" is cast to float16, leading to the type
|
||||
# mismatch with input (when it is float32) or weight.
|
||||
# The flag for whether to use fp16 or amp is the type of "value",
|
||||
# we cast sampling_locations and attention_weights to
|
||||
# temporarily support fp16 and amp whatever the
|
||||
# pytorch version is.
|
||||
sampling_locations = sampling_locations.type_as(value)
|
||||
attention_weights = attention_weights.type_as(value)
|
||||
|
||||
output = ext_module.ms_deform_attn_forward(
|
||||
value,
|
||||
value_spatial_shapes,
|
||||
|
|
|
@ -8,12 +8,21 @@ from mmcv.ops.multi_scale_deform_attn import (
|
|||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
|
||||
|
||||
_USING_PARROTS = True
|
||||
_IS_AUTOCAST_AVAILABLE = True
|
||||
try:
|
||||
from parrots.autograd import gradcheck
|
||||
except ImportError:
|
||||
from torch.autograd import gradcheck
|
||||
_USING_PARROTS = False
|
||||
|
||||
try:
|
||||
# If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast
|
||||
# would be imported and used; we should test if our modules support it.
|
||||
from torch.cuda.amp import autocast
|
||||
except ImportError:
|
||||
_IS_AUTOCAST_AVAILABLE = False
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize('device', [
|
||||
'cpu',
|
||||
|
@ -148,6 +157,58 @@ def test_forward_equal_with_pytorch_float(device):
|
|||
assert max_rel_err < 1e-6
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not _IS_AUTOCAST_AVAILABLE, reason='requires autocast support')
|
||||
@pytest.mark.skipif(not IS_CUDA_AVAILABLE, reason='requires CUDA support')
|
||||
def test_forward_equal_with_autocast():
|
||||
N, M, D = 1, 2, 2
|
||||
Lq, L, P = 2, 2, 2
|
||||
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long)
|
||||
level_start_index = torch.cat((shapes.new_zeros(
|
||||
(1, )), shapes.prod(1).cumsum(0)[:-1]))
|
||||
S = sum((H * W).item() for H, W in shapes)
|
||||
|
||||
torch.manual_seed(3)
|
||||
value = torch.rand(N, S, M, D) * 0.01
|
||||
sampling_locations = torch.rand(N, Lq, M, L, P, 2)
|
||||
attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
|
||||
attention_weights /= attention_weights.sum(
|
||||
-1, keepdim=True).sum(
|
||||
-2, keepdim=True)
|
||||
im2col_step = 2
|
||||
output_pytorch = multi_scale_deformable_attn_pytorch(
|
||||
value, shapes, sampling_locations, attention_weights).detach().cpu()
|
||||
|
||||
# float test
|
||||
dtype = torch.float
|
||||
with autocast(enabled=True):
|
||||
output_device = MultiScaleDeformableAttnFunction.apply(
|
||||
value.cuda().type(dtype), shapes.cuda(), level_start_index.cuda(),
|
||||
sampling_locations.cuda(), attention_weights.cuda(),
|
||||
im2col_step).detach().cpu()
|
||||
assert torch.allclose(output_device, output_pytorch, rtol=1e-2, atol=1e-3)
|
||||
max_abs_err = (output_device - output_pytorch).abs().max()
|
||||
max_rel_err = ((output_device - output_pytorch).abs() /
|
||||
output_pytorch.abs()).max()
|
||||
assert max_abs_err < 1e-9
|
||||
assert max_rel_err < 1e-6
|
||||
|
||||
# half test
|
||||
dtype = torch.half
|
||||
with autocast(enabled=True):
|
||||
output_device = MultiScaleDeformableAttnFunction.apply(
|
||||
value.cuda().type(dtype), shapes.cuda(), level_start_index.cuda(),
|
||||
sampling_locations.cuda(), attention_weights.cuda(),
|
||||
im2col_step).detach().cpu()
|
||||
assert torch.allclose(
|
||||
output_device, output_pytorch.half(), rtol=1e-2, atol=1e-3)
|
||||
max_abs_err = (output_device - output_pytorch).abs().max()
|
||||
max_rel_err = ((output_device - output_pytorch).abs() /
|
||||
output_pytorch.abs()).max()
|
||||
assert max_abs_err < 1e-5
|
||||
assert max_rel_err < 1e-2
|
||||
|
||||
|
||||
@pytest.mark.parametrize('device', [
|
||||
pytest.param(
|
||||
'cuda',
|
||||
|
|
Loading…
Reference in New Issue