diff --git a/csrc/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate_kernel.cu b/csrc/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate_kernel.cu index be70b7582..efb078c43 100644 --- a/csrc/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate_kernel.cu +++ b/csrc/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate_kernel.cu @@ -89,57 +89,55 @@ __global__ void resize_cubic_kernel_torch(const int num_elements, const scalar_t int srcHeight, scalar_t *dst, int dstWidth, int dstHeight, bool align_corners, float height_scale, float width_scale) { - int index = threadIdx.x + blockIdx.x * blockDim.x; - if (index >= num_elements) { - return; - } - // Special case: input and output are the same size, just copy - const int output_x = index % dstWidth; - const int output_y = index / dstWidth; + CUDA_1D_KERNEL_LOOP(index, num_elements) { + // Special case: input and output are the same size, just copy + const int output_x = index % dstWidth; + const int output_y = index / dstWidth; + + if (srcHeight == dstHeight && srcWidth == dstWidth) { + for (int n = 0; n < batchsize; n++) { + for (int c = 0; c < channels; c++) { + const scalar_t val = src[n * channels * dstHeight * dstWidth + c * dstHeight * dstWidth + + output_y * dstWidth + output_x]; + dst[n * channels * dstHeight * dstWidth + c * dstHeight * dstWidth + output_y * dstWidth + + output_x] = val; + } + } + return; + } + // Interpolation kernel + scalar_t real_x = + area_pixel_compute_source_index(width_scale, output_x, align_corners, /*cubic=*/true); + int in_x = floorf(real_x); + scalar_t t_x = real_x - in_x; + + scalar_t real_y = + area_pixel_compute_source_index(height_scale, output_y, align_corners, /*cubic=*/true); + int in_y = floorf(real_y); + scalar_t t_y = real_y - in_y; - if (srcHeight == dstHeight && srcWidth == dstWidth) { for (int n = 0; n < batchsize; n++) { for (int c = 0; c < channels; c++) { - const scalar_t val = src[n * channels * dstHeight * dstWidth + c * dstHeight * dstWidth + - output_y * dstWidth + output_x]; + scalar_t coefficients[4]; + + for (int k = 0; k < 4; k++) { + coefficients[k] = cubic_interp1d( + upsample_get_value_bounded(src, n, c, batchsize, channels, srcHeight, srcWidth, + in_y - 1 + k, in_x - 1), + upsample_get_value_bounded(src, n, c, batchsize, channels, srcHeight, srcWidth, + in_y - 1 + k, in_x + 0), + upsample_get_value_bounded(src, n, c, batchsize, channels, srcHeight, srcWidth, + in_y - 1 + k, in_x + 1), + upsample_get_value_bounded(src, n, c, batchsize, channels, srcHeight, srcWidth, + in_y - 1 + k, in_x + 2), + t_x); + } + dst[n * channels * dstHeight * dstWidth + c * dstHeight * dstWidth + output_y * dstWidth + - output_x] = val; + output_x] = scalar_t(cubic_interp1d(coefficients[0], coefficients[1], coefficients[2], + coefficients[3], t_y)); } } - return; - } - // Interpolation kernel - scalar_t real_x = - area_pixel_compute_source_index(width_scale, output_x, align_corners, /*cubic=*/true); - int in_x = floorf(real_x); - scalar_t t_x = real_x - in_x; - - scalar_t real_y = - area_pixel_compute_source_index(height_scale, output_y, align_corners, /*cubic=*/true); - int in_y = floorf(real_y); - scalar_t t_y = real_y - in_y; - - for (int n = 0; n < batchsize; n++) { - for (int c = 0; c < channels; c++) { - scalar_t coefficients[4]; - - for (int k = 0; k < 4; k++) { - coefficients[k] = cubic_interp1d( - upsample_get_value_bounded(src, n, c, batchsize, channels, srcHeight, srcWidth, - in_y - 1 + k, in_x - 1), - upsample_get_value_bounded(src, n, c, batchsize, channels, srcHeight, srcWidth, - in_y - 1 + k, in_x + 0), - upsample_get_value_bounded(src, n, c, batchsize, channels, srcHeight, srcWidth, - in_y - 1 + k, in_x + 1), - upsample_get_value_bounded(src, n, c, batchsize, channels, srcHeight, srcWidth, - in_y - 1 + k, in_x + 2), - t_x); - } - - dst[n * channels * dstHeight * dstWidth + c * dstHeight * dstWidth + output_y * dstWidth + - output_x] = scalar_t(cubic_interp1d(coefficients[0], coefficients[1], coefficients[2], - coefficients[3], t_y)); - } } }