[NPU] npu msda supports aclnn (#3149)

pull/3153/head
Pr0Wh1teGivee 2024-07-17 17:04:21 +08:00 committed by GitHub
parent 8f23a0b8f2
commit 44eab261b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 3 additions and 9 deletions

View File

@ -57,15 +57,9 @@ Tensor ms_deform_attn_forward_npu(const Tensor &value,
value.size(0), sampling_locations.size(1), value.size(2) * value.size(3)};
at::Tensor output = at::zeros(output_size, value_fp32.options());
OpCommand cmd;
cmd.Name("MultiScaleDeformableAttnFunction")
.Input(value_fp32)
.Input(value_spatial_shapes_int32)
.Input(value_level_start_index_int32)
.Input(sampling_locations_fp32)
.Input(attention_weights_fp32)
.Output(output)
.Run();
EXEC_NPU_CMD(aclnnMultiScaleDeformableAttnFunction, value_fp32,
value_spatial_shapes_int32, value_level_start_index_int32,
sampling_locations_fp32, attention_weights_fp32, output)
at::Tensor real_output = output;
if (value.scalar_type() != at::kFloat) {