[Enhancement] Replace the implementation of deform_roi_pool with mlu-ops (#2598)

* [Feature] Replace the implementation of deform_roi_pool with mlu-ops

* [Feature] Modify code

---------

Co-authored-by: budefei <budefei@cambricon.com>
pull/2697/head
bdf 2023-03-24 10:27:45 +08:00 committed by GitHub
parent 0d1b224fb1
commit 9fcf48c4c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 145 additions and 963 deletions

View File

@ -1,712 +0,0 @@
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <iostream>
#include "common_mlu_helper.hpp"
#define ROI_OFFSET 5
#define FOURSPLIT 4
#define FIVESPLIT 5
#define NINESPLIT 9
#define THIRTEENSPLIT 13
__nram__ char nram_buffer[MAX_NRAM_SIZE];
template <typename T>
static __mlu_func__ void bilinearInterpolate(const int input_width, T y, T x,
T *w1, T *w2, T *w3, T *w4,
int *x_low, int *x_high,
const int y_low, bool *is_empty) {
if (x < -1.0 || x > input_width) {
*is_empty = true;
return;
}
if (x <= 0) x = 0;
*x_low = int(x);
if (*x_low >= input_width - 1) {
*x_high = *x_low = input_width - 1;
x = T(*x_low);
} else {
*x_high = *x_low + 1;
}
T ly = y - y_low;
T lx = x - *x_low;
T hy = 1.0 - ly;
T hx = 1.0 - lx;
*w1 = hy * hx;
*w2 = hy * lx;
*w3 = ly * hx;
*w4 = ly * lx;
}
template <typename T>
__mlu_func__ void MLUUnion1DeformRoIPoolForward(
const T *input, const T *rois, const T *offset, T *output,
const int channels, const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, const T spatial_scale,
const int sampling_ratio, const T gamma) {
for (int bin_index = taskId;
bin_index < num_rois * pooled_width * pooled_height;
bin_index += taskDim) {
int out_batch = bin_index / pooled_width / pooled_height;
int out_height = bin_index / pooled_width % pooled_height;
int out_width = bin_index % pooled_width;
const T *cur_roi = rois + out_batch * ROI_OFFSET;
T *nram_rois = (T *)nram_buffer;
__memcpy((void *)nram_rois, (void *)cur_roi, ROI_OFFSET * sizeof(T),
GDRAM2NRAM);
const int roi_batch = nram_rois[0];
T roi_x_min = nram_rois[1] * spatial_scale - 0.5;
T roi_y_min = nram_rois[2] * spatial_scale - 0.5;
const T roi_x_max = nram_rois[3] * spatial_scale - 0.5;
const T roi_y_max = nram_rois[4] * spatial_scale - 0.5;
const T roi_width = roi_x_max - roi_x_min;
const T roi_height = roi_y_max - roi_y_min;
const T bin_width = roi_width / static_cast<T>(pooled_width);
const T bin_height = roi_height / static_cast<T>(pooled_height);
const T *offset_input = input + roi_batch * height * width * channels;
int roi_bin_grid_height =
(sampling_ratio > 0)
? sampling_ratio
: static_cast<int>(ceilf(roi_height / pooled_height));
int roi_bin_grid_width =
(sampling_ratio > 0)
? sampling_ratio
: static_cast<int>(ceilf(roi_width / pooled_width));
if (offset != NULL) {
const T *offset_cur = offset +
out_batch * pooled_width * pooled_height * 2 +
out_height * pooled_width + out_width;
roi_x_min += gamma * roi_width * offset_cur[0];
roi_y_min +=
gamma * roi_height * offset_cur[pooled_width * pooled_height];
}
int type_align = NFU_ALIGN_SIZE / sizeof(T);
int channels_max_num_nram = MAX_NRAM_SIZE / sizeof(T);
int channels_nram_split =
channels_max_num_nram / NINESPLIT / type_align * type_align;
int channel_rem = channels % channels_nram_split;
int channel_loops =
channels / channels_nram_split + (channel_rem != 0 ? 1 : 0);
for (int channel_loop_index = 0; channel_loop_index < channel_loops;
++channel_loop_index) {
int channels_num =
channels_nram_split >= channels ? channels : channels_nram_split;
const int channel_offset = channel_loop_index * channels_num;
if (channel_loop_index + 1 == channel_loops && channel_rem != 0) {
channels_num = channel_rem;
}
int channels_align = CEIL_ALIGN(channels_num, type_align);
int nram_limit = (MAX_NRAM_SIZE / sizeof(T) - channels_align) >> 1;
int c_slice = nram_limit / FOURSPLIT / type_align * type_align;
int c_slice_align = 0;
/* NRAM partition
*
* | | ping | pong |
* |----------|-------------------|-------------------|
* | nram_out | p1 | p2 | p3 | p4 | p1 | p2 | p3 | p4 |
*
*/
T *nram_out = (T *)nram_buffer;
T *nram_ping = nram_out + channels_align;
T *nram_pong = nram_ping + nram_limit;
__bang_write_value((T *)nram_out, channels_align, (T)0);
__bang_write_value((T *)nram_ping, FOURSPLIT * c_slice, (T)0);
__bang_write_value((T *)nram_pong, FOURSPLIT * c_slice, (T)0);
const T num_bins =
static_cast<T>(max(roi_bin_grid_height * roi_bin_grid_width, 1));
const T value_div = 1.0f / num_bins;
bool is_ping_empty = true;
for (int iy = 0; iy < roi_bin_grid_height; ++iy) {
T y = roi_y_min + out_height * bin_height +
static_cast<T>(iy + .5f) * bin_height /
static_cast<T>(roi_bin_grid_height);
if (y < -1.0 || y > height) {
is_ping_empty = true;
continue;
}
if (y <= 0) {
y = 0;
}
int y_low = 0, y_high = 0;
y_low = int(y);
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = T(y_low);
} else {
y_high = y_low + 1;
}
for (int ix = 0; ix < roi_bin_grid_width; ++ix) {
T x = roi_x_min + out_width * bin_width +
static_cast<T>(ix + .5f) * bin_width /
static_cast<T>(roi_bin_grid_width);
const int sample_index = iy * roi_bin_grid_width + ix;
int c_rem = channels_num;
c_slice = nram_limit / FOURSPLIT / type_align * type_align;
c_slice_align = 0;
bool is_empty = false;
T w1, w2, w3, w4;
int x_low = 0, x_high = 0;
bilinearInterpolate(width, y, x, &w1, &w2, &w3, &w4, &x_low, &x_high,
y_low, &is_empty);
if (is_empty) {
is_ping_empty = true;
continue;
}
if (is_ping_empty) {
c_slice = c_slice > c_rem ? c_rem : c_slice;
c_slice_align = CEIL_ALIGN(c_slice, type_align);
__bang_write_value(nram_ping, FOURSPLIT * c_slice_align, (T)0);
__asm__ volatile("sync;");
__memcpy(nram_ping,
offset_input + y_low * width * channels +
x_low * channels + channel_offset,
c_slice * sizeof(T), GDRAM2NRAM);
__memcpy(nram_ping + c_slice_align,
offset_input + y_low * width * channels +
x_high * channels + channel_offset,
c_slice * sizeof(T), GDRAM2NRAM);
__memcpy(nram_ping + 2 * c_slice_align,
offset_input + y_high * width * channels +
x_low * channels + channel_offset,
c_slice * sizeof(T), GDRAM2NRAM);
__memcpy(nram_ping + 3 * c_slice_align,
offset_input + y_high * width * channels +
x_high * channels + channel_offset,
c_slice * sizeof(T), GDRAM2NRAM);
is_ping_empty = false;
}
int c_offset = 0;
int pongc_slice = 0;
int pongc_slice_align = 0;
while (c_rem > 0) {
c_slice = c_slice > c_rem ? c_rem : c_slice;
c_slice_align = CEIL_ALIGN(c_slice, type_align);
if (sample_index + 1 < roi_bin_grid_height * roi_bin_grid_width) {
int iy_tmp = (sample_index + 1) / roi_bin_grid_width;
int ix_tmp = (sample_index + 1) % roi_bin_grid_width;
y = roi_y_min + out_height * bin_height +
static_cast<T>(iy_tmp + .5f) * bin_height /
static_cast<T>(roi_bin_grid_height);
x = roi_x_min + out_width * bin_width +
static_cast<T>(ix_tmp + .5f) * bin_width /
static_cast<T>(roi_bin_grid_width);
if (y < -1.0 || y > height) {
is_empty = true;
} else {
T w1_tmp, w2_tmp, w3_tmp, w4_tmp;
if (y <= 0) {
y = 0;
}
y_low = int(y);
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = T(y_low);
} else {
y_high = y_low + 1;
}
bilinearInterpolate(width, y, x, &w1_tmp, &w2_tmp, &w3_tmp,
&w4_tmp, &x_low, &x_high, y_low, &is_empty);
}
pongc_slice = nram_limit / FOURSPLIT / type_align * type_align;
pongc_slice =
pongc_slice > channels_num ? channels_num : pongc_slice;
pongc_slice_align = CEIL_ALIGN(pongc_slice, type_align);
__bang_write_value(nram_pong, FOURSPLIT * pongc_slice_align,
(T)0);
__asm__ volatile("sync;");
if (!is_empty) {
__memcpy_async(nram_pong,
offset_input + y_low * width * channels +
x_low * channels + channel_offset,
pongc_slice * sizeof(T), GDRAM2NRAM);
__memcpy_async(nram_pong + pongc_slice_align,
offset_input + y_low * width * channels +
x_high * channels + channel_offset,
pongc_slice * sizeof(T), GDRAM2NRAM);
__memcpy_async(nram_pong + 2 * pongc_slice_align,
offset_input + y_high * width * channels +
x_low * channels + channel_offset,
pongc_slice * sizeof(T), GDRAM2NRAM);
__memcpy_async(nram_pong + 3 * pongc_slice_align,
offset_input + y_high * width * channels +
x_high * channels + channel_offset,
pongc_slice * sizeof(T), GDRAM2NRAM);
}
}
__bang_mul_scalar(nram_ping, nram_ping, w1, c_slice_align);
__bang_mul_scalar(nram_ping + c_slice_align,
nram_ping + c_slice_align, w2, c_slice_align);
__bang_add(nram_ping, nram_ping, nram_ping + c_slice_align,
c_slice_align);
__bang_mul_scalar(nram_ping + 2 * c_slice_align,
nram_ping + 2 * c_slice_align, w3, c_slice_align);
__bang_add(nram_ping, nram_ping, nram_ping + 2 * c_slice_align,
c_slice_align);
__bang_mul_scalar(nram_ping + 3 * c_slice_align,
nram_ping + 3 * c_slice_align, w4, c_slice_align);
__bang_add(nram_ping, nram_ping, nram_ping + 3 * c_slice_align,
c_slice_align);
__bang_add(nram_out + c_offset, nram_out + c_offset, nram_ping,
c_slice_align);
T *nram_tmp = nram_ping;
nram_ping = nram_pong;
nram_pong = nram_tmp;
c_rem -= c_slice;
c_offset += c_slice;
__asm__ volatile("sync;");
}
}
}
__bang_mul_scalar(nram_out, nram_out, value_div, channels_align);
__memcpy(output + channels * bin_index + channel_offset, nram_out,
channels_num * sizeof(T), NRAM2GDRAM);
}
}
}
__mlu_global__ void MLUKernelDeformRoIPoolForward(
cnrtDataType_t data_type, const void *input, const void *rois,
const void *offset, void *output, const int channels, const int height,
const int width, const int num_rois, const int pooled_height,
const int pooled_width, const float spatial_scale, const int sampling_ratio,
const float gamma) {
switch (data_type) {
case CNRT_FLOAT16: {
MLUUnion1DeformRoIPoolForward((half *)input, (half *)rois, (half *)offset,
(half *)output, channels, height, width,
num_rois, pooled_height, pooled_width,
static_cast<half>(spatial_scale),
sampling_ratio, static_cast<half>(gamma));
}; break;
case CNRT_FLOAT32: {
MLUUnion1DeformRoIPoolForward(
(float *)input, (float *)rois, (float *)offset, (float *)output,
channels, height, width, num_rois, pooled_height, pooled_width,
static_cast<float>(spatial_scale), sampling_ratio,
static_cast<float>(gamma));
}; break;
default: {
break;
}
}
}
void KernelDeformRoIPoolForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, cnrtDataType_t data_type,
const void *input, const void *rois,
const void *offset, void *output,
const int channels, const int height,
const int width, const int num_rois,
const int pooled_height, const int pooled_width,
const float spatial_scale,
const int sampling_ratio, const float gamma) {
MLUKernelDeformRoIPoolForward<<<k_dim, k_type, queue>>>(
data_type, input, rois, offset, output, channels, height, width, num_rois,
pooled_height, pooled_width, spatial_scale, sampling_ratio, gamma);
}
template <typename T>
__mlu_func__ void MLUUnion1DeformRoIPoolBackward(
const T *grad_output, const T *input, const T *rois, const T *offset,
T *grad_input, T *grad_offset, const int channels, const int height,
const int width, const int num_rois, const int pooled_height,
const int pooled_width, const T spatial_scale, const int sampling_ratio,
const T gamma) {
for (int bin_index = taskId;
bin_index < num_rois * pooled_width * pooled_height;
bin_index += taskDim) {
int out_batch = bin_index / pooled_width / pooled_height;
int out_height = bin_index / pooled_width % pooled_height;
int out_width = bin_index % pooled_width;
const T *cur_roi = rois + out_batch * ROI_OFFSET;
T *nram_rois = (T *)nram_buffer;
__memcpy((void *)nram_rois, (void *)cur_roi, ROI_OFFSET * sizeof(T),
GDRAM2NRAM);
const int roi_batch = nram_rois[0];
T roi_x_min = nram_rois[1] * spatial_scale - 0.5;
T roi_y_min = nram_rois[2] * spatial_scale - 0.5;
const T roi_x_max = nram_rois[3] * spatial_scale - 0.5;
const T roi_y_max = nram_rois[4] * spatial_scale - 0.5;
const T roi_width = roi_x_max - roi_x_min;
const T roi_height = roi_y_max - roi_y_min;
const T bin_width = roi_width / static_cast<T>(pooled_width);
const T bin_height = roi_height / static_cast<T>(pooled_height);
const T *offset_input = input + roi_batch * height * width * channels;
T *offset_grad_input = grad_input + roi_batch * height * width * channels;
int roi_bin_grid_height =
(sampling_ratio > 0)
? sampling_ratio
: static_cast<int>(ceilf(roi_height / pooled_height));
int roi_bin_grid_width =
(sampling_ratio > 0)
? sampling_ratio
: static_cast<int>(ceilf(roi_width / pooled_width));
if (offset != NULL) {
const T *offset_cur = offset +
out_batch * pooled_width * pooled_height * 2 +
out_height * pooled_width + out_width;
roi_x_min += gamma * roi_width * offset_cur[0];
roi_y_min +=
gamma * roi_height * offset_cur[pooled_width * pooled_height];
}
/* NRAM partition
*
* If offset != NULL, NRAM partition belows.
* | |
* ping | pong |
* |---------------------------------------------------------------------|-----------|-----------|
* |nram_tmp1|nram_tmp2|nram_tmp3|nram_tmp4|nram_grad_output|nram_sum_tmp|p1|p2|p3|p4|p1|p2|p3|p4|
*
* If offset == NULL, ping and pang will not be needed.
* | |
* |----------------------------------------------------------------------------------|
* | nram_tmp1 | nram_tmp2 | nram_tmp3 | nram_tmp4 | nram_grad_output |
*
*/
int type_align = NFU_ALIGN_SIZE / sizeof(T);
int channels_max_num_nram = MAX_NRAM_SIZE / sizeof(T);
int channels_nram_split =
channels_max_num_nram / FIVESPLIT / type_align * type_align;
int channel_rem = channels % channels_nram_split;
int channel_loops =
channels / channels_nram_split + (channel_rem != 0 ? 1 : 0);
if (offset != NULL) {
channels_nram_split =
channels_max_num_nram / THIRTEENSPLIT / type_align * type_align;
channel_rem = channels % channels_nram_split;
channel_loops =
channels / channels_nram_split + (channel_rem != 0 ? 1 : 0);
}
for (int channel_loop_index = 0; channel_loop_index < channel_loops;
++channel_loop_index) {
int channels_num =
channels_nram_split >= channels ? channels : channels_nram_split;
const int channel_offset = channel_loop_index * channels_num;
if (channel_loop_index + 1 == channel_loops && channel_rem != 0) {
channels_num = channel_rem;
}
int channels_align = CEIL_ALIGN(channels_num, type_align);
const int32_t nram_sum_tmp_channel = NFU_ALIGN_SIZE / sizeof(T);
int nram_limit = (MAX_NRAM_SIZE / sizeof(T) - 5 * channels_align -
nram_sum_tmp_channel) >>
1;
int c_slice = 0;
int c_slice_align = 0;
T *nram_tmp1 = (T *)nram_buffer;
T *nram_tmp2 = (T *)nram_buffer + channels_align;
T *nram_tmp3 = (T *)nram_buffer + 2 * channels_align;
T *nram_tmp4 = (T *)nram_buffer + 3 * channels_align;
T *nram_grad_output = nram_tmp4 + channels_align;
T *nram_sum_tmp = NULL;
T *nram_ping_input = NULL;
T *nram_pong_input = NULL;
__bang_write_value((T *)nram_grad_output, channels_align, (T)0);
__asm__ volatile("sync;");
if (offset != NULL) {
c_slice = nram_limit / FOURSPLIT / type_align * type_align;
nram_sum_tmp = nram_grad_output + channels_align;
nram_ping_input = nram_sum_tmp + nram_sum_tmp_channel;
nram_pong_input = nram_ping_input + FOURSPLIT * c_slice;
__bang_write_value((T *)nram_sum_tmp, nram_sum_tmp_channel, (T)0);
__bang_write_value((T *)nram_ping_input, FOURSPLIT * c_slice, (T)0);
__bang_write_value((T *)nram_pong_input, FOURSPLIT * c_slice, (T)0);
__asm__ volatile("sync;");
}
const T num_bins =
static_cast<T>(max(roi_bin_grid_height * roi_bin_grid_width, 1));
const T value_div = 1.0f / num_bins;
bool is_ping_empty = true;
__memcpy(nram_grad_output,
grad_output + channels * bin_index + channel_offset,
channels_num * sizeof(T), GDRAM2NRAM);
__bang_mul_scalar(nram_grad_output, nram_grad_output, value_div,
channels_align);
for (int iy = 0; iy < roi_bin_grid_height; ++iy) {
T y = roi_y_min + out_height * bin_height +
static_cast<T>(iy + .5f) * bin_height /
static_cast<T>(roi_bin_grid_height);
T y_tmp = y;
if (y_tmp < -1.0 || y_tmp > height) {
is_ping_empty = true;
continue;
}
if (y_tmp <= 0) {
y_tmp = 0;
}
int y_low = 0, y_high = 0;
y_low = int(y_tmp);
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y_tmp = T(y_low);
} else {
y_high = y_low + 1;
}
for (int ix = 0; ix < roi_bin_grid_width; ++ix) {
T x = roi_x_min + out_width * bin_width +
static_cast<T>(ix + .5f) * bin_width /
static_cast<T>(roi_bin_grid_width);
const int sample_index = iy * roi_bin_grid_width + ix;
int c_rem = channels_num;
bool is_empty = false;
T w1, w2, w3, w4;
int x_low = 0, x_high = 0;
bilinearInterpolate(width, y_tmp, x, &w1, &w2, &w3, &w4, &x_low,
&x_high, y_low, &is_empty);
if (is_empty) {
is_ping_empty = true;
continue;
}
__bang_mul_scalar((T *)nram_tmp1, (T *)nram_grad_output, w1,
channels_align);
__bang_mul_scalar((T *)nram_tmp2, (T *)nram_grad_output, w2,
channels_align);
__bang_mul_scalar((T *)nram_tmp3, (T *)nram_grad_output, w3,
channels_align);
__bang_mul_scalar((T *)nram_tmp4, (T *)nram_grad_output, w4,
channels_align);
__asm__ volatile("sync;");
__bang_atomic_add(
(T *)nram_tmp1,
(T *)(offset_grad_input + (y_low * width + x_low) * channels +
channel_offset),
(T *)nram_tmp1, channels_num);
__bang_atomic_add(
(T *)nram_tmp2,
(T *)(offset_grad_input + (y_low * width + x_high) * channels +
channel_offset),
(T *)nram_tmp2, channels_num);
__bang_atomic_add(
(T *)nram_tmp3,
(T *)(offset_grad_input + (y_high * width + x_low) * channels +
channel_offset),
(T *)nram_tmp3, channels_num);
__bang_atomic_add(
(T *)nram_tmp4,
(T *)(offset_grad_input + (y_high * width + x_high) * channels +
channel_offset),
(T *)nram_tmp4, channels_num);
if (offset != NULL) {
c_slice = nram_limit / FOURSPLIT / type_align * type_align;
c_slice_align = 0;
if (is_ping_empty) {
c_slice = c_slice > c_rem ? c_rem : c_slice;
c_slice_align = CEIL_ALIGN(c_slice, type_align);
__bang_write_value(nram_ping_input, FOURSPLIT * c_slice_align,
(T)0);
__asm__ volatile("sync;");
const T *src_offset1 = offset_input + y_low * width * channels +
x_low * channels + channel_offset;
const T *src_offset2 = offset_input + y_low * width * channels +
x_high * channels + channel_offset;
const T *src_offset3 = offset_input + y_high * width * channels +
x_low * channels + channel_offset;
const T *src_offset4 = offset_input + y_high * width * channels +
x_high * channels + channel_offset;
__memcpy(nram_ping_input, src_offset1, c_slice * sizeof(T),
GDRAM2NRAM);
__memcpy(nram_ping_input + c_slice_align, src_offset2,
c_slice * sizeof(T), GDRAM2NRAM);
__memcpy(nram_ping_input + 2 * c_slice_align, src_offset3,
c_slice * sizeof(T), GDRAM2NRAM);
__memcpy(nram_ping_input + 3 * c_slice_align, src_offset4,
c_slice * sizeof(T), GDRAM2NRAM);
is_ping_empty = false;
}
int c_offset = 0;
int pongc_slice = 0;
int pongc_slice_align = 0;
while (c_rem > 0) {
c_slice = c_slice > c_rem ? c_rem : c_slice;
c_slice_align = CEIL_ALIGN(c_slice, type_align);
if (sample_index + 1 < roi_bin_grid_height * roi_bin_grid_width) {
int iy_tmp = (sample_index + 1) / roi_bin_grid_width;
int ix_tmp = (sample_index + 1) % roi_bin_grid_width;
T y_tmp = roi_y_min + out_height * bin_height +
static_cast<T>(iy_tmp + .5f) * bin_height /
static_cast<T>(roi_bin_grid_height);
T x_tmp = roi_x_min + out_width * bin_width +
static_cast<T>(ix_tmp + .5f) * bin_width /
static_cast<T>(roi_bin_grid_width);
int x_low_tmp = 0, x_high_tmp = 0, y_low_tmp = 0,
y_high_tmp = 0;
if (y_tmp < -1.0 || y_tmp > height) {
is_empty = true;
} else {
T w1_tmp, w2_tmp, w3_tmp, w4_tmp;
if (y_tmp <= 0) {
y_tmp = 0;
}
y_low_tmp = int(y_tmp);
if (y_low_tmp >= height - 1) {
y_high_tmp = y_low_tmp = height - 1;
y_tmp = T(y_low_tmp);
} else {
y_high_tmp = y_low_tmp + 1;
}
bilinearInterpolate(width, y_tmp, x_tmp, &w1_tmp, &w2_tmp,
&w3_tmp, &w4_tmp, &x_low_tmp, &x_high_tmp,
y_low_tmp, &is_empty);
}
pongc_slice = nram_limit / FOURSPLIT / type_align * type_align;
pongc_slice =
pongc_slice > channels_num ? channels_num : pongc_slice;
pongc_slice_align = CEIL_ALIGN(pongc_slice, type_align);
__bang_write_value(nram_pong_input,
FOURSPLIT * pongc_slice_align, (T)0);
__asm__ volatile("sync;");
if (!is_empty) {
const T *src_offset1 = offset_input +
y_low_tmp * width * channels +
x_low_tmp * channels + channel_offset;
const T *src_offset2 = offset_input +
y_low_tmp * width * channels +
x_high_tmp * channels + channel_offset;
const T *src_offset3 = offset_input +
y_high_tmp * width * channels +
x_low_tmp * channels + channel_offset;
const T *src_offset4 = offset_input +
y_high_tmp * width * channels +
x_high_tmp * channels + channel_offset;
__memcpy_async(nram_pong_input, src_offset1,
pongc_slice * sizeof(T), GDRAM2NRAM);
__memcpy_async(nram_pong_input + pongc_slice_align,
src_offset2, pongc_slice * sizeof(T),
GDRAM2NRAM);
__memcpy_async(nram_pong_input + 2 * pongc_slice_align,
src_offset3, pongc_slice * sizeof(T),
GDRAM2NRAM);
__memcpy_async(nram_pong_input + 3 * pongc_slice_align,
src_offset4, pongc_slice * sizeof(T),
GDRAM2NRAM);
}
}
__bang_mul_scalar(nram_tmp1, nram_ping_input + 3 * c_slice_align,
y - y_low, c_slice_align);
__bang_mul_scalar(nram_tmp2, nram_ping_input + c_slice_align,
y_high - y, c_slice_align);
__bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align);
__bang_mul_scalar(nram_tmp2, nram_ping_input + 2 * c_slice_align,
y_low - y, c_slice_align);
__bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align);
__bang_mul_scalar(nram_tmp2, nram_ping_input, y - y_high,
c_slice_align);
__bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align);
__bang_mul_scalar(nram_tmp1, nram_tmp1, gamma * roi_width,
c_slice_align);
__bang_mul(nram_tmp1, nram_grad_output, nram_tmp1, c_slice_align);
const int32_t kernel_width =
c_slice_align / nram_sum_tmp_channel +
(int32_t)(c_slice_align % nram_sum_tmp_channel > 0);
__bang_sumpool(nram_sum_tmp, nram_tmp1, nram_sum_tmp_channel, 1,
kernel_width, 1, kernel_width, kernel_width, 1);
__bang_reduce_sum(nram_sum_tmp, nram_sum_tmp,
nram_sum_tmp_channel);
__bang_atomic_add(
(T *)nram_sum_tmp,
(T *)(grad_offset +
out_batch * pooled_width * pooled_height * 2 +
out_height * pooled_width + out_width),
(T *)nram_sum_tmp, 1);
__bang_write_value((T *)nram_sum_tmp, nram_sum_tmp_channel, (T)0);
__bang_mul_scalar(nram_tmp1, nram_ping_input + 3 * c_slice_align,
x - x_low, c_slice_align);
__bang_mul_scalar(nram_tmp2, nram_ping_input + 2 * c_slice_align,
x_high - x, c_slice_align);
__bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align);
__bang_mul_scalar(nram_tmp2, nram_ping_input + c_slice_align,
x_low - x, c_slice_align);
__bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align);
__bang_mul_scalar(nram_tmp2, nram_ping_input, x - x_high,
c_slice_align);
__bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align);
__bang_mul_scalar(nram_tmp1, nram_tmp1, gamma * roi_height,
c_slice_align);
__bang_mul(nram_tmp1, nram_grad_output, nram_tmp1, c_slice_align);
__bang_sumpool(nram_sum_tmp, nram_tmp1, nram_sum_tmp_channel, 1,
kernel_width, 1, kernel_width, kernel_width, 1);
__bang_reduce_sum(nram_sum_tmp, nram_sum_tmp,
NFU_ALIGN_SIZE / sizeof(T));
__bang_atomic_add(
(T *)nram_sum_tmp,
(T *)(grad_offset +
out_batch * pooled_width * pooled_height * 2 +
pooled_width * pooled_height +
out_height * pooled_width + out_width),
(T *)nram_sum_tmp, 1);
T *nram_tmp = nram_ping_input;
nram_ping_input = nram_pong_input;
nram_pong_input = nram_tmp;
c_rem -= c_slice;
c_offset += c_slice;
__asm__ volatile("sync;");
}
}
}
}
}
}
}
__mlu_global__ void MLUKernelDeformRoIPoolBackward(
cnrtDataType_t data_type, const void *grad_output, const void *input,
const void *rois, const void *offset, void *grad_input, void *grad_offset,
const int channels, const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, const float spatial_scale,
const int sampling_ratio, const float gamma) {
switch (data_type) {
case CNRT_FLOAT16: {
MLUUnion1DeformRoIPoolBackward(
(half *)grad_output, (half *)input, (half *)rois, (half *)offset,
(half *)grad_input, (half *)grad_offset, channels, height, width,
num_rois, pooled_height, pooled_width,
static_cast<half>(spatial_scale), sampling_ratio,
static_cast<half>(gamma));
}; break;
case CNRT_FLOAT32: {
MLUUnion1DeformRoIPoolBackward(
(float *)grad_output, (float *)input, (float *)rois, (float *)offset,
(float *)grad_input, (float *)grad_offset, channels, height, width,
num_rois, pooled_height, pooled_width,
static_cast<float>(spatial_scale), sampling_ratio,
static_cast<float>(gamma));
}; break;
default: {
break;
}
}
}
void KernelDeformRoIPoolBackward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
cnrtDataType_t data_type, const void *grad_output, const void *input,
const void *rois, const void *offset, void *grad_input, void *grad_offset,
const int channels, const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, const float spatial_scale,
const int sampling_ratio, const float gamma) {
MLUKernelDeformRoIPoolBackward<<<k_dim, k_type, queue>>>(
data_type, grad_output, input, rois, offset, grad_input, grad_offset,
channels, height, width, num_rois, pooled_height, pooled_width,
spatial_scale, sampling_ratio, gamma);
}

