From a55f4b7f40e1ab407b62251e03d1c5fbe89a92ca Mon Sep 17 00:00:00 2001 From: liuduanhui <103939338+DanieeelLiu@users.noreply.github.com> Date: Mon, 3 Apr 2023 23:30:40 +0800 Subject: [PATCH] [Enhancement] Replace the implementation of three_nn_forward with mlu-ops (#2719) --- .../csrc/common/mlu/three_nn_mlu_kernel.mlu | 466 ------------------ mmcv/ops/csrc/pytorch/mlu/three_nn_mlu.cpp | 97 ++-- 2 files changed, 30 insertions(+), 533 deletions(-) delete mode 100644 mmcv/ops/csrc/common/mlu/three_nn_mlu_kernel.mlu diff --git a/mmcv/ops/csrc/common/mlu/three_nn_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/three_nn_mlu_kernel.mlu deleted file mode 100644 index 792738510..000000000 --- a/mmcv/ops/csrc/common/mlu/three_nn_mlu_kernel.mlu +++ /dev/null @@ -1,466 +0,0 @@ -/************************************************************************* - * Copyright (C) 2022 Cambricon. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS - * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - *************************************************************************/ -#include "common_mlu_helper.hpp" -#include - -__nram__ char nram_buffer[MAX_NRAM_SIZE]; - -#if __BANG_ARCH__ >= 322 -/** - * returns the index of ret, which is stored at the 1st position of the `ret`, - * used after bang_min - */ -__mlu_func__ uint32_t getIndice(half *ret) { - uint32_t indice = *((uint32_t *)((uint16_t *)ret + 1)); - return indice; -} - -/** - * returns the index of ret, which is stored at the 1st position of the `ret`, - * used after bang_min - */ -__mlu_func__ uint32_t getIndice(float *ret) { - uint32_t indice = ((uint32_t *)ret)[1]; - return indice; -} -#endif - -template -__mlu_func__ void auxArgmin(T *nram_dst, T *nram_src, const int num_deal, - T *value, int *index) { - __bang_min(nram_dst, nram_src, num_deal); - *value = nram_dst[0]; - __bang_write_value(nram_dst, num_deal, *value); - __bang_eq(nram_dst, nram_src, nram_dst, num_deal); - __bang_findfirst1((uint32_t *)nram_dst, nram_dst, num_deal); - *index = *((int *)nram_dst); -} - -template -__mlu_func__ void auxFuncFind3Min(T *nram_aux_a, const int auxa_offset, - int *nram_aux_b, const int auxb_offset, - T *nram_dest, T *nram_aux_sort_a, - int *nram_aux_sort_b, const int deal_offset) { - __bang_write_value(nram_aux_sort_a, auxa_offset, (T)(INFINITY)); - __bang_write_value(nram_aux_sort_b, auxb_offset, (int)0); - int index = 0; - for (int i = 0; i < 3; i++) { -#if __BANG_ARCH__ >= 322 - __bang_argmin(nram_dest, nram_aux_a, auxa_offset); - nram_aux_sort_a[i] = nram_dest[0]; - index = getIndice(nram_dest); -#else - T value = 0; - auxArgmin(nram_dest, nram_aux_a, auxa_offset, &value, &index); - nram_aux_sort_a[i] = value; -#endif - nram_aux_sort_b[i] = nram_aux_b[index]; - __memset_nram(nram_aux_a + index, 1, (T)(INFINITY)); - } - __memcpy((char *)nram_aux_a, (char *)nram_aux_sort_a, auxa_offset * sizeof(T), - NRAM2NRAM); - __memcpy((char *)nram_aux_b, (char *)nram_aux_sort_b, - auxb_offset * sizeof(int), NRAM2NRAM); -} - -template -__mlu_func__ void auxFuncSort(T *nram_aux_a, const int auxa_offset, - int *nram_aux_b, const int auxb_offset, - T *nram_dest, T *nram_help_value, - int *nram_help_idx, const int num_deal, - const int deal_offset) { - for (int k = 0; k < num_deal; ++k) { - auxFuncFind3Min(nram_aux_a + k * auxa_offset, auxa_offset, - nram_aux_b + k * auxb_offset, auxb_offset, nram_dest, - nram_help_value, nram_help_idx, deal_offset); - } -} - -template -__mlu_func__ void auxFuncNN( - size_t *output_aux_sort_a_gap, size_t *output_aux_sort_b_gap, - size_t *output_aux_dest_gap, size_t *output_unknown_gap, - size_t *output_known_gap, size_t *output_dist_gap, size_t *auxillary_a_gap, - size_t *auxillary_b_gap, size_t *known_num_deal, size_t *unknown_num_deal, - size_t *align_num, size_t *auxa_offset, size_t *auxb_offset) { - /* - * nram partition: - * |-NFU_ALIGN_SIZE-|-2*NFU_ALIGN_SIZE-|-X*3*sizeof(T)-| - * space: | aux_sort_a | aux_sort_b | nram_unknown | - * - * | ------ (Y * 7 *sizeof(T)) ---------------- | - * | nram_known | nram_dist | nram_dest | - * - * | -X * NFU_ALIGN_SIZE ---|---X * 2 * NFU_ALIGN_SIZE-| - * | output_dist(aux_a) | output_dist(aux_b) | - * 200 series - * X = (MAX_NRAM - 3 * NFU_ALIGN_SIZE) * (2/3) / (3 * sizeof(T) + 3 * - * NFU_ALIGN_SIZE) - * Y = (MAX_NRAM - 3 * NFU_ALIGN_SIZE) * (1/3) / (7 * sizeof(T)) - * 300 series - * X = (MAX_NRAM - 3 * NFU_ALIGN_SIZE) * (4/5) / (3 * - * sizeof(T) + 3 * NFU_ALIGN_SIZE) - * Y = (MAX_NRAM - 3 * NFU_ALIGN_SIZE) * - * (1/5) / (7 * sizeof(T)) - * - */ - - *align_num = NFU_ALIGN_SIZE / sizeof(T); - *auxa_offset = NFU_ALIGN_SIZE / sizeof(T); - *auxb_offset = 2 * NFU_ALIGN_SIZE / sizeof(int); -#if __BANG_ARCH__ >= 322 - *known_num_deal = PAD_DOWN( - (MAX_NRAM_SIZE - 3 * NFU_ALIGN_SIZE) / 5 / (7 * sizeof(T)), *align_num); - *unknown_num_deal = PAD_DOWN((MAX_NRAM_SIZE - 3 * NFU_ALIGN_SIZE) / 5 * 4 / - (3 * sizeof(T) + 3 * NFU_ALIGN_SIZE), - *align_num); -#else - *known_num_deal = PAD_DOWN( - (MAX_NRAM_SIZE - 3 * NFU_ALIGN_SIZE) / 3 / (7 * sizeof(T)), *align_num); - *unknown_num_deal = PAD_DOWN((MAX_NRAM_SIZE - 3 * NFU_ALIGN_SIZE) / 3 * 2 / - (3 * sizeof(T) + 3 * NFU_ALIGN_SIZE), - *align_num); -#endif - - *output_aux_sort_a_gap = 0; - *output_aux_sort_b_gap = *output_aux_sort_a_gap + NFU_ALIGN_SIZE; - *output_aux_dest_gap = *output_aux_sort_b_gap + 2 * NFU_ALIGN_SIZE; - - *output_unknown_gap = *output_aux_dest_gap + *known_num_deal * sizeof(T); - *output_known_gap = *output_unknown_gap + *unknown_num_deal * 3 * sizeof(T); - *output_dist_gap = *output_known_gap + *known_num_deal * 3 * sizeof(T); - *auxillary_a_gap = *output_dist_gap + *known_num_deal * 3 * sizeof(T); - *auxillary_b_gap = *auxillary_a_gap + *unknown_num_deal * NFU_ALIGN_SIZE; -} - -#if __BANG_ARCH__ >= 322 -template -__mlu_func__ bool containNanInf(T *nram_unknown) { - if (std::isnan(nram_unknown[0]) || std::isnan(nram_unknown[1]) || - std::isnan(nram_unknown[2]) || std::isinf(nram_unknown[0]) || - std::isinf(nram_unknown[1]) || std::isinf(nram_unknown[2])) - return true; - else - return false; -} -#endif - -template -__mlu_func__ void computeThreeNN(T *nram_unknown, T *nram_known, T *nram_dist, - T *nram_dest, T *nram_aux_a, - T *nram_aux_sort_a, int *nram_aux_b, - int *nram_aux_sort_b, const int known_num_deal, - const int known_seg_num, const int deal_offset, - const int known_count, - const int known_count_align) { - __bang_write_value(nram_dist, 3 * known_num_deal, (T)(INFINITY)); -#if __BANG_ARCH__ >= 322 - if (!containNanInf(nram_unknown)) { -#endif - // x1 - x2 - __bang_sub_scalar(nram_dist, nram_known, nram_unknown[0], - known_count_align); - // y1 - y2 - __bang_sub_scalar(nram_dist + known_count_align, - nram_known + known_count_align, nram_unknown[1], - known_count_align); - // z1 - z2 - __bang_sub_scalar(nram_dist + 2 * known_count_align, - nram_known + 2 * known_count_align, nram_unknown[2], - known_count_align); - __bang_square(nram_dist, nram_dist, 3 * known_count_align); - __bang_add(nram_dist, nram_dist, nram_dist + known_count_align, - known_count_align); - __bang_add(nram_dist, nram_dist, nram_dist + 2 * known_count_align, - known_count_align); -#if __BANG_ARCH__ >= 322 - } -#endif - - int index = 0; - for (int i = 0; i < 3; i++) { -#if __BANG_ARCH__ >= 322 - __bang_argmin(nram_dest, nram_dist, known_count_align); - nram_aux_a[i + deal_offset] = nram_dest[0]; - index = getIndice(nram_dest); -#else - T value = 0; - auxArgmin(nram_dest, nram_dist, known_count_align, &value, &index); - nram_aux_a[i + deal_offset] = value; -#endif - nram_aux_b[i + deal_offset] = index + known_seg_num * known_num_deal; - __memset_nram(nram_dist + index, 1, (T)(INFINITY)); - } -} - -template -__mlu_func__ void loadTransposedKnownTensor( - char *nram_known, char *nram_dist, const char *known_gdram, - const int known_num_deal, const int batch_id, const int m, - const int known_seg_num, const int count, const int count_align_num) { - __bang_write_value(nram_known, 3 * known_num_deal, (T)(INFINITY)); -#if __BANG_ARCH__ >= 322 - __bang_write_value(nram_dist, 3 * known_num_deal, (T)(INFINITY)); - __memcpy(nram_dist, - known_gdram + - (batch_id * m * 3 + known_seg_num * known_num_deal) * sizeof(T), - count * sizeof(T), GDRAM2NRAM, count_align_num * sizeof(T), - m * sizeof(T), 2); - __bang_minequal((T *)nram_known, (T *)nram_known, (T *)nram_dist, - 3 * count_align_num); -#else - __memcpy(nram_known, - known_gdram + - (batch_id * m * 3 + known_seg_num * known_num_deal) * sizeof(T), - count * sizeof(T), GDRAM2NRAM, count_align_num * sizeof(T), - m * sizeof(T), 2); -#endif -} - -template -__mlu_func__ void loadUnknownTensor(char *nram_unknown, - const char *unknown_gdram, - const int unknown_num_deal, - const int unknown_seg_num, const int count, - const int count_align_num) { - __memcpy(nram_unknown, - unknown_gdram + unknown_seg_num * unknown_num_deal * 3 * sizeof(T), - count * 3 * sizeof(T), GDRAM2NRAM); -} - -template -__mlu_func__ void auxProcessSegment( - const int m, const int n, T *nram_unknown, T *nram_known, T *nram_dist, - T *nram_dest, T *known_gdram, T *nram_aux_a, const int auxa_offset, - int *nram_aux_b, const int auxb_offset, T *nram_aux_sort_a, - int *nram_aux_sort_b, const int unknown_num_deal, const int known_num_deal, - const int known_seg_num, const int unknown_seg_num, const int unknown_count, - const int known_count, const int known_count_align, const int start_idx, - int *deal_offset) { - int pre_batch_id = -1; - int cur_batch_id = -1; - pre_batch_id = start_idx / n; - - // if aux_a space is not enough, get the first 3 min among aux_a and clear. - if (*deal_offset >= PAD_DOWN(auxa_offset, 3)) { - auxFuncSort(nram_aux_a, auxa_offset, nram_aux_b, auxb_offset, nram_dest, - nram_aux_sort_a, nram_aux_sort_b, unknown_count, *deal_offset); - *deal_offset = 3; - } - - // load i'th segment of known batch data. - loadTransposedKnownTensor((char *)nram_known, (char *)nram_dist, - (char *)known_gdram, known_num_deal, - pre_batch_id, m, known_seg_num, known_count, - known_count_align); - - for (int k = 0; k < unknown_count; ++k) { - cur_batch_id = (start_idx + k) / n; - if (cur_batch_id != pre_batch_id) { // if batch id of unknown data changed, - // load corresponding known batch data - pre_batch_id = cur_batch_id; - loadTransposedKnownTensor((char *)nram_known, (char *)nram_dist, - (char *)known_gdram, known_num_deal, - pre_batch_id, m, known_seg_num, known_count, - known_count_align); - } - computeThreeNN(nram_unknown + 3 * k, nram_known, nram_dist, nram_dest, - nram_aux_a + k * auxa_offset, nram_aux_sort_a, - nram_aux_b + k * auxb_offset, nram_aux_sort_b, - known_num_deal, known_seg_num, *deal_offset, known_count, - known_count_align); - } -} - -template -__mlu_global__ void MLUUnion1KernelThreeNN(const int b, const int n, - const int m, char *unknown_gdram, - char *known_gdram, char *dist2_gdram, - int *idx_gdram) { - if (coreId == 0x80) { - return; - } - - size_t output_aux_sort_a_gap = 0, output_aux_sort_b_gap = 0, - output_dest_gap = 0, output_unknown_gap = 0, output_known_gap = 0, - output_dist_gap = 0, auxillary_a_gap = 0, auxillary_b_gap = 0, - known_num_deal = 0, unknown_num_deal = 0, align_num = 0, - auxa_offset = 0, auxb_offset = 0; - auxFuncNN(&output_aux_sort_a_gap, &output_aux_sort_b_gap, &output_dest_gap, - &output_unknown_gap, &output_known_gap, &output_dist_gap, - &auxillary_a_gap, &auxillary_b_gap, &known_num_deal, - &unknown_num_deal, &align_num, &auxa_offset, &auxb_offset); - - int num_per_core = b * n / taskDim; - const int core_offset = num_per_core; - - char *unknown_gdram_start = - unknown_gdram + taskId * 3 * core_offset * sizeof(T); - char *known_gdram_start = known_gdram; - char *output_dist_start = dist2_gdram + taskId * 3 * core_offset * sizeof(T); - int *output_idx_start = idx_gdram + taskId * 3 * core_offset; - - const int rem = (b * n) % taskDim; - if (taskId == taskDim - 1) { - num_per_core += rem; - } - - const int unknown_repeat = - num_per_core / unknown_num_deal; // if unknown number is big, process it - // by unknown_repeat times. - const int unknown_rem = num_per_core % unknown_num_deal; // unknown reminder - const int unknown_rem_align = PAD_UP(unknown_rem, align_num); - - const int known_repeat = - m / known_num_deal; // if known number is big, process it by - // unknown_repeat times. - const int known_rem = m % known_num_deal; // known reminder - const int known_rem_align = PAD_UP(known_rem, align_num); - - char *nram_aux_sort_a = nram_buffer; - int *nram_aux_sort_b = (int *)(nram_buffer + output_aux_sort_b_gap); - char *nram_dest = nram_buffer + output_dest_gap; - char *nram_unknown = nram_buffer + output_unknown_gap; - char *nram_known = nram_buffer + output_known_gap; - char *nram_dist = nram_buffer + output_dist_gap; - char *nram_aux_a = nram_buffer + auxillary_a_gap; - int *nram_aux_b = (int *)(nram_buffer + auxillary_b_gap); - int deal_offset = 0; - int start_idx = -1; - - for (int j = 0; j < unknown_repeat; - ++j) { // process data within a unknown_repeat - // if unknown need to be process segmentally, use a aux_a and aux_b - // space to find first 3 minimum dist. - __bang_write_value(nram_aux_a, unknown_num_deal * auxa_offset, - (T)(INFINITY)); - __bang_write_value(nram_aux_b, unknown_num_deal * auxb_offset, (int)0); - loadUnknownTensor(nram_unknown, unknown_gdram_start, unknown_num_deal, j, - unknown_num_deal, unknown_num_deal); - - deal_offset = 0; - start_idx = taskId * core_offset + j * unknown_num_deal; - - for (int i = 0; i < known_repeat; - ++i) { // process known data in segmentally. - auxProcessSegment( - m, n, (T *)nram_unknown, (T *)nram_known, (T *)nram_dist, - (T *)nram_dest, (T *)known_gdram_start, (T *)nram_aux_a, auxa_offset, - nram_aux_b, auxb_offset, (T *)nram_aux_sort_a, nram_aux_sort_b, - unknown_num_deal, known_num_deal, i, j, unknown_num_deal, - known_num_deal, known_num_deal, start_idx, &deal_offset); - deal_offset += 3; - } - - if (known_rem > 0) { // process known rem - __bang_write_value(nram_known, 3 * known_num_deal, (T)(INFINITY)); - auxProcessSegment( - m, n, (T *)nram_unknown, (T *)nram_known, (T *)nram_dist, - (T *)nram_dest, (T *)known_gdram_start, (T *)nram_aux_a, auxa_offset, - nram_aux_b, auxb_offset, (T *)nram_aux_sort_a, nram_aux_sort_b, - unknown_num_deal, known_num_deal, known_repeat, j, unknown_num_deal, - known_rem, known_rem_align, start_idx, &deal_offset); - } - - deal_offset += 3; - - if (deal_offset > 3) { - auxFuncSort((T *)nram_aux_a, auxa_offset, nram_aux_b, auxb_offset, - (T *)nram_dest, (T *)nram_aux_sort_a, nram_aux_sort_b, - unknown_num_deal, deal_offset); - deal_offset = 0; - } - - __memcpy((char *)output_dist_start + j * unknown_num_deal * 3 * sizeof(T), - (char *)nram_aux_a, 3 * sizeof(T), NRAM2GDRAM, 3 * sizeof(T), - auxa_offset * sizeof(T), unknown_num_deal - 1); - __memcpy((char *)output_idx_start + j * unknown_num_deal * 3 * sizeof(int), - (char *)nram_aux_b, 3 * sizeof(int), NRAM2GDRAM, 3 * sizeof(int), - auxb_offset * sizeof(int), unknown_num_deal - 1); - } - - if (unknown_rem > 0) { // process unknown rem - deal_offset = 0; - __bang_write_value(nram_aux_a, unknown_num_deal * auxa_offset, - (T)(INFINITY)); - __bang_write_value(nram_aux_b, unknown_num_deal * auxb_offset, (int)0); - loadUnknownTensor(nram_unknown, unknown_gdram_start, unknown_num_deal, - unknown_repeat, unknown_rem, unknown_rem_align); - start_idx = taskId * core_offset + unknown_repeat * unknown_num_deal; - - for (int i = 0; i < known_repeat; ++i) { - auxProcessSegment( - m, n, (T *)nram_unknown, (T *)nram_known, (T *)nram_dist, - (T *)nram_dest, (T *)known_gdram_start, (T *)nram_aux_a, auxa_offset, - nram_aux_b, auxb_offset, (T *)nram_aux_sort_a, nram_aux_sort_b, - unknown_num_deal, known_num_deal, i, unknown_repeat, unknown_rem, - known_num_deal, known_num_deal, start_idx, &deal_offset); - deal_offset += 3; - } - - if (known_rem > 0) { - __bang_write_value(nram_known, 3 * known_num_deal, (T)(INFINITY)); - start_idx = taskId * core_offset + unknown_repeat * unknown_num_deal; - - auxProcessSegment( - m, n, (T *)nram_unknown, (T *)nram_known, (T *)nram_dist, - (T *)nram_dest, (T *)known_gdram_start, (T *)nram_aux_a, auxa_offset, - nram_aux_b, auxb_offset, (T *)nram_aux_sort_a, nram_aux_sort_b, - unknown_num_deal, known_num_deal, known_repeat, unknown_repeat, - unknown_rem, known_rem, known_rem_align, start_idx, &deal_offset); - - deal_offset += 3; - } - if (deal_offset > 3) { - auxFuncSort((T *)nram_aux_a, auxa_offset, nram_aux_b, auxb_offset, - (T *)nram_dest, (T *)nram_aux_sort_a, nram_aux_sort_b, - unknown_rem, deal_offset); - deal_offset = 0; - } - - __memcpy((char *)output_dist_start + - unknown_repeat * unknown_num_deal * 3 * sizeof(T), - (char *)nram_aux_a, 3 * sizeof(T), NRAM2GDRAM, 3 * sizeof(T), - auxa_offset * sizeof(T), unknown_rem - 1); - __memcpy((char *)output_idx_start + - unknown_repeat * unknown_num_deal * 3 * sizeof(int), - (char *)nram_aux_b, 3 * sizeof(int), NRAM2GDRAM, 3 * sizeof(int), - auxb_offset * sizeof(int), unknown_rem - 1); - } -} - -template __mlu_global__ void MLUUnion1KernelThreeNN( - const int b, const int n, const int m, char *unknown_gdram, - char *known_gdram, char *dist2_gdram, int *idx_gdram); - -template __mlu_global__ void MLUUnion1KernelThreeNN( - const int b, const int n, const int m, char *unknown_gdram, - char *known_gdram, char *dist2_gdram, int *idx_gdram); - -void KernelThreeNNForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, - cnrtQueue_t queue, cnrtDataType_t data_type, - const void *unknown, const void *known, void *dist2, - int *idx, const int b, const int n, const int m) { - switch (data_type) { - case CNRT_FLOAT16: { - MLUUnion1KernelThreeNN<<>>( - b, n, m, (char *)unknown, (char *)known, (char *)dist2, idx); - }; break; - case CNRT_FLOAT32: { - MLUUnion1KernelThreeNN<<>>( - b, n, m, (char *)unknown, (char *)known, (char *)dist2, idx); - }; break; - default: { - break; - } - } -} diff --git a/mmcv/ops/csrc/pytorch/mlu/three_nn_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/three_nn_mlu.cpp index f407e3f63..d46480269 100644 --- a/mmcv/ops/csrc/pytorch/mlu/three_nn_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/three_nn_mlu.cpp @@ -9,84 +9,47 @@ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ -#include "pytorch_device_registry.hpp" -#include "pytorch_mlu_helper.hpp" - -void KernelThreeNNForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, - cnrtQueue_t queue, cnrtDataType_t data_type, - const void *unknown, const void *known, void *dist2, - int *idx, const int b, const int n, const int m); +#include "mlu_common_helper.h" void ThreeNNMLUKernelLauncher(int b, int n, int m, const Tensor unknown, const Tensor known, Tensor dist2, Tensor idx) { - // Check dtype. - TORCH_CHECK( - unknown.scalar_type() == at::kFloat || unknown.scalar_type() == at::kHalf, - "unknown type should be Float or Half, got ", unknown.scalar_type(), "."); - TORCH_CHECK(unknown.scalar_type() == known.scalar_type(), - "known should have the same type as unknown."); - TORCH_CHECK(unknown.scalar_type() == dist2.scalar_type(), - "dist2 should have the same type as unknown."); - TORCH_CHECK(idx.scalar_type() == at::kInt, "idx type should be Int."); + auto unknown_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + unknown, unknown.suggest_memory_format()); + auto known_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + known, known.suggest_memory_format()); + auto dist2_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + dist2, dist2.suggest_memory_format()); + auto idx_contiguous = + torch_mlu::cnnl::ops::cnnl_contiguous(idx, idx.suggest_memory_format()); - // Check shape. - TORCH_CHECK(unknown.dim() == 3, "unknown should be 3d tensor, got ", - unknown.dim(), "D."); - TORCH_CHECK(known.dim() == 3, "known should be 3d tensor, got ", known.dim(), - "D."); - TORCH_CHECK(unknown.size(0) == known.size(0), - "known.dim0 should be equal to unknown.dim0, got ", known.size(0), - "."); - TORCH_CHECK(unknown.size(2) == 3, "unknown dim2 should be 3, got ", - unknown.size(2), "."); - TORCH_CHECK(known.size(2) == 3, "known dim2 should be 3, got ", known.size(2), - "."); + MluOpTensorDescriptor unknown_desc, known_desc, dist2_desc, idx_desc; + unknown_desc.set(unknown_contiguous); + known_desc.set(known_contiguous); + dist2_desc.set(dist2_contiguous); + idx_desc.set(idx_contiguous); - // zero element check - TORCH_CHECK(unknown.numel() > 0, - "unknown.numel should greater than zero, got ", unknown.numel(), - "."); - if (known.numel() == 0) { - // return if known zero element - return; - } + auto handle = mluOpGetCurrentHandle(); + size_t workspace_size = 0; + mluOpGetThreeNNForwardWorkspaceSize(handle, known_desc.desc(), + &workspace_size); + auto known_workspace = + at::empty(workspace_size, known.options().dtype(at::kByte)); - // large tensor check - const size_t max_input_num = 2147483648; // 2^31, 2G num - TORCH_CHECK(unknown.numel() < max_input_num, - "unknown.numel() should be less than 2147483648, got ", - unknown.numel(), "."); - TORCH_CHECK(known.numel() < max_input_num, - "known.numel() should be less than 2147483648, got ", - known.numel(), "."); - - // get compute queue - auto queue = torch_mlu::getCurQueue(); - - // get ptr of tensors - auto unknown_impl = torch_mlu::getMluTensorImpl(unknown); + auto unknown_impl = torch_mlu::getMluTensorImpl(unknown_contiguous); + auto known_impl = torch_mlu::getMluTensorImpl(known_contiguous); + auto dist2_impl = torch_mlu::getMluTensorImpl(dist2_contiguous); + auto idx_impl = torch_mlu::getMluTensorImpl(idx_contiguous); + auto workspace_impl = torch_mlu::getMluTensorImpl(known_workspace); auto unknown_ptr = unknown_impl->cnnlMalloc(); - auto known_t = known.permute({0, 2, 1}).contiguous(); - auto known_impl = torch_mlu::getMluTensorImpl(known_t); auto known_ptr = known_impl->cnnlMalloc(); - auto dist2_impl = torch_mlu::getMluTensorImpl(dist2); auto dist2_ptr = dist2_impl->cnnlMalloc(); - auto idx_impl = torch_mlu::getMluTensorImpl(idx); auto idx_ptr = idx_impl->cnnlMalloc(); + auto workspace_ptr = workspace_impl->cnnlMalloc(); - cnrtJobType_t k_type = CNRT_FUNC_TYPE_UNION1; - cnrtDim3_t k_dim; - k_dim.x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); - k_dim.y = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); - k_dim.z = 1; - cnrtDataType_t data_type = torch_mlu::toCnrtDtype(unknown.dtype()); - - // launch kernel - CNLOG(INFO) << "Launch Kernel MLUKernelThreeNNForward<<<" << k_dim.x << ", " - << k_dim.y << ", " << k_dim.z << ">>>."; - - KernelThreeNNForward(k_dim, k_type, queue, data_type, unknown_ptr, known_ptr, - dist2_ptr, (int *)idx_ptr, b, n, m); + mluOpThreeNNForward(handle, unknown_desc.desc(), unknown_ptr, + known_desc.desc(), known_ptr, workspace_ptr, + workspace_size, dist2_desc.desc(), dist2_ptr, + idx_desc.desc(), idx_ptr); } void three_nn_forward_mlu(int b, int n, int m, const Tensor unknown,