更新knn和three_nn的NPU适配代码 (#3194)

pull/3205/head
huangyuan64 2024-11-18 13:41:44 +08:00 committed by GitHub
parent 71437a361c
commit e1aab12f9b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 2 deletions

View File

@ -12,7 +12,7 @@ void knn_forward_npu(int b, int n, int m, int nsample, const Tensor xyz,
at::Tensor target = new_xyz.contiguous();
bool is_from_knn = true;
EXEC_NPU_CMD(aclnnKnn, source, target, nsample, is_from_knn, idx, dist2);
EXEC_NPU_CMD(aclnnKnn, source, target, is_from_knn, nsample, dist2, idx);
}
void knn_forward_impl(int b, int n, int m, int nsample, const Tensor xyz,

View File

@ -18,7 +18,7 @@ void three_nn_forward_npu(int b, int n, int m, const Tensor unknown,
bool is_from_knn = false;
uint32_t nsample = 3;
EXEC_NPU_CMD(aclnnKnn, source, target, nsample, is_from_knn, idx, dist2);
EXEC_NPU_CMD(aclnnKnn, source, target, is_from_knn, nsample, dist2, idx);
if (originDtype == at::kHalf) {
dist2 = dist2.to(at::kHalf);
}