mirror of https://github.com/open-mmlab/mmcv.git
Add multi_scale_deform_attn op adapter for NPU (#3034)
parent
81f38f4d9b
commit
e5562f8a45
|
@ -33,7 +33,7 @@ We implement common ops used in detection, segmentation, etc.
|
|||
| MergeCells | | √ | | | |
|
||||
| MinAreaPolygon | | √ | | | |
|
||||
| ModulatedDeformConv2d | √ | √ | √ | | √ |
|
||||
| MultiScaleDeformableAttn | | √ | √ | | |
|
||||
| MultiScaleDeformableAttn | | √ | √ | | √ |
|
||||
| NMS | √ | √ | √ | | √ |
|
||||
| NMSRotated | √ | √ | √ | | √ |
|
||||
| NMSQuadri | √ | √ | | | |
|
||||
|
|
|
@ -33,7 +33,7 @@ MMCV 提供了检测、分割等任务中常用的算子
|
|||
| MergeCells | | √ | | | |
|
||||
| MinAreaPolygon | | √ | | | |
|
||||
| ModulatedDeformConv2d | √ | √ | √ | | √ |
|
||||
| MultiScaleDeformableAttn | | √ | √ | | |
|
||||
| MultiScaleDeformableAttn | | √ | √ | | √ |
|
||||
| NMS | √ | √ | √ | | √ |
|
||||
| NMSRotated | √ | √ | √ | | √ |
|
||||
| NMSQuadri | √ | √ | | | |
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
#include "pytorch_npu_helper.hpp"
|
||||
|
||||
using namespace NPU_NAME_SPACE;
|
||||
using namespace std;
|
||||
|
||||
Tensor ms_deform_attn_impl_forward(const Tensor &value,
|
||||
const Tensor &value_spatial_shapes,
|
||||
const Tensor &value_level_start_index,
|
||||
const Tensor &sampling_locations,
|
||||
const Tensor &attention_weights,
|
||||
const int im2col_step);
|
||||
|
||||
void check_support(const Tensor &value, const Tensor &attention_weights) {
|
||||
TORCH_CHECK(
|
||||
(value.scalar_type() == at::kFloat || value.scalar_type() == at::kHalf),
|
||||
"Dtype of value should be float32 or float16.");
|
||||
int64_t num_heads = value.size(2);
|
||||
int64_t embed_dims = value.size(3);
|
||||
int64_t num_points = attention_weights.size(4);
|
||||
TORCH_CHECK((num_heads >= 4 && num_heads <= 8),
|
||||
"num_heads should be in the range of [4, 8]");
|
||||
TORCH_CHECK((embed_dims >= 32 && embed_dims <= 256),
|
||||
"embed_dims should be in the range of [32, 256]");
|
||||
TORCH_CHECK((num_points >= 4 && num_points <= 8),
|
||||
"num_points should be in the range of [4, 8]");
|
||||
}
|
||||
|
||||
Tensor ms_deform_attn_forward_npu(const Tensor &value,
|
||||
const Tensor &value_spatial_shapes,
|
||||
const Tensor &value_level_start_index,
|
||||
const Tensor &sampling_locations,
|
||||
const Tensor &attention_weights,
|
||||
const int im2col_step) {
|
||||
check_support(value, attention_weights);
|
||||
at::Tensor value_fp32 = value;
|
||||
at::Tensor value_spatial_shapes_int32 = value_spatial_shapes;
|
||||
at::Tensor value_level_start_index_int32 = value_level_start_index;
|
||||
at::Tensor sampling_locations_fp32 = sampling_locations;
|
||||
at::Tensor attention_weights_fp32 = attention_weights;
|
||||
if (value.scalar_type() != at::kFloat) {
|
||||
value_fp32 = value.to(at::kFloat);
|
||||
}
|
||||
if (value_spatial_shapes.scalar_type() != at::kInt) {
|
||||
value_spatial_shapes_int32 = value_spatial_shapes.to(at::kInt);
|
||||
}
|
||||
if (value_level_start_index.scalar_type() != at::kInt) {
|
||||
value_level_start_index_int32 = value_level_start_index.to(at::kInt);
|
||||
}
|
||||
if (sampling_locations.scalar_type() != at::kFloat) {
|
||||
sampling_locations_fp32 = sampling_locations.to(at::kFloat);
|
||||
}
|
||||
if (attention_weights.scalar_type() != at::kFloat) {
|
||||
attention_weights_fp32 = attention_weights.to(at::kFloat);
|
||||
}
|
||||
|
||||
c10::SmallVector<int64_t, 3> output_size = {
|
||||
value.size(0), sampling_locations.size(1), value.size(2) * value.size(3)};
|
||||
at::Tensor output = at::empty(output_size, value_fp32.options());
|
||||
|
||||
OpCommand cmd;
|
||||
cmd.Name("MultiScaleDeformableAttnFunction")
|
||||
.Input(value_fp32)
|
||||
.Input(value_spatial_shapes_int32)
|
||||
.Input(value_level_start_index_int32)
|
||||
.Input(sampling_locations_fp32)
|
||||
.Input(attention_weights_fp32)
|
||||
.Output(output)
|
||||
.Run();
|
||||
|
||||
at::Tensor real_output = output;
|
||||
if (value.scalar_type() != at::kFloat) {
|
||||
real_output = output.to(value.scalar_type());
|
||||
}
|
||||
return real_output;
|
||||
}
|
||||
|
||||
REGISTER_NPU_IMPL(ms_deform_attn_impl_forward, ms_deform_attn_forward_npu);
|
|
@ -13,7 +13,7 @@ from mmcv import deprecated_api_warning
|
|||
from mmcv.cnn import constant_init, xavier_init
|
||||
from mmcv.cnn.bricks.registry import ATTENTION
|
||||
from mmcv.runner import BaseModule
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
|
||||
from ..utils import ext_loader
|
||||
|
||||
ext_module = ext_loader.load_ext(
|
||||
|
@ -361,7 +361,8 @@ class MultiScaleDeformableAttention(BaseModule):
|
|||
f'Last dim of reference_points must be'
|
||||
f' 2 or 4, but get {reference_points.shape[-1]} instead.')
|
||||
if ((IS_CUDA_AVAILABLE and value.is_cuda)
|
||||
or (IS_MLU_AVAILABLE and value.is_mlu)):
|
||||
or (IS_MLU_AVAILABLE and value.is_mlu)
|
||||
or (IS_NPU_AVAILABLE and value.device.type == 'npu')):
|
||||
output = MultiScaleDeformableAttnFunction.apply(
|
||||
value, spatial_shapes, level_start_index, sampling_locations,
|
||||
attention_weights, self.im2col_step)
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
from mmcv.ops.multi_scale_deform_attn import (
|
||||
MultiScaleDeformableAttention, MultiScaleDeformableAttnFunction,
|
||||
multi_scale_deformable_attn_pytorch)
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
|
||||
|
||||
_USING_PARROTS = True
|
||||
_IS_AUTOCAST_AVAILABLE = True
|
||||
|
@ -116,6 +116,40 @@ def test_forward_equal_with_pytorch_double():
|
|||
assert max_rel_err < 1e-15
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_NPU_AVAILABLE, reason='requires NPU support')
|
||||
def test_forward_equal_with_pytorch_npu():
|
||||
N, M, D = 6, 4, 8
|
||||
Lq, L, P = 10000, 4, 8
|
||||
shapes = torch.as_tensor([(60, 40), (30, 20), (16, 24), (53, 32)],
|
||||
dtype=torch.int32)
|
||||
level_start_index = torch.cat((shapes.new_zeros(
|
||||
(1, )), shapes.prod(1).cumsum(0)[:-1]))
|
||||
S = sum((H * W).item() for H, W in shapes)
|
||||
|
||||
torch.manual_seed(3)
|
||||
value = torch.rand(N, S, M, D) * 0.01
|
||||
sampling_locations = torch.rand(N, Lq, M, L, P, 2)
|
||||
attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
|
||||
attention_weights /= attention_weights.sum(
|
||||
-1, keepdim=True).sum(
|
||||
-2, keepdim=True)
|
||||
im2col_step = 2
|
||||
output_pytorch = multi_scale_deformable_attn_pytorch(
|
||||
value.float(), shapes, sampling_locations.float(),
|
||||
attention_weights.float()).detach().cpu()
|
||||
|
||||
output_npu = MultiScaleDeformableAttnFunction.apply(
|
||||
value.npu().float(), shapes.npu(), level_start_index.npu(),
|
||||
sampling_locations.npu().float(),
|
||||
attention_weights.npu().float(), im2col_step).detach().cpu()
|
||||
assert torch.allclose(output_npu, output_pytorch)
|
||||
max_abs_err = (output_npu - output_pytorch).abs().max()
|
||||
max_rel_err = ((output_npu - output_pytorch).abs() /
|
||||
output_pytorch.abs()).max()
|
||||
assert max_abs_err < 1e-18
|
||||
assert max_rel_err < 1e-15
|
||||
|
||||
|
||||
@pytest.mark.parametrize('device', [
|
||||
pytest.param(
|
||||
'cuda',
|
||||
|
|
Loading…
Reference in New Issue