mirror of https://github.com/open-mmlab/mmcv.git
[NPU] npu msda supports aclnn (#3149)
parent
8f23a0b8f2
commit
44eab261b9
|
@ -57,15 +57,9 @@ Tensor ms_deform_attn_forward_npu(const Tensor &value,
|
|||
value.size(0), sampling_locations.size(1), value.size(2) * value.size(3)};
|
||||
at::Tensor output = at::zeros(output_size, value_fp32.options());
|
||||
|
||||
OpCommand cmd;
|
||||
cmd.Name("MultiScaleDeformableAttnFunction")
|
||||
.Input(value_fp32)
|
||||
.Input(value_spatial_shapes_int32)
|
||||
.Input(value_level_start_index_int32)
|
||||
.Input(sampling_locations_fp32)
|
||||
.Input(attention_weights_fp32)
|
||||
.Output(output)
|
||||
.Run();
|
||||
EXEC_NPU_CMD(aclnnMultiScaleDeformableAttnFunction, value_fp32,
|
||||
value_spatial_shapes_int32, value_level_start_index_int32,
|
||||
sampling_locations_fp32, attention_weights_fp32, output)
|
||||
|
||||
at::Tensor real_output = output;
|
||||
if (value.scalar_type() != at::kFloat) {
|
||||
|
|
Loading…
Reference in New Issue