mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Support DeformRoiPool with cambricon MLU backend (#2137)
* [Feature] Support DeformRoiPool with cambricon MLU backend * [Fix] Remove use of std library * [Fix] Correct the error information * [Refactor] Refactor test deform_roi_pool code * [Fix] Fix judgment error * [Fix] Modify the large tensor check Co-authored-by: budefei <budefei@cambricon.com>pull/2349/head
parent
1c1964cbd5
commit
a364e6cad2
|
@ -19,7 +19,7 @@ We implement common ops used in detection, segmentation, etc.
|
|||
| CornerPool | | √ | | |
|
||||
| Correlation | | √ | | |
|
||||
| Deformable Convolution v1/v2 | √ | √ | | |
|
||||
| Deformable RoIPool | | √ | | |
|
||||
| Deformable RoIPool | | √ | √ | |
|
||||
| DiffIoURotated | | √ | | |
|
||||
| DynamicScatter | | √ | | |
|
||||
| FurthestPointSample | | √ | | |
|
||||
|
|
|
@ -19,7 +19,7 @@ MMCV 提供了检测、分割等任务中常用的算子
|
|||
| CornerPool | | √ | | |
|
||||
| Correlation | | √ | | |
|
||||
| Deformable Convolution v1/v2 | √ | √ | | |
|
||||
| Deformable RoIPool | | √ | | |
|
||||
| Deformable RoIPool | | √ | √ | |
|
||||
| DiffIoURotated | | √ | | |
|
||||
| DynamicScatter | | √ | | |
|
||||
| FurthestPointSample | | √ | | |
|
||||
|
|
|
@ -0,0 +1,712 @@
|
|||
/*************************************************************************
|
||||
* Copyright (C) 2022 Cambricon.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include <iostream>
|
||||
|
||||
#include "common_mlu_helper.hpp"
|
||||
|
||||
#define ROI_OFFSET 5
|
||||
#define FOURSPLIT 4
|
||||
#define FIVESPLIT 5
|
||||
#define NINESPLIT 9
|
||||
#define THIRTEENSPLIT 13
|
||||
|
||||
__nram__ char nram_buffer[MAX_NRAM_SIZE];
|
||||
|
||||
template <typename T>
|
||||
static __mlu_func__ void bilinearInterpolate(const int input_width, T y, T x,
|
||||
T *w1, T *w2, T *w3, T *w4,
|
||||
int *x_low, int *x_high,
|
||||
const int y_low, bool *is_empty) {
|
||||
if (x < -1.0 || x > input_width) {
|
||||
*is_empty = true;
|
||||
return;
|
||||
}
|
||||
|
||||
if (x <= 0) x = 0;
|
||||
|
||||
*x_low = int(x);
|
||||
|
||||
if (*x_low >= input_width - 1) {
|
||||
*x_high = *x_low = input_width - 1;
|
||||
x = T(*x_low);
|
||||
} else {
|
||||
*x_high = *x_low + 1;
|
||||
}
|
||||
|
||||
T ly = y - y_low;
|
||||
T lx = x - *x_low;
|
||||
T hy = 1.0 - ly;
|
||||
T hx = 1.0 - lx;
|
||||
*w1 = hy * hx;
|
||||
*w2 = hy * lx;
|
||||
*w3 = ly * hx;
|
||||
*w4 = ly * lx;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void MLUUnion1DeformRoIPoolForward(
|
||||
const T *input, const T *rois, const T *offset, T *output,
|
||||
const int channels, const int height, const int width, const int num_rois,
|
||||
const int pooled_height, const int pooled_width, const T spatial_scale,
|
||||
const int sampling_ratio, const T gamma) {
|
||||
for (int bin_index = taskId;
|
||||
bin_index < num_rois * pooled_width * pooled_height;
|
||||
bin_index += taskDim) {
|
||||
int out_batch = bin_index / pooled_width / pooled_height;
|
||||
int out_height = bin_index / pooled_width % pooled_height;
|
||||
int out_width = bin_index % pooled_width;
|
||||
const T *cur_roi = rois + out_batch * ROI_OFFSET;
|
||||
T *nram_rois = (T *)nram_buffer;
|
||||
__memcpy((void *)nram_rois, (void *)cur_roi, ROI_OFFSET * sizeof(T),
|
||||
GDRAM2NRAM);
|
||||
const int roi_batch = nram_rois[0];
|
||||
T roi_x_min = nram_rois[1] * spatial_scale - 0.5;
|
||||
T roi_y_min = nram_rois[2] * spatial_scale - 0.5;
|
||||
const T roi_x_max = nram_rois[3] * spatial_scale - 0.5;
|
||||
const T roi_y_max = nram_rois[4] * spatial_scale - 0.5;
|
||||
const T roi_width = roi_x_max - roi_x_min;
|
||||
const T roi_height = roi_y_max - roi_y_min;
|
||||
const T bin_width = roi_width / static_cast<T>(pooled_width);
|
||||
const T bin_height = roi_height / static_cast<T>(pooled_height);
|
||||
const T *offset_input = input + roi_batch * height * width * channels;
|
||||
int roi_bin_grid_height =
|
||||
(sampling_ratio > 0)
|
||||
? sampling_ratio
|
||||
: static_cast<int>(ceilf(roi_height / pooled_height));
|
||||
int roi_bin_grid_width =
|
||||
(sampling_ratio > 0)
|
||||
? sampling_ratio
|
||||
: static_cast<int>(ceilf(roi_width / pooled_width));
|
||||
if (offset != NULL) {
|
||||
const T *offset_cur = offset +
|
||||
out_batch * pooled_width * pooled_height * 2 +
|
||||
out_height * pooled_width + out_width;
|
||||
roi_x_min += gamma * roi_width * offset_cur[0];
|
||||
roi_y_min +=
|
||||
gamma * roi_height * offset_cur[pooled_width * pooled_height];
|
||||
}
|
||||
int type_align = NFU_ALIGN_SIZE / sizeof(T);
|
||||
int channels_max_num_nram = MAX_NRAM_SIZE / sizeof(T);
|
||||
int channels_nram_split =
|
||||
channels_max_num_nram / NINESPLIT / type_align * type_align;
|
||||
int channel_rem = channels % channels_nram_split;
|
||||
int channel_loops =
|
||||
channels / channels_nram_split + (channel_rem != 0 ? 1 : 0);
|
||||
for (int channel_loop_index = 0; channel_loop_index < channel_loops;
|
||||
++channel_loop_index) {
|
||||
int channels_num =
|
||||
channels_nram_split >= channels ? channels : channels_nram_split;
|
||||
const int channel_offset = channel_loop_index * channels_num;
|
||||
if (channel_loop_index + 1 == channel_loops && channel_rem != 0) {
|
||||
channels_num = channel_rem;
|
||||
}
|
||||
int channels_align = CEIL_ALIGN(channels_num, type_align);
|
||||
int nram_limit = (MAX_NRAM_SIZE / sizeof(T) - channels_align) >> 1;
|
||||
int c_slice = nram_limit / FOURSPLIT / type_align * type_align;
|
||||
int c_slice_align = 0;
|
||||
|
||||
/* NRAM partition
|
||||
*
|
||||
* | | ping | pong |
|
||||
* |----------|-------------------|-------------------|
|
||||
* | nram_out | p1 | p2 | p3 | p4 | p1 | p2 | p3 | p4 |
|
||||
*
|
||||
*/
|
||||
|
||||
T *nram_out = (T *)nram_buffer;
|
||||
T *nram_ping = nram_out + channels_align;
|
||||
T *nram_pong = nram_ping + nram_limit;
|
||||
__bang_write_value((T *)nram_out, channels_align, (T)0);
|
||||
__bang_write_value((T *)nram_ping, FOURSPLIT * c_slice, (T)0);
|
||||
__bang_write_value((T *)nram_pong, FOURSPLIT * c_slice, (T)0);
|
||||
const T num_bins =
|
||||
static_cast<T>(max(roi_bin_grid_height * roi_bin_grid_width, 1));
|
||||
const T value_div = 1.0f / num_bins;
|
||||
bool is_ping_empty = true;
|
||||
for (int iy = 0; iy < roi_bin_grid_height; ++iy) {
|
||||
T y = roi_y_min + out_height * bin_height +
|
||||
static_cast<T>(iy + .5f) * bin_height /
|
||||
static_cast<T>(roi_bin_grid_height);
|
||||
if (y < -1.0 || y > height) {
|
||||
is_ping_empty = true;
|
||||
continue;
|
||||
}
|
||||
if (y <= 0) {
|
||||
y = 0;
|
||||
}
|
||||
int y_low = 0, y_high = 0;
|
||||
y_low = int(y);
|
||||
if (y_low >= height - 1) {
|
||||
y_high = y_low = height - 1;
|
||||
y = T(y_low);
|
||||
} else {
|
||||
y_high = y_low + 1;
|
||||
}
|
||||
for (int ix = 0; ix < roi_bin_grid_width; ++ix) {
|
||||
T x = roi_x_min + out_width * bin_width +
|
||||
static_cast<T>(ix + .5f) * bin_width /
|
||||
static_cast<T>(roi_bin_grid_width);
|
||||
const int sample_index = iy * roi_bin_grid_width + ix;
|
||||
int c_rem = channels_num;
|
||||
c_slice = nram_limit / FOURSPLIT / type_align * type_align;
|
||||
c_slice_align = 0;
|
||||
bool is_empty = false;
|
||||
T w1, w2, w3, w4;
|
||||
int x_low = 0, x_high = 0;
|
||||
bilinearInterpolate(width, y, x, &w1, &w2, &w3, &w4, &x_low, &x_high,
|
||||
y_low, &is_empty);
|
||||
if (is_empty) {
|
||||
is_ping_empty = true;
|
||||
continue;
|
||||
}
|
||||
if (is_ping_empty) {
|
||||
c_slice = c_slice > c_rem ? c_rem : c_slice;
|
||||
c_slice_align = CEIL_ALIGN(c_slice, type_align);
|
||||
__bang_write_value(nram_ping, FOURSPLIT * c_slice_align, (T)0);
|
||||
__asm__ volatile("sync;");
|
||||
__memcpy(nram_ping,
|
||||
offset_input + y_low * width * channels +
|
||||
x_low * channels + channel_offset,
|
||||
c_slice * sizeof(T), GDRAM2NRAM);
|
||||
__memcpy(nram_ping + c_slice_align,
|
||||
offset_input + y_low * width * channels +
|
||||
x_high * channels + channel_offset,
|
||||
c_slice * sizeof(T), GDRAM2NRAM);
|
||||
__memcpy(nram_ping + 2 * c_slice_align,
|
||||
offset_input + y_high * width * channels +
|
||||
x_low * channels + channel_offset,
|
||||
c_slice * sizeof(T), GDRAM2NRAM);
|
||||
__memcpy(nram_ping + 3 * c_slice_align,
|
||||
offset_input + y_high * width * channels +
|
||||
x_high * channels + channel_offset,
|
||||
c_slice * sizeof(T), GDRAM2NRAM);
|
||||
is_ping_empty = false;
|
||||
}
|
||||
int c_offset = 0;
|
||||
int pongc_slice = 0;
|
||||
int pongc_slice_align = 0;
|
||||
while (c_rem > 0) {
|
||||
c_slice = c_slice > c_rem ? c_rem : c_slice;
|
||||
c_slice_align = CEIL_ALIGN(c_slice, type_align);
|
||||
if (sample_index + 1 < roi_bin_grid_height * roi_bin_grid_width) {
|
||||
int iy_tmp = (sample_index + 1) / roi_bin_grid_width;
|
||||
int ix_tmp = (sample_index + 1) % roi_bin_grid_width;
|
||||
y = roi_y_min + out_height * bin_height +
|
||||
static_cast<T>(iy_tmp + .5f) * bin_height /
|
||||
static_cast<T>(roi_bin_grid_height);
|
||||
x = roi_x_min + out_width * bin_width +
|
||||
static_cast<T>(ix_tmp + .5f) * bin_width /
|
||||
static_cast<T>(roi_bin_grid_width);
|
||||
if (y < -1.0 || y > height) {
|
||||
is_empty = true;
|
||||
} else {
|
||||
T w1_tmp, w2_tmp, w3_tmp, w4_tmp;
|
||||
if (y <= 0) {
|
||||
y = 0;
|
||||
}
|
||||
y_low = int(y);
|
||||
if (y_low >= height - 1) {
|
||||
y_high = y_low = height - 1;
|
||||
y = T(y_low);
|
||||
} else {
|
||||
y_high = y_low + 1;
|
||||
}
|
||||
bilinearInterpolate(width, y, x, &w1_tmp, &w2_tmp, &w3_tmp,
|
||||
&w4_tmp, &x_low, &x_high, y_low, &is_empty);
|
||||
}
|
||||
pongc_slice = nram_limit / FOURSPLIT / type_align * type_align;
|
||||
pongc_slice =
|
||||
pongc_slice > channels_num ? channels_num : pongc_slice;
|
||||
pongc_slice_align = CEIL_ALIGN(pongc_slice, type_align);
|
||||
__bang_write_value(nram_pong, FOURSPLIT * pongc_slice_align,
|
||||
(T)0);
|
||||
__asm__ volatile("sync;");
|
||||
if (!is_empty) {
|
||||
__memcpy_async(nram_pong,
|
||||
offset_input + y_low * width * channels +
|
||||
x_low * channels + channel_offset,
|
||||
pongc_slice * sizeof(T), GDRAM2NRAM);
|
||||
__memcpy_async(nram_pong + pongc_slice_align,
|
||||
offset_input + y_low * width * channels +
|
||||
x_high * channels + channel_offset,
|
||||
pongc_slice * sizeof(T), GDRAM2NRAM);
|
||||
__memcpy_async(nram_pong + 2 * pongc_slice_align,
|
||||
offset_input + y_high * width * channels +
|
||||
x_low * channels + channel_offset,
|
||||
pongc_slice * sizeof(T), GDRAM2NRAM);
|
||||
__memcpy_async(nram_pong + 3 * pongc_slice_align,
|
||||
offset_input + y_high * width * channels +
|
||||
x_high * channels + channel_offset,
|
||||
pongc_slice * sizeof(T), GDRAM2NRAM);
|
||||
}
|
||||
}
|
||||
__bang_mul_scalar(nram_ping, nram_ping, w1, c_slice_align);
|
||||
__bang_mul_scalar(nram_ping + c_slice_align,
|
||||
nram_ping + c_slice_align, w2, c_slice_align);
|
||||
__bang_add(nram_ping, nram_ping, nram_ping + c_slice_align,
|
||||
c_slice_align);
|
||||
__bang_mul_scalar(nram_ping + 2 * c_slice_align,
|
||||
nram_ping + 2 * c_slice_align, w3, c_slice_align);
|
||||
__bang_add(nram_ping, nram_ping, nram_ping + 2 * c_slice_align,
|
||||
c_slice_align);
|
||||
__bang_mul_scalar(nram_ping + 3 * c_slice_align,
|
||||
nram_ping + 3 * c_slice_align, w4, c_slice_align);
|
||||
__bang_add(nram_ping, nram_ping, nram_ping + 3 * c_slice_align,
|
||||
c_slice_align);
|
||||
__bang_add(nram_out + c_offset, nram_out + c_offset, nram_ping,
|
||||
c_slice_align);
|
||||
T *nram_tmp = nram_ping;
|
||||
nram_ping = nram_pong;
|
||||
nram_pong = nram_tmp;
|
||||
c_rem -= c_slice;
|
||||
c_offset += c_slice;
|
||||
__asm__ volatile("sync;");
|
||||
}
|
||||
}
|
||||
}
|
||||
__bang_mul_scalar(nram_out, nram_out, value_div, channels_align);
|
||||
__memcpy(output + channels * bin_index + channel_offset, nram_out,
|
||||
channels_num * sizeof(T), NRAM2GDRAM);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__mlu_global__ void MLUKernelDeformRoIPoolForward(
|
||||
cnrtDataType_t data_type, const void *input, const void *rois,
|
||||
const void *offset, void *output, const int channels, const int height,
|
||||
const int width, const int num_rois, const int pooled_height,
|
||||
const int pooled_width, const float spatial_scale, const int sampling_ratio,
|
||||
const float gamma) {
|
||||
switch (data_type) {
|
||||
case CNRT_FLOAT16: {
|
||||
MLUUnion1DeformRoIPoolForward((half *)input, (half *)rois, (half *)offset,
|
||||
(half *)output, channels, height, width,
|
||||
num_rois, pooled_height, pooled_width,
|
||||
static_cast<half>(spatial_scale),
|
||||
sampling_ratio, static_cast<half>(gamma));
|
||||
}; break;
|
||||
case CNRT_FLOAT32: {
|
||||
MLUUnion1DeformRoIPoolForward(
|
||||
(float *)input, (float *)rois, (float *)offset, (float *)output,
|
||||
channels, height, width, num_rois, pooled_height, pooled_width,
|
||||
static_cast<float>(spatial_scale), sampling_ratio,
|
||||
static_cast<float>(gamma));
|
||||
}; break;
|
||||
default: {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void KernelDeformRoIPoolForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
|
||||
cnrtQueue_t queue, cnrtDataType_t data_type,
|
||||
const void *input, const void *rois,
|
||||
const void *offset, void *output,
|
||||
const int channels, const int height,
|
||||
const int width, const int num_rois,
|
||||
const int pooled_height, const int pooled_width,
|
||||
const float spatial_scale,
|
||||
const int sampling_ratio, const float gamma) {
|
||||
MLUKernelDeformRoIPoolForward<<<k_dim, k_type, queue>>>(
|
||||
data_type, input, rois, offset, output, channels, height, width, num_rois,
|
||||
pooled_height, pooled_width, spatial_scale, sampling_ratio, gamma);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__mlu_func__ void MLUUnion1DeformRoIPoolBackward(
|
||||
const T *grad_output, const T *input, const T *rois, const T *offset,
|
||||
T *grad_input, T *grad_offset, const int channels, const int height,
|
||||
const int width, const int num_rois, const int pooled_height,
|
||||
const int pooled_width, const T spatial_scale, const int sampling_ratio,
|
||||
const T gamma) {
|
||||
for (int bin_index = taskId;
|
||||
bin_index < num_rois * pooled_width * pooled_height;
|
||||
bin_index += taskDim) {
|
||||
int out_batch = bin_index / pooled_width / pooled_height;
|
||||
int out_height = bin_index / pooled_width % pooled_height;
|
||||
int out_width = bin_index % pooled_width;
|
||||
const T *cur_roi = rois + out_batch * ROI_OFFSET;
|
||||
T *nram_rois = (T *)nram_buffer;
|
||||
__memcpy((void *)nram_rois, (void *)cur_roi, ROI_OFFSET * sizeof(T),
|
||||
GDRAM2NRAM);
|
||||
const int roi_batch = nram_rois[0];
|
||||
T roi_x_min = nram_rois[1] * spatial_scale - 0.5;
|
||||
T roi_y_min = nram_rois[2] * spatial_scale - 0.5;
|
||||
const T roi_x_max = nram_rois[3] * spatial_scale - 0.5;
|
||||
const T roi_y_max = nram_rois[4] * spatial_scale - 0.5;
|
||||
const T roi_width = roi_x_max - roi_x_min;
|
||||
const T roi_height = roi_y_max - roi_y_min;
|
||||
const T bin_width = roi_width / static_cast<T>(pooled_width);
|
||||
const T bin_height = roi_height / static_cast<T>(pooled_height);
|
||||
const T *offset_input = input + roi_batch * height * width * channels;
|
||||
T *offset_grad_input = grad_input + roi_batch * height * width * channels;
|
||||
int roi_bin_grid_height =
|
||||
(sampling_ratio > 0)
|
||||
? sampling_ratio
|
||||
: static_cast<int>(ceilf(roi_height / pooled_height));
|
||||
int roi_bin_grid_width =
|
||||
(sampling_ratio > 0)
|
||||
? sampling_ratio
|
||||
: static_cast<int>(ceilf(roi_width / pooled_width));
|
||||
if (offset != NULL) {
|
||||
const T *offset_cur = offset +
|
||||
out_batch * pooled_width * pooled_height * 2 +
|
||||
out_height * pooled_width + out_width;
|
||||
roi_x_min += gamma * roi_width * offset_cur[0];
|
||||
roi_y_min +=
|
||||
gamma * roi_height * offset_cur[pooled_width * pooled_height];
|
||||
}
|
||||
|
||||
/* NRAM partition
|
||||
*
|
||||
* If offset != NULL, NRAM partition belows.
|
||||
* | |
|
||||
* ping | pong |
|
||||
* |---------------------------------------------------------------------|-----------|-----------|
|
||||
* |nram_tmp1|nram_tmp2|nram_tmp3|nram_tmp4|nram_grad_output|nram_sum_tmp|p1|p2|p3|p4|p1|p2|p3|p4|
|
||||
*
|
||||
* If offset == NULL, ping and pang will not be needed.
|
||||
* | |
|
||||
* |----------------------------------------------------------------------------------|
|
||||
* | nram_tmp1 | nram_tmp2 | nram_tmp3 | nram_tmp4 | nram_grad_output |
|
||||
*
|
||||
*/
|
||||
|
||||
int type_align = NFU_ALIGN_SIZE / sizeof(T);
|
||||
int channels_max_num_nram = MAX_NRAM_SIZE / sizeof(T);
|
||||
int channels_nram_split =
|
||||
channels_max_num_nram / FIVESPLIT / type_align * type_align;
|
||||
int channel_rem = channels % channels_nram_split;
|
||||
int channel_loops =
|
||||
channels / channels_nram_split + (channel_rem != 0 ? 1 : 0);
|
||||
if (offset != NULL) {
|
||||
channels_nram_split =
|
||||
channels_max_num_nram / THIRTEENSPLIT / type_align * type_align;
|
||||
channel_rem = channels % channels_nram_split;
|
||||
channel_loops =
|
||||
channels / channels_nram_split + (channel_rem != 0 ? 1 : 0);
|
||||
}
|
||||
|
||||
for (int channel_loop_index = 0; channel_loop_index < channel_loops;
|
||||
++channel_loop_index) {
|
||||
int channels_num =
|
||||
channels_nram_split >= channels ? channels : channels_nram_split;
|
||||
const int channel_offset = channel_loop_index * channels_num;
|
||||
if (channel_loop_index + 1 == channel_loops && channel_rem != 0) {
|
||||
channels_num = channel_rem;
|
||||
}
|
||||
int channels_align = CEIL_ALIGN(channels_num, type_align);
|
||||
const int32_t nram_sum_tmp_channel = NFU_ALIGN_SIZE / sizeof(T);
|
||||
int nram_limit = (MAX_NRAM_SIZE / sizeof(T) - 5 * channels_align -
|
||||
nram_sum_tmp_channel) >>
|
||||
1;
|
||||
int c_slice = 0;
|
||||
int c_slice_align = 0;
|
||||
T *nram_tmp1 = (T *)nram_buffer;
|
||||
T *nram_tmp2 = (T *)nram_buffer + channels_align;
|
||||
T *nram_tmp3 = (T *)nram_buffer + 2 * channels_align;
|
||||
T *nram_tmp4 = (T *)nram_buffer + 3 * channels_align;
|
||||
T *nram_grad_output = nram_tmp4 + channels_align;
|
||||
T *nram_sum_tmp = NULL;
|
||||
T *nram_ping_input = NULL;
|
||||
T *nram_pong_input = NULL;
|
||||
__bang_write_value((T *)nram_grad_output, channels_align, (T)0);
|
||||
__asm__ volatile("sync;");
|
||||
|
||||
if (offset != NULL) {
|
||||
c_slice = nram_limit / FOURSPLIT / type_align * type_align;
|
||||
nram_sum_tmp = nram_grad_output + channels_align;
|
||||
nram_ping_input = nram_sum_tmp + nram_sum_tmp_channel;
|
||||
nram_pong_input = nram_ping_input + FOURSPLIT * c_slice;
|
||||
__bang_write_value((T *)nram_sum_tmp, nram_sum_tmp_channel, (T)0);
|
||||
__bang_write_value((T *)nram_ping_input, FOURSPLIT * c_slice, (T)0);
|
||||
__bang_write_value((T *)nram_pong_input, FOURSPLIT * c_slice, (T)0);
|
||||
__asm__ volatile("sync;");
|
||||
}
|
||||
const T num_bins =
|
||||
static_cast<T>(max(roi_bin_grid_height * roi_bin_grid_width, 1));
|
||||
const T value_div = 1.0f / num_bins;
|
||||
bool is_ping_empty = true;
|
||||
__memcpy(nram_grad_output,
|
||||
grad_output + channels * bin_index + channel_offset,
|
||||
channels_num * sizeof(T), GDRAM2NRAM);
|
||||
__bang_mul_scalar(nram_grad_output, nram_grad_output, value_div,
|
||||
channels_align);
|
||||
for (int iy = 0; iy < roi_bin_grid_height; ++iy) {
|
||||
T y = roi_y_min + out_height * bin_height +
|
||||
static_cast<T>(iy + .5f) * bin_height /
|
||||
static_cast<T>(roi_bin_grid_height);
|
||||
T y_tmp = y;
|
||||
if (y_tmp < -1.0 || y_tmp > height) {
|
||||
is_ping_empty = true;
|
||||
continue;
|
||||
}
|
||||
if (y_tmp <= 0) {
|
||||
y_tmp = 0;
|
||||
}
|
||||
int y_low = 0, y_high = 0;
|
||||
y_low = int(y_tmp);
|
||||
if (y_low >= height - 1) {
|
||||
y_high = y_low = height - 1;
|
||||
y_tmp = T(y_low);
|
||||
} else {
|
||||
y_high = y_low + 1;
|
||||
}
|
||||
for (int ix = 0; ix < roi_bin_grid_width; ++ix) {
|
||||
T x = roi_x_min + out_width * bin_width +
|
||||
static_cast<T>(ix + .5f) * bin_width /
|
||||
static_cast<T>(roi_bin_grid_width);
|
||||
const int sample_index = iy * roi_bin_grid_width + ix;
|
||||
int c_rem = channels_num;
|
||||
bool is_empty = false;
|
||||
T w1, w2, w3, w4;
|
||||
int x_low = 0, x_high = 0;
|
||||
bilinearInterpolate(width, y_tmp, x, &w1, &w2, &w3, &w4, &x_low,
|
||||
&x_high, y_low, &is_empty);
|
||||
if (is_empty) {
|
||||
is_ping_empty = true;
|
||||
continue;
|
||||
}
|
||||
__bang_mul_scalar((T *)nram_tmp1, (T *)nram_grad_output, w1,
|
||||
channels_align);
|
||||
__bang_mul_scalar((T *)nram_tmp2, (T *)nram_grad_output, w2,
|
||||
channels_align);
|
||||
__bang_mul_scalar((T *)nram_tmp3, (T *)nram_grad_output, w3,
|
||||
channels_align);
|
||||
__bang_mul_scalar((T *)nram_tmp4, (T *)nram_grad_output, w4,
|
||||
channels_align);
|
||||
__asm__ volatile("sync;");
|
||||
__bang_atomic_add(
|
||||
(T *)nram_tmp1,
|
||||
(T *)(offset_grad_input + (y_low * width + x_low) * channels +
|
||||
channel_offset),
|
||||
(T *)nram_tmp1, channels_num);
|
||||
__bang_atomic_add(
|
||||
(T *)nram_tmp2,
|
||||
(T *)(offset_grad_input + (y_low * width + x_high) * channels +
|
||||
channel_offset),
|
||||
(T *)nram_tmp2, channels_num);
|
||||
__bang_atomic_add(
|
||||
(T *)nram_tmp3,
|
||||
(T *)(offset_grad_input + (y_high * width + x_low) * channels +
|
||||
channel_offset),
|
||||
(T *)nram_tmp3, channels_num);
|
||||
__bang_atomic_add(
|
||||
(T *)nram_tmp4,
|
||||
(T *)(offset_grad_input + (y_high * width + x_high) * channels +
|
||||
channel_offset),
|
||||
(T *)nram_tmp4, channels_num);
|
||||
if (offset != NULL) {
|
||||
c_slice = nram_limit / FOURSPLIT / type_align * type_align;
|
||||
c_slice_align = 0;
|
||||
if (is_ping_empty) {
|
||||
c_slice = c_slice > c_rem ? c_rem : c_slice;
|
||||
c_slice_align = CEIL_ALIGN(c_slice, type_align);
|
||||
__bang_write_value(nram_ping_input, FOURSPLIT * c_slice_align,
|
||||
(T)0);
|
||||
__asm__ volatile("sync;");
|
||||
const T *src_offset1 = offset_input + y_low * width * channels +
|
||||
x_low * channels + channel_offset;
|
||||
const T *src_offset2 = offset_input + y_low * width * channels +
|
||||
x_high * channels + channel_offset;
|
||||
const T *src_offset3 = offset_input + y_high * width * channels +
|
||||
x_low * channels + channel_offset;
|
||||
const T *src_offset4 = offset_input + y_high * width * channels +
|
||||
x_high * channels + channel_offset;
|
||||
__memcpy(nram_ping_input, src_offset1, c_slice * sizeof(T),
|
||||
GDRAM2NRAM);
|
||||
__memcpy(nram_ping_input + c_slice_align, src_offset2,
|
||||
c_slice * sizeof(T), GDRAM2NRAM);
|
||||
__memcpy(nram_ping_input + 2 * c_slice_align, src_offset3,
|
||||
c_slice * sizeof(T), GDRAM2NRAM);
|
||||
__memcpy(nram_ping_input + 3 * c_slice_align, src_offset4,
|
||||
c_slice * sizeof(T), GDRAM2NRAM);
|
||||
is_ping_empty = false;
|
||||
}
|
||||
int c_offset = 0;
|
||||
int pongc_slice = 0;
|
||||
int pongc_slice_align = 0;
|
||||
while (c_rem > 0) {
|
||||
c_slice = c_slice > c_rem ? c_rem : c_slice;
|
||||
c_slice_align = CEIL_ALIGN(c_slice, type_align);
|
||||
if (sample_index + 1 < roi_bin_grid_height * roi_bin_grid_width) {
|
||||
int iy_tmp = (sample_index + 1) / roi_bin_grid_width;
|
||||
int ix_tmp = (sample_index + 1) % roi_bin_grid_width;
|
||||
T y_tmp = roi_y_min + out_height * bin_height +
|
||||
static_cast<T>(iy_tmp + .5f) * bin_height /
|
||||
static_cast<T>(roi_bin_grid_height);
|
||||
T x_tmp = roi_x_min + out_width * bin_width +
|
||||
static_cast<T>(ix_tmp + .5f) * bin_width /
|
||||
static_cast<T>(roi_bin_grid_width);
|
||||
int x_low_tmp = 0, x_high_tmp = 0, y_low_tmp = 0,
|
||||
y_high_tmp = 0;
|
||||
if (y_tmp < -1.0 || y_tmp > height) {
|
||||
is_empty = true;
|
||||
} else {
|
||||
T w1_tmp, w2_tmp, w3_tmp, w4_tmp;
|
||||
if (y_tmp <= 0) {
|
||||
y_tmp = 0;
|
||||
}
|
||||
y_low_tmp = int(y_tmp);
|
||||
if (y_low_tmp >= height - 1) {
|
||||
y_high_tmp = y_low_tmp = height - 1;
|
||||
y_tmp = T(y_low_tmp);
|
||||
} else {
|
||||
y_high_tmp = y_low_tmp + 1;
|
||||
}
|
||||
bilinearInterpolate(width, y_tmp, x_tmp, &w1_tmp, &w2_tmp,
|
||||
&w3_tmp, &w4_tmp, &x_low_tmp, &x_high_tmp,
|
||||
y_low_tmp, &is_empty);
|
||||
}
|
||||
pongc_slice = nram_limit / FOURSPLIT / type_align * type_align;
|
||||
pongc_slice =
|
||||
pongc_slice > channels_num ? channels_num : pongc_slice;
|
||||
pongc_slice_align = CEIL_ALIGN(pongc_slice, type_align);
|
||||
__bang_write_value(nram_pong_input,
|
||||
FOURSPLIT * pongc_slice_align, (T)0);
|
||||
__asm__ volatile("sync;");
|
||||
if (!is_empty) {
|
||||
const T *src_offset1 = offset_input +
|
||||
y_low_tmp * width * channels +
|
||||
x_low_tmp * channels + channel_offset;
|
||||
const T *src_offset2 = offset_input +
|
||||
y_low_tmp * width * channels +
|
||||
x_high_tmp * channels + channel_offset;
|
||||
const T *src_offset3 = offset_input +
|
||||
y_high_tmp * width * channels +
|
||||
x_low_tmp * channels + channel_offset;
|
||||
const T *src_offset4 = offset_input +
|
||||
y_high_tmp * width * channels +
|
||||
x_high_tmp * channels + channel_offset;
|
||||
__memcpy_async(nram_pong_input, src_offset1,
|
||||
pongc_slice * sizeof(T), GDRAM2NRAM);
|
||||
__memcpy_async(nram_pong_input + pongc_slice_align,
|
||||
src_offset2, pongc_slice * sizeof(T),
|
||||
GDRAM2NRAM);
|
||||
__memcpy_async(nram_pong_input + 2 * pongc_slice_align,
|
||||
src_offset3, pongc_slice * sizeof(T),
|
||||
GDRAM2NRAM);
|
||||
__memcpy_async(nram_pong_input + 3 * pongc_slice_align,
|
||||
src_offset4, pongc_slice * sizeof(T),
|
||||
GDRAM2NRAM);
|
||||
}
|
||||
}
|
||||
|
||||
__bang_mul_scalar(nram_tmp1, nram_ping_input + 3 * c_slice_align,
|
||||
y - y_low, c_slice_align);
|
||||
__bang_mul_scalar(nram_tmp2, nram_ping_input + c_slice_align,
|
||||
y_high - y, c_slice_align);
|
||||
__bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align);
|
||||
__bang_mul_scalar(nram_tmp2, nram_ping_input + 2 * c_slice_align,
|
||||
y_low - y, c_slice_align);
|
||||
__bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align);
|
||||
__bang_mul_scalar(nram_tmp2, nram_ping_input, y - y_high,
|
||||
c_slice_align);
|
||||
__bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align);
|
||||
__bang_mul_scalar(nram_tmp1, nram_tmp1, gamma * roi_width,
|
||||
c_slice_align);
|
||||
__bang_mul(nram_tmp1, nram_grad_output, nram_tmp1, c_slice_align);
|
||||
const int32_t kernel_width =
|
||||
c_slice_align / nram_sum_tmp_channel +
|
||||
(int32_t)(c_slice_align % nram_sum_tmp_channel > 0);
|
||||
__bang_sumpool(nram_sum_tmp, nram_tmp1, nram_sum_tmp_channel, 1,
|
||||
kernel_width, 1, kernel_width, kernel_width, 1);
|
||||
__bang_reduce_sum(nram_sum_tmp, nram_sum_tmp,
|
||||
nram_sum_tmp_channel);
|
||||
__bang_atomic_add(
|
||||
(T *)nram_sum_tmp,
|
||||
(T *)(grad_offset +
|
||||
out_batch * pooled_width * pooled_height * 2 +
|
||||
out_height * pooled_width + out_width),
|
||||
(T *)nram_sum_tmp, 1);
|
||||
__bang_write_value((T *)nram_sum_tmp, nram_sum_tmp_channel, (T)0);
|
||||
__bang_mul_scalar(nram_tmp1, nram_ping_input + 3 * c_slice_align,
|
||||
x - x_low, c_slice_align);
|
||||
__bang_mul_scalar(nram_tmp2, nram_ping_input + 2 * c_slice_align,
|
||||
x_high - x, c_slice_align);
|
||||
__bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align);
|
||||
__bang_mul_scalar(nram_tmp2, nram_ping_input + c_slice_align,
|
||||
x_low - x, c_slice_align);
|
||||
__bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align);
|
||||
__bang_mul_scalar(nram_tmp2, nram_ping_input, x - x_high,
|
||||
c_slice_align);
|
||||
__bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align);
|
||||
__bang_mul_scalar(nram_tmp1, nram_tmp1, gamma * roi_height,
|
||||
c_slice_align);
|
||||
__bang_mul(nram_tmp1, nram_grad_output, nram_tmp1, c_slice_align);
|
||||
__bang_sumpool(nram_sum_tmp, nram_tmp1, nram_sum_tmp_channel, 1,
|
||||
kernel_width, 1, kernel_width, kernel_width, 1);
|
||||
__bang_reduce_sum(nram_sum_tmp, nram_sum_tmp,
|
||||
NFU_ALIGN_SIZE / sizeof(T));
|
||||
__bang_atomic_add(
|
||||
(T *)nram_sum_tmp,
|
||||
(T *)(grad_offset +
|
||||
out_batch * pooled_width * pooled_height * 2 +
|
||||
pooled_width * pooled_height +
|
||||
out_height * pooled_width + out_width),
|
||||
(T *)nram_sum_tmp, 1);
|
||||
|
||||
T *nram_tmp = nram_ping_input;
|
||||
nram_ping_input = nram_pong_input;
|
||||
nram_pong_input = nram_tmp;
|
||||
c_rem -= c_slice;
|
||||
c_offset += c_slice;
|
||||
__asm__ volatile("sync;");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__mlu_global__ void MLUKernelDeformRoIPoolBackward(
|
||||
cnrtDataType_t data_type, const void *grad_output, const void *input,
|
||||
const void *rois, const void *offset, void *grad_input, void *grad_offset,
|
||||
const int channels, const int height, const int width, const int num_rois,
|
||||
const int pooled_height, const int pooled_width, const float spatial_scale,
|
||||
const int sampling_ratio, const float gamma) {
|
||||
switch (data_type) {
|
||||
case CNRT_FLOAT16: {
|
||||
MLUUnion1DeformRoIPoolBackward(
|
||||
(half *)grad_output, (half *)input, (half *)rois, (half *)offset,
|
||||
(half *)grad_input, (half *)grad_offset, channels, height, width,
|
||||
num_rois, pooled_height, pooled_width,
|
||||
static_cast<half>(spatial_scale), sampling_ratio,
|
||||
static_cast<half>(gamma));
|
||||
}; break;
|
||||
case CNRT_FLOAT32: {
|
||||
MLUUnion1DeformRoIPoolBackward(
|
||||
(float *)grad_output, (float *)input, (float *)rois, (float *)offset,
|
||||
(float *)grad_input, (float *)grad_offset, channels, height, width,
|
||||
num_rois, pooled_height, pooled_width,
|
||||
static_cast<float>(spatial_scale), sampling_ratio,
|
||||
static_cast<float>(gamma));
|
||||
}; break;
|
||||
default: {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void KernelDeformRoIPoolBackward(
|
||||
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
|
||||
cnrtDataType_t data_type, const void *grad_output, const void *input,
|
||||
const void *rois, const void *offset, void *grad_input, void *grad_offset,
|
||||
const int channels, const int height, const int width, const int num_rois,
|
||||
const int pooled_height, const int pooled_width, const float spatial_scale,
|
||||
const int sampling_ratio, const float gamma) {
|
||||
MLUKernelDeformRoIPoolBackward<<<k_dim, k_type, queue>>>(
|
||||
data_type, grad_output, input, rois, offset, grad_input, grad_offset,
|
||||
channels, height, width, num_rois, pooled_height, pooled_width,
|
||||
spatial_scale, sampling_ratio, gamma);
|
||||
}
|
|
@ -0,0 +1,343 @@
|
|||
/*************************************************************************
|
||||
* Copyright (C) 2022 Cambricon.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include "pytorch_device_registry.hpp"
|
||||
#include "pytorch_mlu_helper.hpp"
|
||||
|
||||
void KernelDeformRoIPoolForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
|
||||
cnrtQueue_t queue, cnrtDataType_t data_type,
|
||||
const void *input, const void *rois,
|
||||
const void *offset, void *output,
|
||||
const int channels, const int height,
|
||||
const int width, const int num_rois,
|
||||
const int pooled_height, const int pooled_width,
|
||||
const float spatial_scale,
|
||||
const int sampling_ratio, const float gamma);
|
||||
|
||||
void KernelDeformRoIPoolBackward(
|
||||
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
|
||||
cnrtDataType_t data_type, const void *grad_output, const void *input,
|
||||
const void *rois, const void *offset, void *grad_input, void *grad_offset,
|
||||
const int channels, const int height, const int width, const int num_rois,
|
||||
const int pooled_height, const int pooled_width, const float spatial_scale,
|
||||
const int sampling_ratio, const float gamma);
|
||||
|
||||
// policy function for forward and backward
|
||||
static void policyFunc(const int bin_num, cnrtDim3_t *k_dim,
|
||||
cnrtFunctionType_t *k_type) {
|
||||
const size_t cluster_limit = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
|
||||
;
|
||||
const size_t core_limit = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
|
||||
const size_t bin_num_align = CEIL_ALIGN(bin_num, core_limit);
|
||||
k_dim->x = core_limit;
|
||||
k_dim->y = (bin_num_align / core_limit) > cluster_limit
|
||||
? cluster_limit
|
||||
: (bin_num_align / core_limit);
|
||||
k_dim->z = 1;
|
||||
*k_type = CNRT_FUNC_TYPE_UNION1;
|
||||
}
|
||||
|
||||
void DeformRoIPoolForwardMLUKernelLauncher(Tensor input, Tensor rois,
|
||||
Tensor offset, Tensor output,
|
||||
int pooled_height, int pooled_width,
|
||||
float spatial_scale,
|
||||
int sampling_ratio, float gamma) {
|
||||
// Check dtype.
|
||||
TORCH_CHECK(
|
||||
input.scalar_type() == at::kFloat || input.scalar_type() == at::kHalf,
|
||||
"input type should be Float or Half, got ", input.scalar_type());
|
||||
TORCH_CHECK(input.scalar_type() == rois.scalar_type(),
|
||||
"rois should have the same type as input");
|
||||
|
||||
// Check shape.
|
||||
TORCH_CHECK(input.dim() == 4, "input should be 4d tensor, got ", input.dim(),
|
||||
"D.");
|
||||
TORCH_CHECK(rois.dim() == 2, "rois should be 2d tensor, got ", rois.dim(),
|
||||
"D.");
|
||||
if (offset.defined() && offset.numel() > 0) {
|
||||
TORCH_CHECK(input.scalar_type() == offset.scalar_type(),
|
||||
"offset should have the same type as input");
|
||||
TORCH_CHECK(offset.dim() == 4, "offset should be 4d tensor, got ",
|
||||
offset.dim(), "D.");
|
||||
TORCH_CHECK(
|
||||
(offset.size(0) == rois.size(0)), "offset.size(0) = ", offset.size(0),
|
||||
"while rois.size(0)) = ", rois.size(0), ". They should be the same.");
|
||||
TORCH_CHECK((offset.size(1) == 2), "offset.size(1) should be 2, ",
|
||||
"but now offset.size(1) = ", offset.size(1), ".");
|
||||
TORCH_CHECK((offset.size(2) == output.size(2)),
|
||||
"offset.size(2) = ", offset.size(2),
|
||||
"while output.size(2)) = ", output.size(2),
|
||||
". They should be the same.");
|
||||
TORCH_CHECK((offset.size(3) == output.size(3)),
|
||||
"offset.size(3) = ", offset.size(3),
|
||||
"while output.size(3)) = ", output.size(3),
|
||||
". They should be the same.");
|
||||
}
|
||||
|
||||
TORCH_CHECK(spatial_scale > 0 && spatial_scale <= 1,
|
||||
"spatial_scale should be within (0, 1], got ", spatial_scale,
|
||||
".");
|
||||
|
||||
// compute kernel params
|
||||
auto height = input.size(2);
|
||||
auto width = input.size(3);
|
||||
auto channels = input.size(1);
|
||||
auto num_rois = output.size(0);
|
||||
|
||||
if (output.numel() == 0) {
|
||||
output = at::zeros({num_rois, channels, pooled_height, pooled_width},
|
||||
input.options());
|
||||
return;
|
||||
}
|
||||
|
||||
// zero element check
|
||||
TORCH_CHECK(input.size(0) != 0, "input.size(0) should not be zero, got ",
|
||||
input.size(0));
|
||||
TORCH_CHECK(rois.numel() != 0, "rois.numel() should not be zero, got ",
|
||||
rois.numel());
|
||||
if (input.numel() == 0 || output.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// large tensor check
|
||||
const size_t max_input_num = 2147483648; // 2^31, 2G num
|
||||
TORCH_CHECK(input.numel() < max_input_num,
|
||||
"input.numel() should be less than 2147483648, got ",
|
||||
input.numel());
|
||||
TORCH_CHECK(rois.numel() < max_input_num,
|
||||
"rois.numel() should be less than 2147483648, got ",
|
||||
rois.numel());
|
||||
TORCH_CHECK(output.numel() < max_input_num,
|
||||
"output.numel() should be less than 2147483648, got ",
|
||||
output.numel());
|
||||
TORCH_CHECK(!offset.defined() || offset.numel() < max_input_num,
|
||||
"offset.numel() should be less than 2147483648, got ",
|
||||
offset.numel());
|
||||
|
||||
auto memory_format =
|
||||
torch_mlu::cnnl::ops::get_channels_last_memory_format(input.dim());
|
||||
auto input_ = torch_mlu::cnnl::ops::cnnl_contiguous(input, memory_format);
|
||||
|
||||
at::Tensor output_ =
|
||||
at::empty({num_rois, channels, pooled_height, pooled_width},
|
||||
input.options(), memory_format);
|
||||
|
||||
// calculate task dimension
|
||||
cnrtDim3_t k_dim;
|
||||
cnrtFunctionType_t k_type;
|
||||
policyFunc(num_rois * pooled_height * pooled_width, &k_dim, &k_type);
|
||||
|
||||
// get compute queue
|
||||
auto queue = torch_mlu::getCurQueue();
|
||||
|
||||
// get ptr of tensors
|
||||
auto input_impl = torch_mlu::getMluTensorImpl(input_);
|
||||
auto input_ptr = input_impl->cnnlMalloc();
|
||||
auto rois_impl = torch_mlu::getMluTensorImpl(rois);
|
||||
auto rois_ptr = rois_impl->cnnlMalloc();
|
||||
auto offset_impl = torch_mlu::getMluTensorImpl(offset);
|
||||
auto offset_ptr = offset_impl->cnnlMalloc();
|
||||
auto output_impl = torch_mlu::getMluTensorImpl(output_);
|
||||
auto output_ptr = output_impl->cnnlMalloc();
|
||||
|
||||
// get comput dtype of input
|
||||
cnrtDataType_t data_type = torch_mlu::toCnrtDtype(input_.dtype());
|
||||
|
||||
// launch kernel
|
||||
CNLOG(INFO) << "Launch Kernel MLUKernelDeformRoIPoolForward<<<" << k_dim.x
|
||||
<< ", " << k_dim.y << ", " << k_dim.z << ">>>";
|
||||
|
||||
KernelDeformRoIPoolForward(k_dim, k_type, queue, data_type, input_ptr,
|
||||
rois_ptr, offset_ptr, output_ptr, channels, height,
|
||||
width, num_rois, pooled_height, pooled_width,
|
||||
spatial_scale, sampling_ratio, gamma);
|
||||
|
||||
output.copy_(output_);
|
||||
}
|
||||
|
||||
void DeformRoIPoolBackwardMLUKernelLauncher(
|
||||
Tensor grad_output, Tensor input, Tensor rois, Tensor offset,
|
||||
Tensor grad_input, Tensor grad_offset, int pooled_height, int pooled_width,
|
||||
float spatial_scale, int sampling_ratio, float gamma) {
|
||||
// Check dtype.
|
||||
TORCH_CHECK(
|
||||
input.scalar_type() == at::kFloat || input.scalar_type() == at::kHalf,
|
||||
"input type should be Float or Half, got ", input.scalar_type());
|
||||
TORCH_CHECK(input.scalar_type() == grad_output.scalar_type(),
|
||||
"grad_output should have the same type as input");
|
||||
TORCH_CHECK(input.scalar_type() == rois.scalar_type(),
|
||||
"rois should have the same type as input");
|
||||
TORCH_CHECK(input.scalar_type() == grad_input.scalar_type(),
|
||||
"grad_input should have the same type as input");
|
||||
|
||||
// Check shape.
|
||||
TORCH_CHECK(grad_output.dim() == 4, "grad_output should be 4d tensor, got ",
|
||||
grad_output.dim(), "D.");
|
||||
TORCH_CHECK(input.dim() == 4, "input should be 4d tensor, got ", input.dim(),
|
||||
"D.");
|
||||
TORCH_CHECK(rois.dim() == 2, "rois should be 2d tensor, got ", rois.dim(),
|
||||
"D.");
|
||||
if (offset.defined() && offset.numel() > 0) {
|
||||
TORCH_CHECK(input.scalar_type() == offset.scalar_type(),
|
||||
"offset should have the same type as input");
|
||||
TORCH_CHECK(offset.dim() == 4, "offset should be 4d tensor, got ",
|
||||
offset.dim(), "D.");
|
||||
TORCH_CHECK(
|
||||
(offset.size(0) == rois.size(0)), "offset.size(0) = ", offset.size(0),
|
||||
"while rois.size(0)) = ", rois.size(0), ". They should be the same.");
|
||||
TORCH_CHECK((offset.size(1) == 2), "offset.size(1) should be 2, ",
|
||||
"but now offset.size(1) = ", offset.size(1), ".");
|
||||
TORCH_CHECK((offset.size(2) == grad_output.size(2)),
|
||||
"offset.size(2) = ", offset.size(2),
|
||||
"while grad_output.size(2)) = ", grad_output.size(2),
|
||||
". They should be the same.");
|
||||
TORCH_CHECK((offset.size(3) == grad_output.size(3)),
|
||||
"offset.size(3) = ", offset.size(3),
|
||||
"while grad_output.size(3)) = ", grad_output.size(3),
|
||||
". They should be the same.");
|
||||
}
|
||||
|
||||
TORCH_CHECK(spatial_scale > 0 && spatial_scale <= 1,
|
||||
"spatial_scale should be within (0, 1], got ", spatial_scale);
|
||||
|
||||
// Check relationship between tensor.
|
||||
TORCH_CHECK((grad_output.size(0) == rois.size(0)),
|
||||
"grad_output.size(0) = ", grad_output.size(0),
|
||||
"while rois.size(0)) = ", rois.size(0),
|
||||
". They should be the same.");
|
||||
TORCH_CHECK((grad_output.size(1) == input.size(1)),
|
||||
"grad_output.size(1) = ", grad_output.size(1),
|
||||
"while input.size(1)) = ", input.size(1),
|
||||
". They should be the same.");
|
||||
TORCH_CHECK((grad_output.size(2) == pooled_height),
|
||||
"grad_output.size(2) = ", grad_output.size(2),
|
||||
"while pooled_height = ", pooled_height,
|
||||
". They should be the same.");
|
||||
TORCH_CHECK((grad_output.size(3) == pooled_width),
|
||||
"grad_output.size(3) = ", grad_output.size(3),
|
||||
"while pooled_width = ", pooled_width,
|
||||
". They should be the same.");
|
||||
|
||||
// compute kernel params
|
||||
auto batch = input.size(0);
|
||||
auto channels = input.size(1);
|
||||
auto height = input.size(2);
|
||||
auto width = input.size(3);
|
||||
auto num_rois = grad_output.size(0);
|
||||
|
||||
// zero element check
|
||||
TORCH_CHECK(input.size(0) != 0, "input.size(0) should not be zero, got ",
|
||||
input.size(0));
|
||||
TORCH_CHECK(rois.numel() != 0, "rois.numel() should not be zero, got ",
|
||||
rois.numel());
|
||||
if (input.numel() == 0 || grad_output.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// large tensor check
|
||||
const size_t max_input_num = 2147483648; // 2^31, 2G num
|
||||
TORCH_CHECK(input.numel() < max_input_num,
|
||||
"input.numel() should be less than 2147483648, got ",
|
||||
input.numel());
|
||||
TORCH_CHECK(rois.numel() < max_input_num,
|
||||
"rois.numel() should be less than 2147483648, got ",
|
||||
rois.numel());
|
||||
TORCH_CHECK(grad_output.numel() < max_input_num,
|
||||
"grad_output.numel() should be less than 2147483648, got ",
|
||||
grad_output.numel());
|
||||
TORCH_CHECK(!offset.defined() || offset.numel() < max_input_num,
|
||||
"offset.numel() should be less than 2147483648, got ",
|
||||
offset.numel());
|
||||
|
||||
auto memory_format =
|
||||
torch_mlu::cnnl::ops::get_channels_last_memory_format(grad_output.dim());
|
||||
auto grad_output_ =
|
||||
torch_mlu::cnnl::ops::cnnl_contiguous(grad_output, memory_format);
|
||||
memory_format =
|
||||
torch_mlu::cnnl::ops::get_channels_last_memory_format(input.dim());
|
||||
auto input_ = torch_mlu::cnnl::ops::cnnl_contiguous(input, memory_format);
|
||||
at::Tensor grad_input_ = at::empty({batch, channels, height, width},
|
||||
input.options(), memory_format)
|
||||
.zero_();
|
||||
|
||||
// calculate task dimension
|
||||
cnrtDim3_t k_dim;
|
||||
cnrtFunctionType_t k_type;
|
||||
policyFunc(num_rois * pooled_height * pooled_width, &k_dim, &k_type);
|
||||
|
||||
// get compute queue
|
||||
auto queue = torch_mlu::getCurQueue();
|
||||
|
||||
// get ptr of tensors
|
||||
auto grad_output_impl = torch_mlu::getMluTensorImpl(grad_output_);
|
||||
auto grad_output_ptr = grad_output_impl->cnnlMalloc();
|
||||
auto input_impl = torch_mlu::getMluTensorImpl(input_);
|
||||
auto input_ptr = input_impl->cnnlMalloc();
|
||||
auto rois_impl = torch_mlu::getMluTensorImpl(rois);
|
||||
auto rois_ptr = rois_impl->cnnlMalloc();
|
||||
auto offset_impl = torch_mlu::getMluTensorImpl(offset);
|
||||
auto offset_ptr = offset_impl->cnnlMalloc();
|
||||
auto grad_input_impl = torch_mlu::getMluTensorImpl(grad_input_);
|
||||
auto grad_input_ptr = grad_input_impl->cnnlMalloc();
|
||||
auto grad_offset_impl = torch_mlu::getMluTensorImpl(grad_offset);
|
||||
auto grad_offset_ptr = grad_offset_impl->cnnlMalloc();
|
||||
|
||||
// get comput dtype of input
|
||||
cnrtDataType_t data_type = torch_mlu::toCnrtDtype(input.dtype());
|
||||
|
||||
// launch kernel
|
||||
CNLOG(INFO) << "Launch Kernel KernelDeformRoIPoolBackward<<<" << k_dim.x
|
||||
<< ", " << k_dim.y << ", " << k_dim.z << ">>>";
|
||||
|
||||
KernelDeformRoIPoolBackward(k_dim, k_type, queue, data_type, grad_output_ptr,
|
||||
input_ptr, rois_ptr, offset_ptr, grad_input_ptr,
|
||||
grad_offset_ptr, channels, height, width,
|
||||
num_rois, pooled_height, pooled_width,
|
||||
spatial_scale, sampling_ratio, gamma);
|
||||
|
||||
grad_input.copy_(grad_input_);
|
||||
}
|
||||
|
||||
void deform_roi_pool_forward_mlu(Tensor input, Tensor rois, Tensor offset,
|
||||
Tensor output, int pooled_height,
|
||||
int pooled_width, float spatial_scale,
|
||||
int sampling_ratio, float gamma) {
|
||||
DeformRoIPoolForwardMLUKernelLauncher(input, rois, offset, output,
|
||||
pooled_height, pooled_width,
|
||||
spatial_scale, sampling_ratio, gamma);
|
||||
}
|
||||
|
||||
void deform_roi_pool_backward_mlu(Tensor grad_output, Tensor input, Tensor rois,
|
||||
Tensor offset, Tensor grad_input,
|
||||
Tensor grad_offset, int pooled_height,
|
||||
int pooled_width, float spatial_scale,
|
||||
int sampling_ratio, float gamma) {
|
||||
DeformRoIPoolBackwardMLUKernelLauncher(
|
||||
grad_output, input, rois, offset, grad_input, grad_offset, pooled_height,
|
||||
pooled_width, spatial_scale, sampling_ratio, gamma);
|
||||
}
|
||||
|
||||
void deform_roi_pool_forward_impl(Tensor input, Tensor rois, Tensor offset,
|
||||
Tensor output, int pooled_height,
|
||||
int pooled_width, float spatial_scale,
|
||||
int sampling_ratio, float gamma);
|
||||
|
||||
void deform_roi_pool_backward_impl(Tensor grad_output, Tensor input,
|
||||
Tensor rois, Tensor offset,
|
||||
Tensor grad_input, Tensor grad_offset,
|
||||
int pooled_height, int pooled_width,
|
||||
float spatial_scale, int sampling_ratio,
|
||||
float gamma);
|
||||
|
||||
REGISTER_DEVICE_IMPL(deform_roi_pool_forward_impl, MLU,
|
||||
deform_roi_pool_forward_mlu);
|
||||
REGISTER_DEVICE_IMPL(deform_roi_pool_backward_impl, MLU,
|
||||
deform_roi_pool_backward_mlu);
|
|
@ -2,8 +2,11 @@
|
|||
import os
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
|
||||
|
||||
_USING_PARROTS = True
|
||||
try:
|
||||
from parrots.autograd import gradcheck
|
||||
|
@ -93,3 +96,53 @@ class TestDeformRoIPool:
|
|||
gradcheck(droipool, (x, rois), no_grads=[rois])
|
||||
else:
|
||||
gradcheck(droipool, (x, rois), eps=1e-2, atol=1e-2)
|
||||
|
||||
def _test_deform_roi_pool_allclose(self, device, dtype=torch.float):
|
||||
from mmcv.ops import DeformRoIPoolPack
|
||||
pool_h = 2
|
||||
pool_w = 2
|
||||
spatial_scale = 1.0
|
||||
sampling_ratio = 2
|
||||
|
||||
for case, output in zip(inputs, outputs):
|
||||
np_input = np.array(case[0])
|
||||
np_rois = np.array(case[1])
|
||||
np_output = np.array(output[0])
|
||||
np_grad = np.array(output[1])
|
||||
|
||||
x = torch.tensor(
|
||||
np_input, device=device, dtype=torch.float, requires_grad=True)
|
||||
rois = torch.tensor(np_rois, device=device, dtype=torch.float)
|
||||
output_c = x.size(1)
|
||||
droipool = DeformRoIPoolPack(
|
||||
(pool_h, pool_w),
|
||||
output_c,
|
||||
spatial_scale=spatial_scale,
|
||||
sampling_ratio=sampling_ratio).to(device)
|
||||
|
||||
output = droipool(x, rois)
|
||||
output.backward(torch.ones_like(output))
|
||||
assert np.allclose(output.data.cpu().numpy(), np_output, 1e-3)
|
||||
assert np.allclose(x.grad.data.cpu().numpy(), np_grad, 1e-3)
|
||||
|
||||
@pytest.mark.parametrize('device', [
|
||||
pytest.param(
|
||||
'cuda',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
|
||||
pytest.param(
|
||||
'mlu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MLU_AVAILABLE, reason='requires MLU support'))
|
||||
])
|
||||
@pytest.mark.parametrize('dtype', [
|
||||
torch.float,
|
||||
pytest.param(
|
||||
torch.double,
|
||||
marks=pytest.mark.skipif(
|
||||
IS_MLU_AVAILABLE,
|
||||
reason='MLU does not support for 64-bit floating point')),
|
||||
torch.half
|
||||
])
|
||||
def test_deform_roi_pool_allclose(self, device, dtype):
|
||||
self._test_deform_roi_pool_allclose(device, dtype)
|
||||
|
|
Loading…
Reference in New Issue