mirror of https://github.com/open-mmlab/mmcv.git
[Fix] Fix a potential bug in prroipool op (#2200)
parent
86f9dc7a40
commit
f6fd6c212f
|
@ -223,13 +223,13 @@ __global__ void prroi_pool_backward_cuda_kernel(
|
|||
int ph = (index / pooled_width) % pooled_height;
|
||||
int c = (index / pooled_width / pooled_height) % channels;
|
||||
int n = index / pooled_width / pooled_height / channels;
|
||||
rois += n * 5;
|
||||
auto rois_cur = rois + n * 5;
|
||||
|
||||
int roi_batch_ind = rois[0];
|
||||
T roi_x1 = rois[1] * spatial_scale;
|
||||
T roi_y1 = rois[2] * spatial_scale;
|
||||
T roi_x2 = rois[3] * spatial_scale;
|
||||
T roi_y2 = rois[4] * spatial_scale;
|
||||
int roi_batch_ind = rois_cur[0];
|
||||
T roi_x1 = rois_cur[1] * spatial_scale;
|
||||
T roi_y1 = rois_cur[2] * spatial_scale;
|
||||
T roi_x2 = rois_cur[3] * spatial_scale;
|
||||
T roi_y2 = rois_cur[4] * spatial_scale;
|
||||
|
||||
T roi_width = max(roi_x2 - roi_x1, (T)0);
|
||||
T roi_height = max(roi_y2 - roi_y1, (T)0);
|
||||
|
@ -278,13 +278,13 @@ __global__ void prroi_pool_coor_backward_cuda_kernel(
|
|||
int ph = (index / pooled_width) % pooled_height;
|
||||
int c = (index / pooled_width / pooled_height) % channels;
|
||||
int n = index / pooled_width / pooled_height / channels;
|
||||
rois += n * 5;
|
||||
auto rois_cur = rois + n * 5;
|
||||
|
||||
int roi_batch_ind = rois[0];
|
||||
T roi_x1 = rois[1] * spatial_scale;
|
||||
T roi_y1 = rois[2] * spatial_scale;
|
||||
T roi_x2 = rois[3] * spatial_scale;
|
||||
T roi_y2 = rois[4] * spatial_scale;
|
||||
int roi_batch_ind = rois_cur[0];
|
||||
T roi_x1 = rois_cur[1] * spatial_scale;
|
||||
T roi_y1 = rois_cur[2] * spatial_scale;
|
||||
T roi_x2 = rois_cur[3] * spatial_scale;
|
||||
T roi_y2 = rois_cur[4] * spatial_scale;
|
||||
|
||||
T roi_width = max(roi_x2 - roi_x1, (T)0);
|
||||
T roi_height = max(roi_y2 - roi_y1, (T)0);
|
||||
|
@ -307,7 +307,7 @@ __global__ void prroi_pool_coor_backward_cuda_kernel(
|
|||
T sum_out = bin_size == T(0) ? T(0) : output_grad_val / bin_size;
|
||||
|
||||
// WARNING: to be discussed
|
||||
if (sum_out == 0) return;
|
||||
if (sum_out == 0) continue;
|
||||
|
||||
int start_x, start_y, end_x, end_y;
|
||||
|
||||
|
|
Loading…
Reference in New Issue