View File

@ -9,254 +9,59 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
void KernelDeformRoIPoolForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, cnrtDataType_t data_type,
const void *input, const void *rois,
const void *offset, void *output,
const int channels, const int height,
const int width, const int num_rois,
const int pooled_height, const int pooled_width,
const float spatial_scale,
const int sampling_ratio, const float gamma);
void KernelDeformRoIPoolBackward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
cnrtDataType_t data_type, const void *grad_output, const void *input,
const void *rois, const void *offset, void *grad_input, void *grad_offset,
const int channels, const int height, const int width, const int num_rois,
const int pooled_height, const int pooled_width, const float spatial_scale,
const int sampling_ratio, const float gamma);
// policy function for forward and backward
static void policyFunc(const int bin_num, cnrtDim3_t *k_dim,
cnrtFunctionType_t *k_type) {
const size_t cluster_limit = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
;
const size_t core_limit = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
const size_t bin_num_align = CEIL_ALIGN(bin_num, core_limit);
k_dim->x = core_limit;
k_dim->y = (bin_num_align / core_limit) > cluster_limit
? cluster_limit
: (bin_num_align / core_limit);
k_dim->z = 1;
*k_type = CNRT_FUNC_TYPE_UNION1;
}
#include "mlu_common_helper.h"
void DeformRoIPoolForwardMLUKernelLauncher(Tensor input, Tensor rois,
Tensor offset, Tensor output,
int pooled_height, int pooled_width,
float spatial_scale,
int sampling_ratio, float gamma) {
// Check dtype.
TORCH_CHECK(
input.scalar_type() == at::kFloat || input.scalar_type() == at::kHalf,
"input type should be Float or Half, got ", input.scalar_type());
TORCH_CHECK(input.scalar_type() == rois.scalar_type(),
"rois should have the same type as input");
// Check shape.
TORCH_CHECK(input.dim() == 4, "input should be 4d tensor, got ", input.dim(),
"D.");
TORCH_CHECK(rois.dim() == 2, "rois should be 2d tensor, got ", rois.dim(),
"D.");
if (offset.defined() && offset.numel() > 0) {
TORCH_CHECK(input.scalar_type() == offset.scalar_type(),
"offset should have the same type as input");
TORCH_CHECK(offset.dim() == 4, "offset should be 4d tensor, got ",
offset.dim(), "D.");
TORCH_CHECK(
(offset.size(0) == rois.size(0)), "offset.size(0) = ", offset.size(0),
"while rois.size(0)) = ", rois.size(0), ". They should be the same.");
TORCH_CHECK((offset.size(1) == 2), "offset.size(1) should be 2, ",
"but now offset.size(1) = ", offset.size(1), ".");
TORCH_CHECK((offset.size(2) == output.size(2)),
"offset.size(2) = ", offset.size(2),
"while output.size(2)) = ", output.size(2),
". They should be the same.");
TORCH_CHECK((offset.size(3) == output.size(3)),
"offset.size(3) = ", offset.size(3),
"while output.size(3)) = ", output.size(3),
". They should be the same.");
}
TORCH_CHECK(spatial_scale > 0 && spatial_scale <= 1,
"spatial_scale should be within (0, 1], got ", spatial_scale,
".");
// compute kernel params
auto height = input.size(2);
auto width = input.size(3);
auto channels = input.size(1);
auto num_rois = output.size(0);
if (output.numel() == 0) {
output = at::zeros({num_rois, channels, pooled_height, pooled_width},
input.options());
return;
}
// zero element check
TORCH_CHECK(input.size(0) != 0, "input.size(0) should not be zero, got ",
input.size(0));
TORCH_CHECK(rois.numel() != 0, "rois.numel() should not be zero, got ",
rois.numel());
if (input.numel() == 0 || output.numel() == 0) {
return;
}
// large tensor check
const size_t max_input_num = 2147483648; // 2^31, 2G num
TORCH_CHECK(input.numel() < max_input_num,
"input.numel() should be less than 2147483648, got ",
input.numel());
TORCH_CHECK(rois.numel() < max_input_num,
"rois.numel() should be less than 2147483648, got ",
rois.numel());
TORCH_CHECK(output.numel() < max_input_num,
"output.numel() should be less than 2147483648, got ",
output.numel());
TORCH_CHECK(!offset.defined() || offset.numel() < max_input_num,
"offset.numel() should be less than 2147483648, got ",
offset.numel());
auto memory_format =
torch_mlu::cnnl::ops::get_channels_last_memory_format(input.dim());
auto input_ = torch_mlu::cnnl::ops::cnnl_contiguous(input, memory_format);
auto rois_contiguous =
torch_mlu::cnnl::ops::cnnl_contiguous(rois, rois.suggest_memory_format());
auto output_contiguous =
torch_mlu::cnnl::ops::cnnl_contiguous(output, memory_format);
at::Tensor output_ =
at::empty({num_rois, channels, pooled_height, pooled_width},
input.options(), memory_format);
MluOpTensorDescriptor input_desc, rois_desc, offset_desc, output_desc;
input_desc.set_with_layout(input_, MLUOP_LAYOUT_NHWC);
rois_desc.set(rois_contiguous);
output_desc.set_with_layout(output_contiguous, MLUOP_LAYOUT_NHWC);
// calculate task dimension
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
policyFunc(num_rois * pooled_height * pooled_width, &k_dim, &k_type);
// get compute queue
auto queue = torch_mlu::getCurQueue();
mluOpTensorDescriptor_t offset_real_desc = NULL;
void *offset_ptr = NULL;
if (offset.defined() && offset.numel() > 0) {
auto offset_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
offset, offset.suggest_memory_format());
offset_desc.set(offset_contiguous);
offset_real_desc = offset_desc.desc();
auto offset_impl = torch_mlu::getMluTensorImpl(offset_contiguous);
offset_ptr = offset_impl->cnnlMalloc();
}
// get ptr of tensors
auto input_impl = torch_mlu::getMluTensorImpl(input_);
auto input_ptr = input_impl->cnnlMalloc();
auto rois_impl = torch_mlu::getMluTensorImpl(rois);
auto rois_impl = torch_mlu::getMluTensorImpl(rois_contiguous);
auto rois_ptr = rois_impl->cnnlMalloc();
auto offset_impl = torch_mlu::getMluTensorImpl(offset);
auto offset_ptr = offset_impl->cnnlMalloc();
auto output_impl = torch_mlu::getMluTensorImpl(output_);
auto output_impl = torch_mlu::getMluTensorImpl(output_contiguous);
auto output_ptr = output_impl->cnnlMalloc();
// get comput dtype of input
cnrtDataType_t data_type = torch_mlu::toCnrtDtype(input_.dtype());
// get compute handle
auto handle = mluOpGetCurrentHandle();
mluOpDeformRoiPoolForward(
handle, input_desc.desc(), input_ptr, rois_desc.desc(), rois_ptr,
offset_real_desc, offset_ptr, pooled_height, pooled_width, spatial_scale,
sampling_ratio, gamma, output_desc.desc(), output_ptr);
// launch kernel
CNLOG(INFO) << "Launch Kernel MLUKernelDeformRoIPoolForward<<<" << k_dim.x
<< ", " << k_dim.y << ", " << k_dim.z << ">>>";
KernelDeformRoIPoolForward(k_dim, k_type, queue, data_type, input_ptr,
rois_ptr, offset_ptr, output_ptr, channels, height,
width, num_rois, pooled_height, pooled_width,
spatial_scale, sampling_ratio, gamma);
output.copy_(output_);
output.copy_(output_contiguous);
}
void DeformRoIPoolBackwardMLUKernelLauncher(
Tensor grad_output, Tensor input, Tensor rois, Tensor offset,
Tensor grad_input, Tensor grad_offset, int pooled_height, int pooled_width,
float spatial_scale, int sampling_ratio, float gamma) {
// Check dtype.
TORCH_CHECK(
input.scalar_type() == at::kFloat || input.scalar_type() == at::kHalf,
"input type should be Float or Half, got ", input.scalar_type());
TORCH_CHECK(input.scalar_type() == grad_output.scalar_type(),
"grad_output should have the same type as input");
TORCH_CHECK(input.scalar_type() == rois.scalar_type(),
"rois should have the same type as input");
TORCH_CHECK(input.scalar_type() == grad_input.scalar_type(),
"grad_input should have the same type as input");
// Check shape.
TORCH_CHECK(grad_output.dim() == 4, "grad_output should be 4d tensor, got ",
grad_output.dim(), "D.");
TORCH_CHECK(input.dim() == 4, "input should be 4d tensor, got ", input.dim(),
"D.");
TORCH_CHECK(rois.dim() == 2, "rois should be 2d tensor, got ", rois.dim(),
"D.");
if (offset.defined() && offset.numel() > 0) {
TORCH_CHECK(input.scalar_type() == offset.scalar_type(),
"offset should have the same type as input");
TORCH_CHECK(offset.dim() == 4, "offset should be 4d tensor, got ",
offset.dim(), "D.");
TORCH_CHECK(
(offset.size(0) == rois.size(0)), "offset.size(0) = ", offset.size(0),
"while rois.size(0)) = ", rois.size(0), ". They should be the same.");
TORCH_CHECK((offset.size(1) == 2), "offset.size(1) should be 2, ",
"but now offset.size(1) = ", offset.size(1), ".");
TORCH_CHECK((offset.size(2) == grad_output.size(2)),
"offset.size(2) = ", offset.size(2),
"while grad_output.size(2)) = ", grad_output.size(2),
". They should be the same.");
TORCH_CHECK((offset.size(3) == grad_output.size(3)),
"offset.size(3) = ", offset.size(3),
"while grad_output.size(3)) = ", grad_output.size(3),
". They should be the same.");
}
TORCH_CHECK(spatial_scale > 0 && spatial_scale <= 1,
"spatial_scale should be within (0, 1], got ", spatial_scale);
// Check relationship between tensor.
TORCH_CHECK((grad_output.size(0) == rois.size(0)),
"grad_output.size(0) = ", grad_output.size(0),
"while rois.size(0)) = ", rois.size(0),
". They should be the same.");
TORCH_CHECK((grad_output.size(1) == input.size(1)),
"grad_output.size(1) = ", grad_output.size(1),
"while input.size(1)) = ", input.size(1),
". They should be the same.");
TORCH_CHECK((grad_output.size(2) == pooled_height),
"grad_output.size(2) = ", grad_output.size(2),
"while pooled_height = ", pooled_height,
". They should be the same.");
TORCH_CHECK((grad_output.size(3) == pooled_width),
"grad_output.size(3) = ", grad_output.size(3),
"while pooled_width = ", pooled_width,
". They should be the same.");
// compute kernel params
auto batch = input.size(0);
auto channels = input.size(1);
auto height = input.size(2);
auto width = input.size(3);
auto num_rois = grad_output.size(0);
// zero element check
TORCH_CHECK(input.size(0) != 0, "input.size(0) should not be zero, got ",
input.size(0));
TORCH_CHECK(rois.numel() != 0, "rois.numel() should not be zero, got ",
rois.numel());
if (input.numel() == 0 || grad_output.numel() == 0) {
return;
}
// large tensor check
const size_t max_input_num = 2147483648; // 2^31, 2G num
TORCH_CHECK(input.numel() < max_input_num,
"input.numel() should be less than 2147483648, got ",
input.numel());
TORCH_CHECK(rois.numel() < max_input_num,
"rois.numel() should be less than 2147483648, got ",
rois.numel());
TORCH_CHECK(grad_output.numel() < max_input_num,
"grad_output.numel() should be less than 2147483648, got ",
grad_output.numel());
TORCH_CHECK(!offset.defined() || offset.numel() < max_input_num,
"offset.numel() should be less than 2147483648, got ",
offset.numel());
auto memory_format =
torch_mlu::cnnl::ops::get_channels_last_memory_format(grad_output.dim());
auto grad_output_ =
@ -264,45 +69,56 @@ void DeformRoIPoolBackwardMLUKernelLauncher(
memory_format =
torch_mlu::cnnl::ops::get_channels_last_memory_format(input.dim());
auto input_ = torch_mlu::cnnl::ops::cnnl_contiguous(input, memory_format);
at::Tensor grad_input_ = at::empty({batch, channels, height, width},
input.options(), memory_format)
.zero_();
// calculate task dimension
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
policyFunc(num_rois * pooled_height * pooled_width, &k_dim, &k_type);
// get compute queue
auto queue = torch_mlu::getCurQueue();
auto rois_contiguous =
torch_mlu::cnnl::ops::cnnl_contiguous(rois, rois.suggest_memory_format());
auto grad_input_ =
torch_mlu::cnnl::ops::cnnl_contiguous(grad_input, memory_format);
// get ptr of tensors
auto grad_output_impl = torch_mlu::getMluTensorImpl(grad_output_);
auto grad_output_ptr = grad_output_impl->cnnlMalloc();
auto input_impl = torch_mlu::getMluTensorImpl(input_);
auto input_ptr = input_impl->cnnlMalloc();
auto rois_impl = torch_mlu::getMluTensorImpl(rois);
auto rois_impl = torch_mlu::getMluTensorImpl(rois_contiguous);
auto rois_ptr = rois_impl->cnnlMalloc();
auto offset_impl = torch_mlu::getMluTensorImpl(offset);
auto offset_ptr = offset_impl->cnnlMalloc();
auto grad_input_impl = torch_mlu::getMluTensorImpl(grad_input_);
auto grad_input_ptr = grad_input_impl->cnnlMalloc();
auto grad_offset_impl = torch_mlu::getMluTensorImpl(grad_offset);
auto grad_offset_ptr = grad_offset_impl->cnnlMalloc();
// get comput dtype of input
cnrtDataType_t data_type = torch_mlu::toCnrtDtype(input.dtype());
// launch kernel
CNLOG(INFO) << "Launch Kernel KernelDeformRoIPoolBackward<<<" << k_dim.x
<< ", " << k_dim.y << ", " << k_dim.z << ">>>";
KernelDeformRoIPoolBackward(k_dim, k_type, queue, data_type, grad_output_ptr,
input_ptr, rois_ptr, offset_ptr, grad_input_ptr,
grad_offset_ptr, channels, height, width,
num_rois, pooled_height, pooled_width,
spatial_scale, sampling_ratio, gamma);
MluOpTensorDescriptor grad_output_desc, input_desc, rois_desc, offset_desc,
grad_input_desc, grad_offset_desc;
grad_output_desc.set_with_layout(grad_output_, MLUOP_LAYOUT_NHWC);
input_desc.set_with_layout(input_, MLUOP_LAYOUT_NHWC);
rois_desc.set(rois_contiguous);
grad_input_desc.set_with_layout(grad_input_, MLUOP_LAYOUT_NHWC);
mluOpTensorDescriptor_t offset_real_desc = NULL;
void *offset_ptr = NULL;
if (offset.defined() && offset.numel() > 0) {
auto offset_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
offset, offset.suggest_memory_format());
offset_desc.set(offset_contiguous);
offset_real_desc = offset_desc.desc();
auto offset_impl = torch_mlu::getMluTensorImpl(offset_contiguous);
offset_ptr = offset_impl->cnnlMalloc();
}
mluOpTensorDescriptor_t grad_offset_real_desc = NULL;
void *grad_offset_ptr = NULL;
if (grad_offset.defined() && grad_offset.numel() > 0) {
auto grad_offset_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
grad_offset, grad_offset.suggest_memory_format());
grad_offset_desc.set(grad_offset_contiguous);
grad_offset_real_desc = grad_offset_desc.desc();
auto grad_offset_impl = torch_mlu::getMluTensorImpl(grad_offset_contiguous);
grad_offset_ptr = grad_offset_impl->cnnlMalloc();
}
// get compute handle
auto handle = mluOpGetCurrentHandle();
mluOpDeformRoiPoolBackward(
handle, grad_output_desc.desc(), grad_output_ptr, input_desc.desc(),
input_ptr, rois_desc.desc(), rois_ptr, offset_real_desc, offset_ptr,
pooled_height, pooled_width, spatial_scale, sampling_ratio, gamma,
grad_input_desc.desc(), grad_input_ptr, grad_offset_real_desc,
grad_offset_ptr);
grad_input.copy_(grad_input_);
}

