【Feature】knn/tnn npu added (#3125)

Co-authored-by: lizekai <lizekai3@hisilicon.com>
pull/3129/head
Tsekai Lee 2024-06-11 17:05:24 +08:00 committed by GitHub
parent b91cfded58
commit 8f23a0b8f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 54 additions and 2 deletions

View File

@ -0,0 +1,21 @@
#include "pytorch_npu_helper.hpp"
#include "torch_npu/csrc/aten/NPUNativeFunctions.h"
#include "torch_npu/csrc/framework/utils/OpAdapter.h"
using namespace NPU_NAME_SPACE;
using namespace std;
void knn_forward_npu(int b, int n, int m, int nsample, const Tensor xyz,
const Tensor new_xyz, Tensor idx, Tensor dist2) {
// transpose known from [B, N, 3] to [B, 3, N]
at::Tensor source = xyz.transpose(1, 2).contiguous();
at::Tensor target = new_xyz.contiguous();
bool is_from_knn = true;
EXEC_NPU_CMD(aclnnKnn, source, target, nsample, is_from_knn, idx, dist2);
}
void knn_forward_impl(int b, int n, int m, int nsample, const Tensor xyz,
const Tensor new_xyz, Tensor idx, Tensor dist2);
REGISTER_NPU_IMPL(knn_forward_impl, knn_forward_npu);

View File

@ -0,0 +1,30 @@
#include "pytorch_npu_helper.hpp"
#include "torch_npu/csrc/aten/NPUNativeFunctions.h"
#include "torch_npu/csrc/framework/utils/OpAdapter.h"
using namespace NPU_NAME_SPACE;
using namespace std;
void three_nn_forward_npu(int b, int n, int m, const Tensor unknown,
const Tensor known, Tensor dist2, Tensor idx) {
// transpose known [B, N, 3] -> [B, 3, N]
at::Tensor source = known.transpose(1, 2).contiguous();
at::Tensor target = unknown.contiguous();
auto originDtype = source.scalar_type();
if (originDtype == at::kHalf) {
source = source.to(at::kFloat);
target = target.to(at::kFloat);
}
bool is_from_knn = false;
uint32_t nsample = 3;
EXEC_NPU_CMD(aclnnKnn, source, target, nsample, is_from_knn, idx, dist2);
if (originDtype == at::kHalf) {
dist2 = dist2.to(at::kHalf);
}
}
void three_nn_forward_impl(int b, int n, int m, const Tensor unknown,
const Tensor known, Tensor dist2, Tensor idx);
REGISTER_NPU_IMPL(three_nn_forward_impl, three_nn_forward_npu);

View File

@ -55,8 +55,9 @@ class KNN(Function):
center_xyz_device = center_xyz.get_device()
assert center_xyz_device == xyz.get_device(), \
'center_xyz and xyz should be put on the same device'
if torch.cuda.current_device() != center_xyz_device:
torch.cuda.set_device(center_xyz_device)
if xyz.device.type != 'npu':
if torch.cuda.current_device() != center_xyz_device:
torch.cuda.set_device(center_xyz_device)
B, npoint, _ = center_xyz.shape
N = xyz.shape[1]