mirror of https://github.com/open-mmlab/mmcv.git
parent
b91cfded58
commit
8f23a0b8f2
|
@ -0,0 +1,21 @@
|
|||
#include "pytorch_npu_helper.hpp"
|
||||
#include "torch_npu/csrc/aten/NPUNativeFunctions.h"
|
||||
#include "torch_npu/csrc/framework/utils/OpAdapter.h"
|
||||
|
||||
using namespace NPU_NAME_SPACE;
|
||||
using namespace std;
|
||||
|
||||
void knn_forward_npu(int b, int n, int m, int nsample, const Tensor xyz,
|
||||
const Tensor new_xyz, Tensor idx, Tensor dist2) {
|
||||
// transpose known from [B, N, 3] to [B, 3, N]
|
||||
at::Tensor source = xyz.transpose(1, 2).contiguous();
|
||||
at::Tensor target = new_xyz.contiguous();
|
||||
|
||||
bool is_from_knn = true;
|
||||
EXEC_NPU_CMD(aclnnKnn, source, target, nsample, is_from_knn, idx, dist2);
|
||||
}
|
||||
|
||||
void knn_forward_impl(int b, int n, int m, int nsample, const Tensor xyz,
|
||||
const Tensor new_xyz, Tensor idx, Tensor dist2);
|
||||
|
||||
REGISTER_NPU_IMPL(knn_forward_impl, knn_forward_npu);
|
|
@ -0,0 +1,30 @@
|
|||
#include "pytorch_npu_helper.hpp"
|
||||
#include "torch_npu/csrc/aten/NPUNativeFunctions.h"
|
||||
#include "torch_npu/csrc/framework/utils/OpAdapter.h"
|
||||
|
||||
using namespace NPU_NAME_SPACE;
|
||||
using namespace std;
|
||||
|
||||
void three_nn_forward_npu(int b, int n, int m, const Tensor unknown,
|
||||
const Tensor known, Tensor dist2, Tensor idx) {
|
||||
// transpose known [B, N, 3] -> [B, 3, N]
|
||||
at::Tensor source = known.transpose(1, 2).contiguous();
|
||||
at::Tensor target = unknown.contiguous();
|
||||
auto originDtype = source.scalar_type();
|
||||
if (originDtype == at::kHalf) {
|
||||
source = source.to(at::kFloat);
|
||||
target = target.to(at::kFloat);
|
||||
}
|
||||
|
||||
bool is_from_knn = false;
|
||||
uint32_t nsample = 3;
|
||||
EXEC_NPU_CMD(aclnnKnn, source, target, nsample, is_from_knn, idx, dist2);
|
||||
if (originDtype == at::kHalf) {
|
||||
dist2 = dist2.to(at::kHalf);
|
||||
}
|
||||
}
|
||||
|
||||
void three_nn_forward_impl(int b, int n, int m, const Tensor unknown,
|
||||
const Tensor known, Tensor dist2, Tensor idx);
|
||||
|
||||
REGISTER_NPU_IMPL(three_nn_forward_impl, three_nn_forward_npu);
|
|
@ -55,8 +55,9 @@ class KNN(Function):
|
|||
center_xyz_device = center_xyz.get_device()
|
||||
assert center_xyz_device == xyz.get_device(), \
|
||||
'center_xyz and xyz should be put on the same device'
|
||||
if torch.cuda.current_device() != center_xyz_device:
|
||||
torch.cuda.set_device(center_xyz_device)
|
||||
if xyz.device.type != 'npu':
|
||||
if torch.cuda.current_device() != center_xyz_device:
|
||||
torch.cuda.set_device(center_xyz_device)
|
||||
|
||||
B, npoint, _ = center_xyz.shape
|
||||
N = xyz.shape[1]
|
||||
|
|
Loading…
Reference in New Issue