diff --git a/mmcv/ops/csrc/common/mlu/deform_roi_pool_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/deform_roi_pool_mlu_kernel.mlu deleted file mode 100644 index 6c765e3ea..000000000 --- a/mmcv/ops/csrc/common/mlu/deform_roi_pool_mlu_kernel.mlu +++ /dev/null @@ -1,712 +0,0 @@ -/************************************************************************* - * Copyright (C) 2022 Cambricon. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS - * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - *************************************************************************/ -#include - -#include "common_mlu_helper.hpp" - -#define ROI_OFFSET 5 -#define FOURSPLIT 4 -#define FIVESPLIT 5 -#define NINESPLIT 9 -#define THIRTEENSPLIT 13 - -__nram__ char nram_buffer[MAX_NRAM_SIZE]; - -template -static __mlu_func__ void bilinearInterpolate(const int input_width, T y, T x, - T *w1, T *w2, T *w3, T *w4, - int *x_low, int *x_high, - const int y_low, bool *is_empty) { - if (x < -1.0 || x > input_width) { - *is_empty = true; - return; - } - - if (x <= 0) x = 0; - - *x_low = int(x); - - if (*x_low >= input_width - 1) { - *x_high = *x_low = input_width - 1; - x = T(*x_low); - } else { - *x_high = *x_low + 1; - } - - T ly = y - y_low; - T lx = x - *x_low; - T hy = 1.0 - ly; - T hx = 1.0 - lx; - *w1 = hy * hx; - *w2 = hy * lx; - *w3 = ly * hx; - *w4 = ly * lx; -} - -template -__mlu_func__ void MLUUnion1DeformRoIPoolForward( - const T *input, const T *rois, const T *offset, T *output, - const int channels, const int height, const int width, const int num_rois, - const int pooled_height, const int pooled_width, const T spatial_scale, - const int sampling_ratio, const T gamma) { - for (int bin_index = taskId; - bin_index < num_rois * pooled_width * pooled_height; - bin_index += taskDim) { - int out_batch = bin_index / pooled_width / pooled_height; - int out_height = bin_index / pooled_width % pooled_height; - int out_width = bin_index % pooled_width; - const T *cur_roi = rois + out_batch * ROI_OFFSET; - T *nram_rois = (T *)nram_buffer; - __memcpy((void *)nram_rois, (void *)cur_roi, ROI_OFFSET * sizeof(T), - GDRAM2NRAM); - const int roi_batch = nram_rois[0]; - T roi_x_min = nram_rois[1] * spatial_scale - 0.5; - T roi_y_min = nram_rois[2] * spatial_scale - 0.5; - const T roi_x_max = nram_rois[3] * spatial_scale - 0.5; - const T roi_y_max = nram_rois[4] * spatial_scale - 0.5; - const T roi_width = roi_x_max - roi_x_min; - const T roi_height = roi_y_max - roi_y_min; - const T bin_width = roi_width / static_cast(pooled_width); - const T bin_height = roi_height / static_cast(pooled_height); - const T *offset_input = input + roi_batch * height * width * channels; - int roi_bin_grid_height = - (sampling_ratio > 0) - ? sampling_ratio - : static_cast(ceilf(roi_height / pooled_height)); - int roi_bin_grid_width = - (sampling_ratio > 0) - ? sampling_ratio - : static_cast(ceilf(roi_width / pooled_width)); - if (offset != NULL) { - const T *offset_cur = offset + - out_batch * pooled_width * pooled_height * 2 + - out_height * pooled_width + out_width; - roi_x_min += gamma * roi_width * offset_cur[0]; - roi_y_min += - gamma * roi_height * offset_cur[pooled_width * pooled_height]; - } - int type_align = NFU_ALIGN_SIZE / sizeof(T); - int channels_max_num_nram = MAX_NRAM_SIZE / sizeof(T); - int channels_nram_split = - channels_max_num_nram / NINESPLIT / type_align * type_align; - int channel_rem = channels % channels_nram_split; - int channel_loops = - channels / channels_nram_split + (channel_rem != 0 ? 1 : 0); - for (int channel_loop_index = 0; channel_loop_index < channel_loops; - ++channel_loop_index) { - int channels_num = - channels_nram_split >= channels ? channels : channels_nram_split; - const int channel_offset = channel_loop_index * channels_num; - if (channel_loop_index + 1 == channel_loops && channel_rem != 0) { - channels_num = channel_rem; - } - int channels_align = CEIL_ALIGN(channels_num, type_align); - int nram_limit = (MAX_NRAM_SIZE / sizeof(T) - channels_align) >> 1; - int c_slice = nram_limit / FOURSPLIT / type_align * type_align; - int c_slice_align = 0; - - /* NRAM partition - * - * | | ping | pong | - * |----------|-------------------|-------------------| - * | nram_out | p1 | p2 | p3 | p4 | p1 | p2 | p3 | p4 | - * - */ - - T *nram_out = (T *)nram_buffer; - T *nram_ping = nram_out + channels_align; - T *nram_pong = nram_ping + nram_limit; - __bang_write_value((T *)nram_out, channels_align, (T)0); - __bang_write_value((T *)nram_ping, FOURSPLIT * c_slice, (T)0); - __bang_write_value((T *)nram_pong, FOURSPLIT * c_slice, (T)0); - const T num_bins = - static_cast(max(roi_bin_grid_height * roi_bin_grid_width, 1)); - const T value_div = 1.0f / num_bins; - bool is_ping_empty = true; - for (int iy = 0; iy < roi_bin_grid_height; ++iy) { - T y = roi_y_min + out_height * bin_height + - static_cast(iy + .5f) * bin_height / - static_cast(roi_bin_grid_height); - if (y < -1.0 || y > height) { - is_ping_empty = true; - continue; - } - if (y <= 0) { - y = 0; - } - int y_low = 0, y_high = 0; - y_low = int(y); - if (y_low >= height - 1) { - y_high = y_low = height - 1; - y = T(y_low); - } else { - y_high = y_low + 1; - } - for (int ix = 0; ix < roi_bin_grid_width; ++ix) { - T x = roi_x_min + out_width * bin_width + - static_cast(ix + .5f) * bin_width / - static_cast(roi_bin_grid_width); - const int sample_index = iy * roi_bin_grid_width + ix; - int c_rem = channels_num; - c_slice = nram_limit / FOURSPLIT / type_align * type_align; - c_slice_align = 0; - bool is_empty = false; - T w1, w2, w3, w4; - int x_low = 0, x_high = 0; - bilinearInterpolate(width, y, x, &w1, &w2, &w3, &w4, &x_low, &x_high, - y_low, &is_empty); - if (is_empty) { - is_ping_empty = true; - continue; - } - if (is_ping_empty) { - c_slice = c_slice > c_rem ? c_rem : c_slice; - c_slice_align = CEIL_ALIGN(c_slice, type_align); - __bang_write_value(nram_ping, FOURSPLIT * c_slice_align, (T)0); - __asm__ volatile("sync;"); - __memcpy(nram_ping, - offset_input + y_low * width * channels + - x_low * channels + channel_offset, - c_slice * sizeof(T), GDRAM2NRAM); - __memcpy(nram_ping + c_slice_align, - offset_input + y_low * width * channels + - x_high * channels + channel_offset, - c_slice * sizeof(T), GDRAM2NRAM); - __memcpy(nram_ping + 2 * c_slice_align, - offset_input + y_high * width * channels + - x_low * channels + channel_offset, - c_slice * sizeof(T), GDRAM2NRAM); - __memcpy(nram_ping + 3 * c_slice_align, - offset_input + y_high * width * channels + - x_high * channels + channel_offset, - c_slice * sizeof(T), GDRAM2NRAM); - is_ping_empty = false; - } - int c_offset = 0; - int pongc_slice = 0; - int pongc_slice_align = 0; - while (c_rem > 0) { - c_slice = c_slice > c_rem ? c_rem : c_slice; - c_slice_align = CEIL_ALIGN(c_slice, type_align); - if (sample_index + 1 < roi_bin_grid_height * roi_bin_grid_width) { - int iy_tmp = (sample_index + 1) / roi_bin_grid_width; - int ix_tmp = (sample_index + 1) % roi_bin_grid_width; - y = roi_y_min + out_height * bin_height + - static_cast(iy_tmp + .5f) * bin_height / - static_cast(roi_bin_grid_height); - x = roi_x_min + out_width * bin_width + - static_cast(ix_tmp + .5f) * bin_width / - static_cast(roi_bin_grid_width); - if (y < -1.0 || y > height) { - is_empty = true; - } else { - T w1_tmp, w2_tmp, w3_tmp, w4_tmp; - if (y <= 0) { - y = 0; - } - y_low = int(y); - if (y_low >= height - 1) { - y_high = y_low = height - 1; - y = T(y_low); - } else { - y_high = y_low + 1; - } - bilinearInterpolate(width, y, x, &w1_tmp, &w2_tmp, &w3_tmp, - &w4_tmp, &x_low, &x_high, y_low, &is_empty); - } - pongc_slice = nram_limit / FOURSPLIT / type_align * type_align; - pongc_slice = - pongc_slice > channels_num ? channels_num : pongc_slice; - pongc_slice_align = CEIL_ALIGN(pongc_slice, type_align); - __bang_write_value(nram_pong, FOURSPLIT * pongc_slice_align, - (T)0); - __asm__ volatile("sync;"); - if (!is_empty) { - __memcpy_async(nram_pong, - offset_input + y_low * width * channels + - x_low * channels + channel_offset, - pongc_slice * sizeof(T), GDRAM2NRAM); - __memcpy_async(nram_pong + pongc_slice_align, - offset_input + y_low * width * channels + - x_high * channels + channel_offset, - pongc_slice * sizeof(T), GDRAM2NRAM); - __memcpy_async(nram_pong + 2 * pongc_slice_align, - offset_input + y_high * width * channels + - x_low * channels + channel_offset, - pongc_slice * sizeof(T), GDRAM2NRAM); - __memcpy_async(nram_pong + 3 * pongc_slice_align, - offset_input + y_high * width * channels + - x_high * channels + channel_offset, - pongc_slice * sizeof(T), GDRAM2NRAM); - } - } - __bang_mul_scalar(nram_ping, nram_ping, w1, c_slice_align); - __bang_mul_scalar(nram_ping + c_slice_align, - nram_ping + c_slice_align, w2, c_slice_align); - __bang_add(nram_ping, nram_ping, nram_ping + c_slice_align, - c_slice_align); - __bang_mul_scalar(nram_ping + 2 * c_slice_align, - nram_ping + 2 * c_slice_align, w3, c_slice_align); - __bang_add(nram_ping, nram_ping, nram_ping + 2 * c_slice_align, - c_slice_align); - __bang_mul_scalar(nram_ping + 3 * c_slice_align, - nram_ping + 3 * c_slice_align, w4, c_slice_align); - __bang_add(nram_ping, nram_ping, nram_ping + 3 * c_slice_align, - c_slice_align); - __bang_add(nram_out + c_offset, nram_out + c_offset, nram_ping, - c_slice_align); - T *nram_tmp = nram_ping; - nram_ping = nram_pong; - nram_pong = nram_tmp; - c_rem -= c_slice; - c_offset += c_slice; - __asm__ volatile("sync;"); - } - } - } - __bang_mul_scalar(nram_out, nram_out, value_div, channels_align); - __memcpy(output + channels * bin_index + channel_offset, nram_out, - channels_num * sizeof(T), NRAM2GDRAM); - } - } -} - -__mlu_global__ void MLUKernelDeformRoIPoolForward( - cnrtDataType_t data_type, const void *input, const void *rois, - const void *offset, void *output, const int channels, const int height, - const int width, const int num_rois, const int pooled_height, - const int pooled_width, const float spatial_scale, const int sampling_ratio, - const float gamma) { - switch (data_type) { - case CNRT_FLOAT16: { - MLUUnion1DeformRoIPoolForward((half *)input, (half *)rois, (half *)offset, - (half *)output, channels, height, width, - num_rois, pooled_height, pooled_width, - static_cast(spatial_scale), - sampling_ratio, static_cast(gamma)); - }; break; - case CNRT_FLOAT32: { - MLUUnion1DeformRoIPoolForward( - (float *)input, (float *)rois, (float *)offset, (float *)output, - channels, height, width, num_rois, pooled_height, pooled_width, - static_cast(spatial_scale), sampling_ratio, - static_cast(gamma)); - }; break; - default: { - break; - } - } -} - -void KernelDeformRoIPoolForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, - cnrtQueue_t queue, cnrtDataType_t data_type, - const void *input, const void *rois, - const void *offset, void *output, - const int channels, const int height, - const int width, const int num_rois, - const int pooled_height, const int pooled_width, - const float spatial_scale, - const int sampling_ratio, const float gamma) { - MLUKernelDeformRoIPoolForward<<>>( - data_type, input, rois, offset, output, channels, height, width, num_rois, - pooled_height, pooled_width, spatial_scale, sampling_ratio, gamma); -} - -template -__mlu_func__ void MLUUnion1DeformRoIPoolBackward( - const T *grad_output, const T *input, const T *rois, const T *offset, - T *grad_input, T *grad_offset, const int channels, const int height, - const int width, const int num_rois, const int pooled_height, - const int pooled_width, const T spatial_scale, const int sampling_ratio, - const T gamma) { - for (int bin_index = taskId; - bin_index < num_rois * pooled_width * pooled_height; - bin_index += taskDim) { - int out_batch = bin_index / pooled_width / pooled_height; - int out_height = bin_index / pooled_width % pooled_height; - int out_width = bin_index % pooled_width; - const T *cur_roi = rois + out_batch * ROI_OFFSET; - T *nram_rois = (T *)nram_buffer; - __memcpy((void *)nram_rois, (void *)cur_roi, ROI_OFFSET * sizeof(T), - GDRAM2NRAM); - const int roi_batch = nram_rois[0]; - T roi_x_min = nram_rois[1] * spatial_scale - 0.5; - T roi_y_min = nram_rois[2] * spatial_scale - 0.5; - const T roi_x_max = nram_rois[3] * spatial_scale - 0.5; - const T roi_y_max = nram_rois[4] * spatial_scale - 0.5; - const T roi_width = roi_x_max - roi_x_min; - const T roi_height = roi_y_max - roi_y_min; - const T bin_width = roi_width / static_cast(pooled_width); - const T bin_height = roi_height / static_cast(pooled_height); - const T *offset_input = input + roi_batch * height * width * channels; - T *offset_grad_input = grad_input + roi_batch * height * width * channels; - int roi_bin_grid_height = - (sampling_ratio > 0) - ? sampling_ratio - : static_cast(ceilf(roi_height / pooled_height)); - int roi_bin_grid_width = - (sampling_ratio > 0) - ? sampling_ratio - : static_cast(ceilf(roi_width / pooled_width)); - if (offset != NULL) { - const T *offset_cur = offset + - out_batch * pooled_width * pooled_height * 2 + - out_height * pooled_width + out_width; - roi_x_min += gamma * roi_width * offset_cur[0]; - roi_y_min += - gamma * roi_height * offset_cur[pooled_width * pooled_height]; - } - - /* NRAM partition - * - * If offset != NULL, NRAM partition belows. - * | | - * ping | pong | - * |---------------------------------------------------------------------|-----------|-----------| - * |nram_tmp1|nram_tmp2|nram_tmp3|nram_tmp4|nram_grad_output|nram_sum_tmp|p1|p2|p3|p4|p1|p2|p3|p4| - * - * If offset == NULL, ping and pang will not be needed. - * | | - * |----------------------------------------------------------------------------------| - * | nram_tmp1 | nram_tmp2 | nram_tmp3 | nram_tmp4 | nram_grad_output | - * - */ - - int type_align = NFU_ALIGN_SIZE / sizeof(T); - int channels_max_num_nram = MAX_NRAM_SIZE / sizeof(T); - int channels_nram_split = - channels_max_num_nram / FIVESPLIT / type_align * type_align; - int channel_rem = channels % channels_nram_split; - int channel_loops = - channels / channels_nram_split + (channel_rem != 0 ? 1 : 0); - if (offset != NULL) { - channels_nram_split = - channels_max_num_nram / THIRTEENSPLIT / type_align * type_align; - channel_rem = channels % channels_nram_split; - channel_loops = - channels / channels_nram_split + (channel_rem != 0 ? 1 : 0); - } - - for (int channel_loop_index = 0; channel_loop_index < channel_loops; - ++channel_loop_index) { - int channels_num = - channels_nram_split >= channels ? channels : channels_nram_split; - const int channel_offset = channel_loop_index * channels_num; - if (channel_loop_index + 1 == channel_loops && channel_rem != 0) { - channels_num = channel_rem; - } - int channels_align = CEIL_ALIGN(channels_num, type_align); - const int32_t nram_sum_tmp_channel = NFU_ALIGN_SIZE / sizeof(T); - int nram_limit = (MAX_NRAM_SIZE / sizeof(T) - 5 * channels_align - - nram_sum_tmp_channel) >> - 1; - int c_slice = 0; - int c_slice_align = 0; - T *nram_tmp1 = (T *)nram_buffer; - T *nram_tmp2 = (T *)nram_buffer + channels_align; - T *nram_tmp3 = (T *)nram_buffer + 2 * channels_align; - T *nram_tmp4 = (T *)nram_buffer + 3 * channels_align; - T *nram_grad_output = nram_tmp4 + channels_align; - T *nram_sum_tmp = NULL; - T *nram_ping_input = NULL; - T *nram_pong_input = NULL; - __bang_write_value((T *)nram_grad_output, channels_align, (T)0); - __asm__ volatile("sync;"); - - if (offset != NULL) { - c_slice = nram_limit / FOURSPLIT / type_align * type_align; - nram_sum_tmp = nram_grad_output + channels_align; - nram_ping_input = nram_sum_tmp + nram_sum_tmp_channel; - nram_pong_input = nram_ping_input + FOURSPLIT * c_slice; - __bang_write_value((T *)nram_sum_tmp, nram_sum_tmp_channel, (T)0); - __bang_write_value((T *)nram_ping_input, FOURSPLIT * c_slice, (T)0); - __bang_write_value((T *)nram_pong_input, FOURSPLIT * c_slice, (T)0); - __asm__ volatile("sync;"); - } - const T num_bins = - static_cast(max(roi_bin_grid_height * roi_bin_grid_width, 1)); - const T value_div = 1.0f / num_bins; - bool is_ping_empty = true; - __memcpy(nram_grad_output, - grad_output + channels * bin_index + channel_offset, - channels_num * sizeof(T), GDRAM2NRAM); - __bang_mul_scalar(nram_grad_output, nram_grad_output, value_div, - channels_align); - for (int iy = 0; iy < roi_bin_grid_height; ++iy) { - T y = roi_y_min + out_height * bin_height + - static_cast(iy + .5f) * bin_height / - static_cast(roi_bin_grid_height); - T y_tmp = y; - if (y_tmp < -1.0 || y_tmp > height) { - is_ping_empty = true; - continue; - } - if (y_tmp <= 0) { - y_tmp = 0; - } - int y_low = 0, y_high = 0; - y_low = int(y_tmp); - if (y_low >= height - 1) { - y_high = y_low = height - 1; - y_tmp = T(y_low); - } else { - y_high = y_low + 1; - } - for (int ix = 0; ix < roi_bin_grid_width; ++ix) { - T x = roi_x_min + out_width * bin_width + - static_cast(ix + .5f) * bin_width / - static_cast(roi_bin_grid_width); - const int sample_index = iy * roi_bin_grid_width + ix; - int c_rem = channels_num; - bool is_empty = false; - T w1, w2, w3, w4; - int x_low = 0, x_high = 0; - bilinearInterpolate(width, y_tmp, x, &w1, &w2, &w3, &w4, &x_low, - &x_high, y_low, &is_empty); - if (is_empty) { - is_ping_empty = true; - continue; - } - __bang_mul_scalar((T *)nram_tmp1, (T *)nram_grad_output, w1, - channels_align); - __bang_mul_scalar((T *)nram_tmp2, (T *)nram_grad_output, w2, - channels_align); - __bang_mul_scalar((T *)nram_tmp3, (T *)nram_grad_output, w3, - channels_align); - __bang_mul_scalar((T *)nram_tmp4, (T *)nram_grad_output, w4, - channels_align); - __asm__ volatile("sync;"); - __bang_atomic_add( - (T *)nram_tmp1, - (T *)(offset_grad_input + (y_low * width + x_low) * channels + - channel_offset), - (T *)nram_tmp1, channels_num); - __bang_atomic_add( - (T *)nram_tmp2, - (T *)(offset_grad_input + (y_low * width + x_high) * channels + - channel_offset), - (T *)nram_tmp2, channels_num); - __bang_atomic_add( - (T *)nram_tmp3, - (T *)(offset_grad_input + (y_high * width + x_low) * channels + - channel_offset), - (T *)nram_tmp3, channels_num); - __bang_atomic_add( - (T *)nram_tmp4, - (T *)(offset_grad_input + (y_high * width + x_high) * channels + - channel_offset), - (T *)nram_tmp4, channels_num); - if (offset != NULL) { - c_slice = nram_limit / FOURSPLIT / type_align * type_align; - c_slice_align = 0; - if (is_ping_empty) { - c_slice = c_slice > c_rem ? c_rem : c_slice; - c_slice_align = CEIL_ALIGN(c_slice, type_align); - __bang_write_value(nram_ping_input, FOURSPLIT * c_slice_align, - (T)0); - __asm__ volatile("sync;"); - const T *src_offset1 = offset_input + y_low * width * channels + - x_low * channels + channel_offset; - const T *src_offset2 = offset_input + y_low * width * channels + - x_high * channels + channel_offset; - const T *src_offset3 = offset_input + y_high * width * channels + - x_low * channels + channel_offset; - const T *src_offset4 = offset_input + y_high * width * channels + - x_high * channels + channel_offset; - __memcpy(nram_ping_input, src_offset1, c_slice * sizeof(T), - GDRAM2NRAM); - __memcpy(nram_ping_input + c_slice_align, src_offset2, - c_slice * sizeof(T), GDRAM2NRAM); - __memcpy(nram_ping_input + 2 * c_slice_align, src_offset3, - c_slice * sizeof(T), GDRAM2NRAM); - __memcpy(nram_ping_input + 3 * c_slice_align, src_offset4, - c_slice * sizeof(T), GDRAM2NRAM); - is_ping_empty = false; - } - int c_offset = 0; - int pongc_slice = 0; - int pongc_slice_align = 0; - while (c_rem > 0) { - c_slice = c_slice > c_rem ? c_rem : c_slice; - c_slice_align = CEIL_ALIGN(c_slice, type_align); - if (sample_index + 1 < roi_bin_grid_height * roi_bin_grid_width) { - int iy_tmp = (sample_index + 1) / roi_bin_grid_width; - int ix_tmp = (sample_index + 1) % roi_bin_grid_width; - T y_tmp = roi_y_min + out_height * bin_height + - static_cast(iy_tmp + .5f) * bin_height / - static_cast(roi_bin_grid_height); - T x_tmp = roi_x_min + out_width * bin_width + - static_cast(ix_tmp + .5f) * bin_width / - static_cast(roi_bin_grid_width); - int x_low_tmp = 0, x_high_tmp = 0, y_low_tmp = 0, - y_high_tmp = 0; - if (y_tmp < -1.0 || y_tmp > height) { - is_empty = true; - } else { - T w1_tmp, w2_tmp, w3_tmp, w4_tmp; - if (y_tmp <= 0) { - y_tmp = 0; - } - y_low_tmp = int(y_tmp); - if (y_low_tmp >= height - 1) { - y_high_tmp = y_low_tmp = height - 1; - y_tmp = T(y_low_tmp); - } else { - y_high_tmp = y_low_tmp + 1; - } - bilinearInterpolate(width, y_tmp, x_tmp, &w1_tmp, &w2_tmp, - &w3_tmp, &w4_tmp, &x_low_tmp, &x_high_tmp, - y_low_tmp, &is_empty); - } - pongc_slice = nram_limit / FOURSPLIT / type_align * type_align; - pongc_slice = - pongc_slice > channels_num ? channels_num : pongc_slice; - pongc_slice_align = CEIL_ALIGN(pongc_slice, type_align); - __bang_write_value(nram_pong_input, - FOURSPLIT * pongc_slice_align, (T)0); - __asm__ volatile("sync;"); - if (!is_empty) { - const T *src_offset1 = offset_input + - y_low_tmp * width * channels + - x_low_tmp * channels + channel_offset; - const T *src_offset2 = offset_input + - y_low_tmp * width * channels + - x_high_tmp * channels + channel_offset; - const T *src_offset3 = offset_input + - y_high_tmp * width * channels + - x_low_tmp * channels + channel_offset; - const T *src_offset4 = offset_input + - y_high_tmp * width * channels + - x_high_tmp * channels + channel_offset; - __memcpy_async(nram_pong_input, src_offset1, - pongc_slice * sizeof(T), GDRAM2NRAM); - __memcpy_async(nram_pong_input + pongc_slice_align, - src_offset2, pongc_slice * sizeof(T), - GDRAM2NRAM); - __memcpy_async(nram_pong_input + 2 * pongc_slice_align, - src_offset3, pongc_slice * sizeof(T), - GDRAM2NRAM); - __memcpy_async(nram_pong_input + 3 * pongc_slice_align, - src_offset4, pongc_slice * sizeof(T), - GDRAM2NRAM); - } - } - - __bang_mul_scalar(nram_tmp1, nram_ping_input + 3 * c_slice_align, - y - y_low, c_slice_align); - __bang_mul_scalar(nram_tmp2, nram_ping_input + c_slice_align, - y_high - y, c_slice_align); - __bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align); - __bang_mul_scalar(nram_tmp2, nram_ping_input + 2 * c_slice_align, - y_low - y, c_slice_align); - __bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align); - __bang_mul_scalar(nram_tmp2, nram_ping_input, y - y_high, - c_slice_align); - __bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align); - __bang_mul_scalar(nram_tmp1, nram_tmp1, gamma * roi_width, - c_slice_align); - __bang_mul(nram_tmp1, nram_grad_output, nram_tmp1, c_slice_align); - const int32_t kernel_width = - c_slice_align / nram_sum_tmp_channel + - (int32_t)(c_slice_align % nram_sum_tmp_channel > 0); - __bang_sumpool(nram_sum_tmp, nram_tmp1, nram_sum_tmp_channel, 1, - kernel_width, 1, kernel_width, kernel_width, 1); - __bang_reduce_sum(nram_sum_tmp, nram_sum_tmp, - nram_sum_tmp_channel); - __bang_atomic_add( - (T *)nram_sum_tmp, - (T *)(grad_offset + - out_batch * pooled_width * pooled_height * 2 + - out_height * pooled_width + out_width), - (T *)nram_sum_tmp, 1); - __bang_write_value((T *)nram_sum_tmp, nram_sum_tmp_channel, (T)0); - __bang_mul_scalar(nram_tmp1, nram_ping_input + 3 * c_slice_align, - x - x_low, c_slice_align); - __bang_mul_scalar(nram_tmp2, nram_ping_input + 2 * c_slice_align, - x_high - x, c_slice_align); - __bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align); - __bang_mul_scalar(nram_tmp2, nram_ping_input + c_slice_align, - x_low - x, c_slice_align); - __bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align); - __bang_mul_scalar(nram_tmp2, nram_ping_input, x - x_high, - c_slice_align); - __bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align); - __bang_mul_scalar(nram_tmp1, nram_tmp1, gamma * roi_height, - c_slice_align); - __bang_mul(nram_tmp1, nram_grad_output, nram_tmp1, c_slice_align); - __bang_sumpool(nram_sum_tmp, nram_tmp1, nram_sum_tmp_channel, 1, - kernel_width, 1, kernel_width, kernel_width, 1); - __bang_reduce_sum(nram_sum_tmp, nram_sum_tmp, - NFU_ALIGN_SIZE / sizeof(T)); - __bang_atomic_add( - (T *)nram_sum_tmp, - (T *)(grad_offset + - out_batch * pooled_width * pooled_height * 2 + - pooled_width * pooled_height + - out_height * pooled_width + out_width), - (T *)nram_sum_tmp, 1); - - T *nram_tmp = nram_ping_input; - nram_ping_input = nram_pong_input; - nram_pong_input = nram_tmp; - c_rem -= c_slice; - c_offset += c_slice; - __asm__ volatile("sync;"); - } - } - } - } - } - } -} - -__mlu_global__ void MLUKernelDeformRoIPoolBackward( - cnrtDataType_t data_type, const void *grad_output, const void *input, - const void *rois, const void *offset, void *grad_input, void *grad_offset, - const int channels, const int height, const int width, const int num_rois, - const int pooled_height, const int pooled_width, const float spatial_scale, - const int sampling_ratio, const float gamma) { - switch (data_type) { - case CNRT_FLOAT16: { - MLUUnion1DeformRoIPoolBackward( - (half *)grad_output, (half *)input, (half *)rois, (half *)offset, - (half *)grad_input, (half *)grad_offset, channels, height, width, - num_rois, pooled_height, pooled_width, - static_cast(spatial_scale), sampling_ratio, - static_cast(gamma)); - }; break; - case CNRT_FLOAT32: { - MLUUnion1DeformRoIPoolBackward( - (float *)grad_output, (float *)input, (float *)rois, (float *)offset, - (float *)grad_input, (float *)grad_offset, channels, height, width, - num_rois, pooled_height, pooled_width, - static_cast(spatial_scale), sampling_ratio, - static_cast(gamma)); - }; break; - default: { - break; - } - } -} - -void KernelDeformRoIPoolBackward( - cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, - cnrtDataType_t data_type, const void *grad_output, const void *input, - const void *rois, const void *offset, void *grad_input, void *grad_offset, - const int channels, const int height, const int width, const int num_rois, - const int pooled_height, const int pooled_width, const float spatial_scale, - const int sampling_ratio, const float gamma) { - MLUKernelDeformRoIPoolBackward<<>>( - data_type, grad_output, input, rois, offset, grad_input, grad_offset, - channels, height, width, num_rois, pooled_height, pooled_width, - spatial_scale, sampling_ratio, gamma); -} diff --git a/mmcv/ops/csrc/pytorch/mlu/deform_roi_pool_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/deform_roi_pool_mlu.cpp index 4d73cbbe5..90a625c4a 100644 --- a/mmcv/ops/csrc/pytorch/mlu/deform_roi_pool_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/deform_roi_pool_mlu.cpp @@ -9,254 +9,59 @@ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ -#include "pytorch_device_registry.hpp" -#include "pytorch_mlu_helper.hpp" - -void KernelDeformRoIPoolForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, - cnrtQueue_t queue, cnrtDataType_t data_type, - const void *input, const void *rois, - const void *offset, void *output, - const int channels, const int height, - const int width, const int num_rois, - const int pooled_height, const int pooled_width, - const float spatial_scale, - const int sampling_ratio, const float gamma); - -void KernelDeformRoIPoolBackward( - cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, - cnrtDataType_t data_type, const void *grad_output, const void *input, - const void *rois, const void *offset, void *grad_input, void *grad_offset, - const int channels, const int height, const int width, const int num_rois, - const int pooled_height, const int pooled_width, const float spatial_scale, - const int sampling_ratio, const float gamma); - -// policy function for forward and backward -static void policyFunc(const int bin_num, cnrtDim3_t *k_dim, - cnrtFunctionType_t *k_type) { - const size_t cluster_limit = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); - ; - const size_t core_limit = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); - const size_t bin_num_align = CEIL_ALIGN(bin_num, core_limit); - k_dim->x = core_limit; - k_dim->y = (bin_num_align / core_limit) > cluster_limit - ? cluster_limit - : (bin_num_align / core_limit); - k_dim->z = 1; - *k_type = CNRT_FUNC_TYPE_UNION1; -} +#include "mlu_common_helper.h" void DeformRoIPoolForwardMLUKernelLauncher(Tensor input, Tensor rois, Tensor offset, Tensor output, int pooled_height, int pooled_width, float spatial_scale, int sampling_ratio, float gamma) { - // Check dtype. - TORCH_CHECK( - input.scalar_type() == at::kFloat || input.scalar_type() == at::kHalf, - "input type should be Float or Half, got ", input.scalar_type()); - TORCH_CHECK(input.scalar_type() == rois.scalar_type(), - "rois should have the same type as input"); - - // Check shape. - TORCH_CHECK(input.dim() == 4, "input should be 4d tensor, got ", input.dim(), - "D."); - TORCH_CHECK(rois.dim() == 2, "rois should be 2d tensor, got ", rois.dim(), - "D."); - if (offset.defined() && offset.numel() > 0) { - TORCH_CHECK(input.scalar_type() == offset.scalar_type(), - "offset should have the same type as input"); - TORCH_CHECK(offset.dim() == 4, "offset should be 4d tensor, got ", - offset.dim(), "D."); - TORCH_CHECK( - (offset.size(0) == rois.size(0)), "offset.size(0) = ", offset.size(0), - "while rois.size(0)) = ", rois.size(0), ". They should be the same."); - TORCH_CHECK((offset.size(1) == 2), "offset.size(1) should be 2, ", - "but now offset.size(1) = ", offset.size(1), "."); - TORCH_CHECK((offset.size(2) == output.size(2)), - "offset.size(2) = ", offset.size(2), - "while output.size(2)) = ", output.size(2), - ". They should be the same."); - TORCH_CHECK((offset.size(3) == output.size(3)), - "offset.size(3) = ", offset.size(3), - "while output.size(3)) = ", output.size(3), - ". They should be the same."); - } - - TORCH_CHECK(spatial_scale > 0 && spatial_scale <= 1, - "spatial_scale should be within (0, 1], got ", spatial_scale, - "."); - - // compute kernel params - auto height = input.size(2); - auto width = input.size(3); - auto channels = input.size(1); - auto num_rois = output.size(0); - - if (output.numel() == 0) { - output = at::zeros({num_rois, channels, pooled_height, pooled_width}, - input.options()); - return; - } - - // zero element check - TORCH_CHECK(input.size(0) != 0, "input.size(0) should not be zero, got ", - input.size(0)); - TORCH_CHECK(rois.numel() != 0, "rois.numel() should not be zero, got ", - rois.numel()); - if (input.numel() == 0 || output.numel() == 0) { - return; - } - - // large tensor check - const size_t max_input_num = 2147483648; // 2^31, 2G num - TORCH_CHECK(input.numel() < max_input_num, - "input.numel() should be less than 2147483648, got ", - input.numel()); - TORCH_CHECK(rois.numel() < max_input_num, - "rois.numel() should be less than 2147483648, got ", - rois.numel()); - TORCH_CHECK(output.numel() < max_input_num, - "output.numel() should be less than 2147483648, got ", - output.numel()); - TORCH_CHECK(!offset.defined() || offset.numel() < max_input_num, - "offset.numel() should be less than 2147483648, got ", - offset.numel()); - auto memory_format = torch_mlu::cnnl::ops::get_channels_last_memory_format(input.dim()); auto input_ = torch_mlu::cnnl::ops::cnnl_contiguous(input, memory_format); + auto rois_contiguous = + torch_mlu::cnnl::ops::cnnl_contiguous(rois, rois.suggest_memory_format()); + auto output_contiguous = + torch_mlu::cnnl::ops::cnnl_contiguous(output, memory_format); - at::Tensor output_ = - at::empty({num_rois, channels, pooled_height, pooled_width}, - input.options(), memory_format); + MluOpTensorDescriptor input_desc, rois_desc, offset_desc, output_desc; + input_desc.set_with_layout(input_, MLUOP_LAYOUT_NHWC); + rois_desc.set(rois_contiguous); + output_desc.set_with_layout(output_contiguous, MLUOP_LAYOUT_NHWC); - // calculate task dimension - cnrtDim3_t k_dim; - cnrtFunctionType_t k_type; - policyFunc(num_rois * pooled_height * pooled_width, &k_dim, &k_type); - - // get compute queue - auto queue = torch_mlu::getCurQueue(); + mluOpTensorDescriptor_t offset_real_desc = NULL; + void *offset_ptr = NULL; + if (offset.defined() && offset.numel() > 0) { + auto offset_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + offset, offset.suggest_memory_format()); + offset_desc.set(offset_contiguous); + offset_real_desc = offset_desc.desc(); + auto offset_impl = torch_mlu::getMluTensorImpl(offset_contiguous); + offset_ptr = offset_impl->cnnlMalloc(); + } // get ptr of tensors auto input_impl = torch_mlu::getMluTensorImpl(input_); auto input_ptr = input_impl->cnnlMalloc(); - auto rois_impl = torch_mlu::getMluTensorImpl(rois); + auto rois_impl = torch_mlu::getMluTensorImpl(rois_contiguous); auto rois_ptr = rois_impl->cnnlMalloc(); - auto offset_impl = torch_mlu::getMluTensorImpl(offset); - auto offset_ptr = offset_impl->cnnlMalloc(); - auto output_impl = torch_mlu::getMluTensorImpl(output_); + auto output_impl = torch_mlu::getMluTensorImpl(output_contiguous); auto output_ptr = output_impl->cnnlMalloc(); - // get comput dtype of input - cnrtDataType_t data_type = torch_mlu::toCnrtDtype(input_.dtype()); + // get compute handle + auto handle = mluOpGetCurrentHandle(); + mluOpDeformRoiPoolForward( + handle, input_desc.desc(), input_ptr, rois_desc.desc(), rois_ptr, + offset_real_desc, offset_ptr, pooled_height, pooled_width, spatial_scale, + sampling_ratio, gamma, output_desc.desc(), output_ptr); - // launch kernel - CNLOG(INFO) << "Launch Kernel MLUKernelDeformRoIPoolForward<<<" << k_dim.x - << ", " << k_dim.y << ", " << k_dim.z << ">>>"; - - KernelDeformRoIPoolForward(k_dim, k_type, queue, data_type, input_ptr, - rois_ptr, offset_ptr, output_ptr, channels, height, - width, num_rois, pooled_height, pooled_width, - spatial_scale, sampling_ratio, gamma); - - output.copy_(output_); + output.copy_(output_contiguous); } void DeformRoIPoolBackwardMLUKernelLauncher( Tensor grad_output, Tensor input, Tensor rois, Tensor offset, Tensor grad_input, Tensor grad_offset, int pooled_height, int pooled_width, float spatial_scale, int sampling_ratio, float gamma) { - // Check dtype. - TORCH_CHECK( - input.scalar_type() == at::kFloat || input.scalar_type() == at::kHalf, - "input type should be Float or Half, got ", input.scalar_type()); - TORCH_CHECK(input.scalar_type() == grad_output.scalar_type(), - "grad_output should have the same type as input"); - TORCH_CHECK(input.scalar_type() == rois.scalar_type(), - "rois should have the same type as input"); - TORCH_CHECK(input.scalar_type() == grad_input.scalar_type(), - "grad_input should have the same type as input"); - - // Check shape. - TORCH_CHECK(grad_output.dim() == 4, "grad_output should be 4d tensor, got ", - grad_output.dim(), "D."); - TORCH_CHECK(input.dim() == 4, "input should be 4d tensor, got ", input.dim(), - "D."); - TORCH_CHECK(rois.dim() == 2, "rois should be 2d tensor, got ", rois.dim(), - "D."); - if (offset.defined() && offset.numel() > 0) { - TORCH_CHECK(input.scalar_type() == offset.scalar_type(), - "offset should have the same type as input"); - TORCH_CHECK(offset.dim() == 4, "offset should be 4d tensor, got ", - offset.dim(), "D."); - TORCH_CHECK( - (offset.size(0) == rois.size(0)), "offset.size(0) = ", offset.size(0), - "while rois.size(0)) = ", rois.size(0), ". They should be the same."); - TORCH_CHECK((offset.size(1) == 2), "offset.size(1) should be 2, ", - "but now offset.size(1) = ", offset.size(1), "."); - TORCH_CHECK((offset.size(2) == grad_output.size(2)), - "offset.size(2) = ", offset.size(2), - "while grad_output.size(2)) = ", grad_output.size(2), - ". They should be the same."); - TORCH_CHECK((offset.size(3) == grad_output.size(3)), - "offset.size(3) = ", offset.size(3), - "while grad_output.size(3)) = ", grad_output.size(3), - ". They should be the same."); - } - - TORCH_CHECK(spatial_scale > 0 && spatial_scale <= 1, - "spatial_scale should be within (0, 1], got ", spatial_scale); - - // Check relationship between tensor. - TORCH_CHECK((grad_output.size(0) == rois.size(0)), - "grad_output.size(0) = ", grad_output.size(0), - "while rois.size(0)) = ", rois.size(0), - ". They should be the same."); - TORCH_CHECK((grad_output.size(1) == input.size(1)), - "grad_output.size(1) = ", grad_output.size(1), - "while input.size(1)) = ", input.size(1), - ". They should be the same."); - TORCH_CHECK((grad_output.size(2) == pooled_height), - "grad_output.size(2) = ", grad_output.size(2), - "while pooled_height = ", pooled_height, - ". They should be the same."); - TORCH_CHECK((grad_output.size(3) == pooled_width), - "grad_output.size(3) = ", grad_output.size(3), - "while pooled_width = ", pooled_width, - ". They should be the same."); - - // compute kernel params - auto batch = input.size(0); - auto channels = input.size(1); - auto height = input.size(2); - auto width = input.size(3); - auto num_rois = grad_output.size(0); - - // zero element check - TORCH_CHECK(input.size(0) != 0, "input.size(0) should not be zero, got ", - input.size(0)); - TORCH_CHECK(rois.numel() != 0, "rois.numel() should not be zero, got ", - rois.numel()); - if (input.numel() == 0 || grad_output.numel() == 0) { - return; - } - - // large tensor check - const size_t max_input_num = 2147483648; // 2^31, 2G num - TORCH_CHECK(input.numel() < max_input_num, - "input.numel() should be less than 2147483648, got ", - input.numel()); - TORCH_CHECK(rois.numel() < max_input_num, - "rois.numel() should be less than 2147483648, got ", - rois.numel()); - TORCH_CHECK(grad_output.numel() < max_input_num, - "grad_output.numel() should be less than 2147483648, got ", - grad_output.numel()); - TORCH_CHECK(!offset.defined() || offset.numel() < max_input_num, - "offset.numel() should be less than 2147483648, got ", - offset.numel()); - auto memory_format = torch_mlu::cnnl::ops::get_channels_last_memory_format(grad_output.dim()); auto grad_output_ = @@ -264,45 +69,56 @@ void DeformRoIPoolBackwardMLUKernelLauncher( memory_format = torch_mlu::cnnl::ops::get_channels_last_memory_format(input.dim()); auto input_ = torch_mlu::cnnl::ops::cnnl_contiguous(input, memory_format); - at::Tensor grad_input_ = at::empty({batch, channels, height, width}, - input.options(), memory_format) - .zero_(); - - // calculate task dimension - cnrtDim3_t k_dim; - cnrtFunctionType_t k_type; - policyFunc(num_rois * pooled_height * pooled_width, &k_dim, &k_type); - - // get compute queue - auto queue = torch_mlu::getCurQueue(); + auto rois_contiguous = + torch_mlu::cnnl::ops::cnnl_contiguous(rois, rois.suggest_memory_format()); + auto grad_input_ = + torch_mlu::cnnl::ops::cnnl_contiguous(grad_input, memory_format); // get ptr of tensors auto grad_output_impl = torch_mlu::getMluTensorImpl(grad_output_); auto grad_output_ptr = grad_output_impl->cnnlMalloc(); auto input_impl = torch_mlu::getMluTensorImpl(input_); auto input_ptr = input_impl->cnnlMalloc(); - auto rois_impl = torch_mlu::getMluTensorImpl(rois); + auto rois_impl = torch_mlu::getMluTensorImpl(rois_contiguous); auto rois_ptr = rois_impl->cnnlMalloc(); - auto offset_impl = torch_mlu::getMluTensorImpl(offset); - auto offset_ptr = offset_impl->cnnlMalloc(); auto grad_input_impl = torch_mlu::getMluTensorImpl(grad_input_); auto grad_input_ptr = grad_input_impl->cnnlMalloc(); - auto grad_offset_impl = torch_mlu::getMluTensorImpl(grad_offset); - auto grad_offset_ptr = grad_offset_impl->cnnlMalloc(); - // get comput dtype of input - cnrtDataType_t data_type = torch_mlu::toCnrtDtype(input.dtype()); - - // launch kernel - CNLOG(INFO) << "Launch Kernel KernelDeformRoIPoolBackward<<<" << k_dim.x - << ", " << k_dim.y << ", " << k_dim.z << ">>>"; - - KernelDeformRoIPoolBackward(k_dim, k_type, queue, data_type, grad_output_ptr, - input_ptr, rois_ptr, offset_ptr, grad_input_ptr, - grad_offset_ptr, channels, height, width, - num_rois, pooled_height, pooled_width, - spatial_scale, sampling_ratio, gamma); + MluOpTensorDescriptor grad_output_desc, input_desc, rois_desc, offset_desc, + grad_input_desc, grad_offset_desc; + grad_output_desc.set_with_layout(grad_output_, MLUOP_LAYOUT_NHWC); + input_desc.set_with_layout(input_, MLUOP_LAYOUT_NHWC); + rois_desc.set(rois_contiguous); + grad_input_desc.set_with_layout(grad_input_, MLUOP_LAYOUT_NHWC); + mluOpTensorDescriptor_t offset_real_desc = NULL; + void *offset_ptr = NULL; + if (offset.defined() && offset.numel() > 0) { + auto offset_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + offset, offset.suggest_memory_format()); + offset_desc.set(offset_contiguous); + offset_real_desc = offset_desc.desc(); + auto offset_impl = torch_mlu::getMluTensorImpl(offset_contiguous); + offset_ptr = offset_impl->cnnlMalloc(); + } + mluOpTensorDescriptor_t grad_offset_real_desc = NULL; + void *grad_offset_ptr = NULL; + if (grad_offset.defined() && grad_offset.numel() > 0) { + auto grad_offset_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + grad_offset, grad_offset.suggest_memory_format()); + grad_offset_desc.set(grad_offset_contiguous); + grad_offset_real_desc = grad_offset_desc.desc(); + auto grad_offset_impl = torch_mlu::getMluTensorImpl(grad_offset_contiguous); + grad_offset_ptr = grad_offset_impl->cnnlMalloc(); + } + // get compute handle + auto handle = mluOpGetCurrentHandle(); + mluOpDeformRoiPoolBackward( + handle, grad_output_desc.desc(), grad_output_ptr, input_desc.desc(), + input_ptr, rois_desc.desc(), rois_ptr, offset_real_desc, offset_ptr, + pooled_height, pooled_width, spatial_scale, sampling_ratio, gamma, + grad_input_desc.desc(), grad_input_ptr, grad_offset_real_desc, + grad_offset_ptr); grad_input.copy_(grad_input_); } diff --git a/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.cpp b/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.cpp index 442c55dbc..3a76b4971 100644 --- a/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.cpp @@ -72,6 +72,39 @@ void MluOpTensorDescriptor::set(Tensor t) { set_desc(t, layout, data_type, dim_array); } +void MluOpTensorDescriptor::set_with_layout(Tensor t, + mluOpTensorLayout_t layout) { + mluOpDataType_t data_type = getMluOpDataType(t.dtype()); + int t_dim = t.dim(); + std::vector shape_info = checkUpperBoundAndCastTo(t.sizes().vec()); + std::vector stride_info = + checkUpperBoundAndCastTo(t.strides().vec()); + if (layout == MLUOP_LAYOUT_NHWC || layout == MLUOP_LAYOUT_NDHWC || + layout == MLUOP_LAYOUT_NLC) { + convertShapeAndStride(shape_info, stride_info); + } else if (layout == MLUOP_LAYOUT_HWCN) { + auto convertDepthWiseConvShapeStride = [](const std::vector& vec, + std::vector& target_vec, + std::vector& stride_vec) { + // NCHW --> HWCN + target_vec[0] = static_cast(vec[2]); + target_vec[1] = static_cast(vec[3]); + target_vec[2] = static_cast(vec[1]); + target_vec[3] = static_cast(vec[0]); + // Calculate Stride just like contiguous of HWCN. + stride_vec[3] = 1; + stride_vec[2] = target_vec[3] * stride_vec[3]; + stride_vec[1] = target_vec[2] * stride_vec[2]; + stride_vec[0] = target_vec[1] * stride_vec[1]; + }; + convertDepthWiseConvShapeStride(t.sizes().vec(), shape_info, stride_info); + } + TORCH_CHECK(mluOpSetTensorDescriptorEx( + desc_, layout, data_type, t_dim, shape_info.data(), + stride_info.data()) == MLUOP_STATUS_SUCCESS, + "mluOpSetTensorDescriptorEx execution failed."); +} + void MluOpTensorDescriptor::set_desc(const at::Tensor& t, mluOpTensorLayout_t layout, mluOpDataType_t dtype, diff --git a/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h b/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h index 38805c0de..436f055f0 100644 --- a/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h +++ b/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h @@ -30,6 +30,7 @@ class MluOpTensorDescriptor { ~MluOpTensorDescriptor() { mluOpDestroyTensorDescriptor(desc_); } void set(at::Tensor); + void set_with_layout(at::Tensor, mluOpTensorLayout_t layout); mluOpTensorDescriptor_t desc() { return desc_; } private: @@ -52,3 +53,47 @@ class MluOpHandle { void setQueue(cnrtQueue_t queue) { mluOpSetQueue(handle, queue); } mluOpHandle_t handle; }; + +// modify tensor size and stride order based on +// channels_first to channels_last or channels_last_3d. +// which this is not same with pytorch original layout, +// this real layout is based on data storage real order. +// example: modify channels_last tensor dim to nhwc tensor desc. +// N C H W --> N H W C +// C*H*W 1 W C --> C*H*W W C 1 +template +void convertShapeAndStride(std::vector& shape_info, + std::vector& stride_info) { + TORCH_MLU_CHECK(shape_info.size() == stride_info.size(), + "shape size need equal to stride size."); + const int dim = shape_info.size(); + std::vector temp_shape_info(dim); + std::vector temp_stride_info(dim); + temp_shape_info[0] = shape_info[0]; + temp_stride_info[0] = stride_info[0]; + for (size_t i = 0; i < dim - 1; ++i) { + const int index = (i + 1) % (dim - 1) + 1; + temp_shape_info[i + 1] = shape_info[index]; + temp_stride_info[i + 1] = stride_info[index]; + } + shape_info.assign(temp_shape_info.begin(), temp_shape_info.end()); + stride_info.assign(temp_stride_info.begin(), temp_stride_info.end()); +} + +// torch tensor provides int64_t type of shape and stride, +// but mluops descriptor requires type int32. +// use this function to ensure safe CAST, or report an error. +template +std::vector checkUpperBoundAndCastTo(const std::vector& input) { + std::vector output; + output.reserve(input.size()); + for (const auto& val : input) { + if (val > std::numeric_limits::max()) { + TORCH_MLU_CHECK(false, "Requires dim size not greater than ", + std::numeric_limits::max(), ". But got ", val, + "."); + } + output.push_back(static_cast(val)); + } + return output; +}