mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
【Feature】knn/tnn npu added (#3125)
Co-authored-by: lizekai <lizekai3@hisilicon.com>
This commit is contained in:
parent
b91cfded58
commit
8f23a0b8f2
21
mmcv/ops/csrc/pytorch/npu/knn_npu.cpp
Normal file
21
mmcv/ops/csrc/pytorch/npu/knn_npu.cpp
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
#include "pytorch_npu_helper.hpp"
|
||||||
|
#include "torch_npu/csrc/aten/NPUNativeFunctions.h"
|
||||||
|
#include "torch_npu/csrc/framework/utils/OpAdapter.h"
|
||||||
|
|
||||||
|
using namespace NPU_NAME_SPACE;
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
void knn_forward_npu(int b, int n, int m, int nsample, const Tensor xyz,
|
||||||
|
const Tensor new_xyz, Tensor idx, Tensor dist2) {
|
||||||
|
// transpose known from [B, N, 3] to [B, 3, N]
|
||||||
|
at::Tensor source = xyz.transpose(1, 2).contiguous();
|
||||||
|
at::Tensor target = new_xyz.contiguous();
|
||||||
|
|
||||||
|
bool is_from_knn = true;
|
||||||
|
EXEC_NPU_CMD(aclnnKnn, source, target, nsample, is_from_knn, idx, dist2);
|
||||||
|
}
|
||||||
|
|
||||||
|
void knn_forward_impl(int b, int n, int m, int nsample, const Tensor xyz,
|
||||||
|
const Tensor new_xyz, Tensor idx, Tensor dist2);
|
||||||
|
|
||||||
|
REGISTER_NPU_IMPL(knn_forward_impl, knn_forward_npu);
|
30
mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp
Normal file
30
mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
#include "pytorch_npu_helper.hpp"
|
||||||
|
#include "torch_npu/csrc/aten/NPUNativeFunctions.h"
|
||||||
|
#include "torch_npu/csrc/framework/utils/OpAdapter.h"
|
||||||
|
|
||||||
|
using namespace NPU_NAME_SPACE;
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
void three_nn_forward_npu(int b, int n, int m, const Tensor unknown,
|
||||||
|
const Tensor known, Tensor dist2, Tensor idx) {
|
||||||
|
// transpose known [B, N, 3] -> [B, 3, N]
|
||||||
|
at::Tensor source = known.transpose(1, 2).contiguous();
|
||||||
|
at::Tensor target = unknown.contiguous();
|
||||||
|
auto originDtype = source.scalar_type();
|
||||||
|
if (originDtype == at::kHalf) {
|
||||||
|
source = source.to(at::kFloat);
|
||||||
|
target = target.to(at::kFloat);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_from_knn = false;
|
||||||
|
uint32_t nsample = 3;
|
||||||
|
EXEC_NPU_CMD(aclnnKnn, source, target, nsample, is_from_knn, idx, dist2);
|
||||||
|
if (originDtype == at::kHalf) {
|
||||||
|
dist2 = dist2.to(at::kHalf);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void three_nn_forward_impl(int b, int n, int m, const Tensor unknown,
|
||||||
|
const Tensor known, Tensor dist2, Tensor idx);
|
||||||
|
|
||||||
|
REGISTER_NPU_IMPL(three_nn_forward_impl, three_nn_forward_npu);
|
@ -55,8 +55,9 @@ class KNN(Function):
|
|||||||
center_xyz_device = center_xyz.get_device()
|
center_xyz_device = center_xyz.get_device()
|
||||||
assert center_xyz_device == xyz.get_device(), \
|
assert center_xyz_device == xyz.get_device(), \
|
||||||
'center_xyz and xyz should be put on the same device'
|
'center_xyz and xyz should be put on the same device'
|
||||||
if torch.cuda.current_device() != center_xyz_device:
|
if xyz.device.type != 'npu':
|
||||||
torch.cuda.set_device(center_xyz_device)
|
if torch.cuda.current_device() != center_xyz_device:
|
||||||
|
torch.cuda.set_device(center_xyz_device)
|
||||||
|
|
||||||
B, npoint, _ = center_xyz.shape
|
B, npoint, _ = center_xyz.shape
|
||||||
N = xyz.shape[1]
|
N = xyz.shape[1]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user