[Fix] Fix a potential bug in prroipool op (#2200)

pull/2220/head
Jingwei Zhang 2022-08-18 15:02:20 +08:00 committed by Zaida Zhou
parent 86f9dc7a40
commit f6fd6c212f
1 changed files with 13 additions and 13 deletions

View File

@ -223,13 +223,13 @@ __global__ void prroi_pool_backward_cuda_kernel(
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
rois += n * 5;
auto rois_cur = rois + n * 5;
int roi_batch_ind = rois[0];
T roi_x1 = rois[1] * spatial_scale;
T roi_y1 = rois[2] * spatial_scale;
T roi_x2 = rois[3] * spatial_scale;
T roi_y2 = rois[4] * spatial_scale;
int roi_batch_ind = rois_cur[0];
T roi_x1 = rois_cur[1] * spatial_scale;
T roi_y1 = rois_cur[2] * spatial_scale;
T roi_x2 = rois_cur[3] * spatial_scale;
T roi_y2 = rois_cur[4] * spatial_scale;
T roi_width = max(roi_x2 - roi_x1, (T)0);
T roi_height = max(roi_y2 - roi_y1, (T)0);
@ -278,13 +278,13 @@ __global__ void prroi_pool_coor_backward_cuda_kernel(
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
rois += n * 5;
auto rois_cur = rois + n * 5;
int roi_batch_ind = rois[0];
T roi_x1 = rois[1] * spatial_scale;
T roi_y1 = rois[2] * spatial_scale;
T roi_x2 = rois[3] * spatial_scale;
T roi_y2 = rois[4] * spatial_scale;
int roi_batch_ind = rois_cur[0];
T roi_x1 = rois_cur[1] * spatial_scale;
T roi_y1 = rois_cur[2] * spatial_scale;
T roi_x2 = rois_cur[3] * spatial_scale;
T roi_y2 = rois_cur[4] * spatial_scale;
T roi_width = max(roi_x2 - roi_x1, (T)0);
T roi_height = max(roi_y2 - roi_y1, (T)0);
@ -307,7 +307,7 @@ __global__ void prroi_pool_coor_backward_cuda_kernel(
T sum_out = bin_size == T(0) ? T(0) : output_grad_val / bin_size;
// WARNING: to be discussed
if (sum_out == 0) return;
if (sum_out == 0) continue;
int start_x, start_y, end_x, end_y;