View File

@ -72,6 +72,39 @@ void MluOpTensorDescriptor::set(Tensor t) {
set_desc(t, layout, data_type, dim_array);
}
void MluOpTensorDescriptor::set_with_layout(Tensor t,
mluOpTensorLayout_t layout) {
mluOpDataType_t data_type = getMluOpDataType(t.dtype());
int t_dim = t.dim();
std::vector<int> shape_info = checkUpperBoundAndCastTo<int>(t.sizes().vec());
std::vector<int> stride_info =
checkUpperBoundAndCastTo<int>(t.strides().vec());
if (layout == MLUOP_LAYOUT_NHWC || layout == MLUOP_LAYOUT_NDHWC ||
layout == MLUOP_LAYOUT_NLC) {
convertShapeAndStride(shape_info, stride_info);
} else if (layout == MLUOP_LAYOUT_HWCN) {
auto convertDepthWiseConvShapeStride = [](const std::vector<int64_t>& vec,
std::vector<int>& target_vec,
std::vector<int>& stride_vec) {
// NCHW --> HWCN
target_vec[0] = static_cast<int>(vec[2]);
target_vec[1] = static_cast<int>(vec[3]);
target_vec[2] = static_cast<int>(vec[1]);
target_vec[3] = static_cast<int>(vec[0]);
// Calculate Stride just like contiguous of HWCN.
stride_vec[3] = 1;
stride_vec[2] = target_vec[3] * stride_vec[3];
stride_vec[1] = target_vec[2] * stride_vec[2];
stride_vec[0] = target_vec[1] * stride_vec[1];
};
convertDepthWiseConvShapeStride(t.sizes().vec(), shape_info, stride_info);
}
TORCH_CHECK(mluOpSetTensorDescriptorEx(
desc_, layout, data_type, t_dim, shape_info.data(),
stride_info.data()) == MLUOP_STATUS_SUCCESS,
"mluOpSetTensorDescriptorEx execution failed.");
}
void MluOpTensorDescriptor::set_desc(const at::Tensor& t,
mluOpTensorLayout_t layout,
mluOpDataType_t dtype,

View File

@ -30,6 +30,7 @@ class MluOpTensorDescriptor {
~MluOpTensorDescriptor() { mluOpDestroyTensorDescriptor(desc_); }
void set(at::Tensor);
void set_with_layout(at::Tensor, mluOpTensorLayout_t layout);
mluOpTensorDescriptor_t desc() { return desc_; }
private:
@ -52,3 +53,47 @@ class MluOpHandle {
void setQueue(cnrtQueue_t queue) { mluOpSetQueue(handle, queue); }
mluOpHandle_t handle;
};
// modify tensor size and stride order based on
// channels_first to channels_last or channels_last_3d.
// which this is not same with pytorch original layout,
// this real layout is based on data storage real order.
// example: modify channels_last tensor dim to nhwc tensor desc.
// N C H W --> N H W C
// C*H*W 1 W C --> C*H*W W C 1
template <typename T>
void convertShapeAndStride(std::vector<T>& shape_info,
std::vector<T>& stride_info) {
TORCH_MLU_CHECK(shape_info.size() == stride_info.size(),
"shape size need equal to stride size.");
const int dim = shape_info.size();
std::vector<T> temp_shape_info(dim);
std::vector<T> temp_stride_info(dim);
temp_shape_info[0] = shape_info[0];
temp_stride_info[0] = stride_info[0];
for (size_t i = 0; i < dim - 1; ++i) {
const int index = (i + 1) % (dim - 1) + 1;
temp_shape_info[i + 1] = shape_info[index];
temp_stride_info[i + 1] = stride_info[index];
}
shape_info.assign(temp_shape_info.begin(), temp_shape_info.end());
stride_info.assign(temp_stride_info.begin(), temp_stride_info.end());
}
// torch tensor provides int64_t type of shape and stride,
// but mluops descriptor requires type int32.
// use this function to ensure safe CAST, or report an error.
template <typename DST_T, typename SRC_T>
std::vector<DST_T> checkUpperBoundAndCastTo(const std::vector<SRC_T>& input) {
std::vector<DST_T> output;
output.reserve(input.size());
for (const auto& val : input) {
if (val > std::numeric_limits<DST_T>::max()) {
TORCH_MLU_CHECK(false, "Requires dim size not greater than ",
std::numeric_limits<DST_T>::max(), ". But got ", val,
".");
}
output.push_back(static_cast<DST_T>(val));
}
return output;
}