Add multi_scale_deform_attn op adapter for NPU (#3034)

pull/3048/head
DaGaiBa 2024-03-01 15:02:58 +08:00 committed by GitHub
parent 81f38f4d9b
commit e5562f8a45
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 118 additions and 6 deletions

View File

@ -33,7 +33,7 @@ We implement common ops used in detection, segmentation, etc.
| MergeCells | | √ | | | |
| MinAreaPolygon | | √ | | | |
| ModulatedDeformConv2d | √ | √ | √ | | √ |
| MultiScaleDeformableAttn | | √ | √ | | |
| MultiScaleDeformableAttn | | √ | √ | | |
| NMS | √ | √ | √ | | √ |
| NMSRotated | √ | √ | √ | | √ |
| NMSQuadri | √ | √ | | | |

View File

@ -33,7 +33,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| MergeCells | | √ | | | |
| MinAreaPolygon | | √ | | | |
| ModulatedDeformConv2d | √ | √ | √ | | √ |
| MultiScaleDeformableAttn | | √ | √ | | |
| MultiScaleDeformableAttn | | √ | √ | | |
| NMS | √ | √ | √ | | √ |
| NMSRotated | √ | √ | √ | | √ |
| NMSQuadri | √ | √ | | | |

View File

@ -0,0 +1,77 @@
#include "pytorch_npu_helper.hpp"
using namespace NPU_NAME_SPACE;
using namespace std;
Tensor ms_deform_attn_impl_forward(const Tensor &value,
const Tensor &value_spatial_shapes,
const Tensor &value_level_start_index,
const Tensor &sampling_locations,
const Tensor &attention_weights,
const int im2col_step);
void check_support(const Tensor &value, const Tensor &attention_weights) {
TORCH_CHECK(
(value.scalar_type() == at::kFloat || value.scalar_type() == at::kHalf),
"Dtype of value should be float32 or float16.");
int64_t num_heads = value.size(2);
int64_t embed_dims = value.size(3);
int64_t num_points = attention_weights.size(4);
TORCH_CHECK((num_heads >= 4 && num_heads <= 8),
"num_heads should be in the range of [4, 8]");
TORCH_CHECK((embed_dims >= 32 && embed_dims <= 256),
"embed_dims should be in the range of [32, 256]");
TORCH_CHECK((num_points >= 4 && num_points <= 8),
"num_points should be in the range of [4, 8]");
}
Tensor ms_deform_attn_forward_npu(const Tensor &value,
const Tensor &value_spatial_shapes,
const Tensor &value_level_start_index,
const Tensor &sampling_locations,
const Tensor &attention_weights,
const int im2col_step) {
check_support(value, attention_weights);
at::Tensor value_fp32 = value;
at::Tensor value_spatial_shapes_int32 = value_spatial_shapes;
at::Tensor value_level_start_index_int32 = value_level_start_index;
at::Tensor sampling_locations_fp32 = sampling_locations;
at::Tensor attention_weights_fp32 = attention_weights;
if (value.scalar_type() != at::kFloat) {
value_fp32 = value.to(at::kFloat);
}
if (value_spatial_shapes.scalar_type() != at::kInt) {
value_spatial_shapes_int32 = value_spatial_shapes.to(at::kInt);
}
if (value_level_start_index.scalar_type() != at::kInt) {
value_level_start_index_int32 = value_level_start_index.to(at::kInt);
}
if (sampling_locations.scalar_type() != at::kFloat) {
sampling_locations_fp32 = sampling_locations.to(at::kFloat);
}
if (attention_weights.scalar_type() != at::kFloat) {
attention_weights_fp32 = attention_weights.to(at::kFloat);
}
c10::SmallVector<int64_t, 3> output_size = {
value.size(0), sampling_locations.size(1), value.size(2) * value.size(3)};
at::Tensor output = at::empty(output_size, value_fp32.options());
OpCommand cmd;
cmd.Name("MultiScaleDeformableAttnFunction")
.Input(value_fp32)
.Input(value_spatial_shapes_int32)
.Input(value_level_start_index_int32)
.Input(sampling_locations_fp32)
.Input(attention_weights_fp32)
.Output(output)
.Run();
at::Tensor real_output = output;
if (value.scalar_type() != at::kFloat) {
real_output = output.to(value.scalar_type());
}
return real_output;
}
REGISTER_NPU_IMPL(ms_deform_attn_impl_forward, ms_deform_attn_forward_npu);

View File

@ -13,7 +13,7 @@ from mmcv import deprecated_api_warning
from mmcv.cnn import constant_init, xavier_init
from mmcv.cnn.bricks.registry import ATTENTION
from mmcv.runner import BaseModule
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
@ -361,7 +361,8 @@ class MultiScaleDeformableAttention(BaseModule):
f'Last dim of reference_points must be'
f' 2 or 4, but get {reference_points.shape[-1]} instead.')
if ((IS_CUDA_AVAILABLE and value.is_cuda)
or (IS_MLU_AVAILABLE and value.is_mlu)):
or (IS_MLU_AVAILABLE and value.is_mlu)
or (IS_NPU_AVAILABLE and value.device.type == 'npu')):
output = MultiScaleDeformableAttnFunction.apply(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)

View File

@ -5,7 +5,7 @@ import torch
from mmcv.ops.multi_scale_deform_attn import (
MultiScaleDeformableAttention, MultiScaleDeformableAttnFunction,
multi_scale_deformable_attn_pytorch)
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
_USING_PARROTS = True
_IS_AUTOCAST_AVAILABLE = True
@ -116,6 +116,40 @@ def test_forward_equal_with_pytorch_double():
assert max_rel_err < 1e-15
@pytest.mark.skipif(not IS_NPU_AVAILABLE, reason='requires NPU support')
def test_forward_equal_with_pytorch_npu():
N, M, D = 6, 4, 8
Lq, L, P = 10000, 4, 8
shapes = torch.as_tensor([(60, 40), (30, 20), (16, 24), (53, 32)],
dtype=torch.int32)
level_start_index = torch.cat((shapes.new_zeros(
(1, )), shapes.prod(1).cumsum(0)[:-1]))
S = sum((H * W).item() for H, W in shapes)
torch.manual_seed(3)
value = torch.rand(N, S, M, D) * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2)
attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
attention_weights /= attention_weights.sum(
-1, keepdim=True).sum(
-2, keepdim=True)
im2col_step = 2
output_pytorch = multi_scale_deformable_attn_pytorch(
value.float(), shapes, sampling_locations.float(),
attention_weights.float()).detach().cpu()
output_npu = MultiScaleDeformableAttnFunction.apply(
value.npu().float(), shapes.npu(), level_start_index.npu(),
sampling_locations.npu().float(),
attention_weights.npu().float(), im2col_step).detach().cpu()
assert torch.allclose(output_npu, output_pytorch)
max_abs_err = (output_npu - output_pytorch).abs().max()
max_rel_err = ((output_npu - output_pytorch).abs() /
output_pytorch.abs()).max()
assert max_abs_err < 1e-18
assert max_rel_err < 1e-15
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',