mirror of https://github.com/open-mmlab/mmcv.git
更新knn和three_nn的NPU适配代码 (#3194)
parent
71437a361c
commit
e1aab12f9b
|
@ -12,7 +12,7 @@ void knn_forward_npu(int b, int n, int m, int nsample, const Tensor xyz,
|
|||
at::Tensor target = new_xyz.contiguous();
|
||||
|
||||
bool is_from_knn = true;
|
||||
EXEC_NPU_CMD(aclnnKnn, source, target, nsample, is_from_knn, idx, dist2);
|
||||
EXEC_NPU_CMD(aclnnKnn, source, target, is_from_knn, nsample, dist2, idx);
|
||||
}
|
||||
|
||||
void knn_forward_impl(int b, int n, int m, int nsample, const Tensor xyz,
|
||||
|
|
|
@ -18,7 +18,7 @@ void three_nn_forward_npu(int b, int n, int m, const Tensor unknown,
|
|||
|
||||
bool is_from_knn = false;
|
||||
uint32_t nsample = 3;
|
||||
EXEC_NPU_CMD(aclnnKnn, source, target, nsample, is_from_knn, idx, dist2);
|
||||
EXEC_NPU_CMD(aclnnKnn, source, target, is_from_knn, nsample, dist2, idx);
|
||||
if (originDtype == at::kHalf) {
|
||||
dist2 = dist2.to(at::kHalf);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue