diff --git a/mmcv/ops/csrc/parrots/roi_align.cpp b/mmcv/ops/csrc/parrots/roi_align.cpp index 47149cac8..3adf2b019 100644 --- a/mmcv/ops/csrc/parrots/roi_align.cpp +++ b/mmcv/ops/csrc/parrots/roi_align.cpp @@ -1,19 +1,90 @@ // Copyright (c) 2018, SenseTime. #include "parrots_cpp_helper.hpp" -void ROIAlignForwardCUDAKernelLauncher(const DArrayLite input, - const DArrayLite rois, DArrayLite output, - DArrayLite argmax_y, DArrayLite argmax_x, - int aligned_height, int aligned_width, - float spatial_scale, int sampling_ratio, - int pool_mode, bool aligned, - cudaStream_t stream); +void ROIAlignForwardCPULauncher(DArrayLite input, DArrayLite rois, + DArrayLite output, DArrayLite argmax_y, + DArrayLite argmax_x, int aligned_height, + int aligned_width, float spatial_scale, + int sampling_ratio, int pool_mode, + bool aligned); + +void ROIAlignBackwardCPULauncher(DArrayLite grad_output, DArrayLite rois, + DArrayLite argmax_y, DArrayLite argmax_x, + DArrayLite grad_input, int aligned_height, + int aligned_width, float spatial_scale, + int sampling_ratio, int pool_mode, + bool aligned); + +void ROIAlignForwardCUDAKernelLauncher(DArrayLite input, DArrayLite rois, + DArrayLite output, DArrayLite argmax_y, + DArrayLite argmax_x, int aligned_height, + int aligned_width, float spatial_scale, + int sampling_ratio, int pool_mode, + bool aligned, cudaStream_t stream); void ROIAlignBackwardCUDAKernelLauncher( - const DArrayLite grad_output, const DArrayLite rois, - const DArrayLite argmax_y, const DArrayLite argmax_x, DArrayLite grad_input, - int aligned_height, int aligned_width, float spatial_scale, - int sampling_ratio, int pool_mode, bool aligned, cudaStream_t stream); + DArrayLite grad_output, DArrayLite rois, DArrayLite argmax_y, + DArrayLite argmax_x, DArrayLite grad_input, int aligned_height, + int aligned_width, float spatial_scale, int sampling_ratio, int pool_mode, + bool aligned, cudaStream_t stream); + +void roi_align_forward_cpu(HostContext& ctx, const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int aligned_height; + int aligned_width; + float spatial_scale; + int sampling_ratio; + int pool_mode; + bool aligned; + SSAttrs(attr) + .get("aligned_height", aligned_height) + .get("aligned_width", aligned_width) + .get("spatial_scale", spatial_scale) + .get("sampling_ratio", sampling_ratio) + .get("pool_mode", pool_mode) + .get("aligned", aligned) + .done(); + + auto& input = ins[0]; + auto& rois = ins[1]; + auto& output = outs[0]; + auto& argmax_y = outs[1]; + auto& argmax_x = outs[2]; + + ROIAlignForwardCPULauncher(input, rois, output, argmax_y, argmax_x, + aligned_height, aligned_width, spatial_scale, + sampling_ratio, pool_mode, aligned); +} + +void roi_align_backward_cpu(HostContext& ctx, const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + int aligned_height; + int aligned_width; + float spatial_scale; + int sampling_ratio; + int pool_mode; + bool aligned; + SSAttrs(attr) + .get("aligned_height", aligned_height) + .get("aligned_width", aligned_width) + .get("spatial_scale", spatial_scale) + .get("sampling_ratio", sampling_ratio) + .get("pool_mode", pool_mode) + .get("aligned", aligned) + .done(); + + auto& grad_output = ins[0]; + auto& rois = ins[1]; + auto& argmax_y = ins[2]; + auto& argmax_x = ins[3]; + auto& grad_input = outs[0]; + + ROIAlignBackwardCPULauncher(grad_output, rois, argmax_y, argmax_x, grad_input, + aligned_height, aligned_width, spatial_scale, + sampling_ratio, pool_mode, aligned); +} void roi_align_forward_cuda(CudaContext& ctx, const SSElement& attr, const OperatorBase::in_list_t& ins, @@ -33,8 +104,8 @@ void roi_align_forward_cuda(CudaContext& ctx, const SSElement& attr, .get("aligned", aligned) .done(); - const auto& input = ins[0]; - const auto& rois = ins[1]; + auto& input = ins[0]; + auto& rois = ins[1]; auto& output = outs[0]; auto& argmax_y = outs[1]; auto& argmax_x = outs[2]; @@ -63,10 +134,10 @@ void roi_align_backward_cuda(CudaContext& ctx, const SSElement& attr, .get("aligned", aligned) .done(); - const auto& grad_output = ins[0]; - const auto& rois = ins[1]; - const auto& argmax_y = ins[2]; - const auto& argmax_x = ins[3]; + auto& grad_output = ins[0]; + auto& rois = ins[1]; + auto& argmax_y = ins[2]; + auto& argmax_x = ins[3]; auto& grad_input = outs[0]; cudaStream_t stream = getStreamNative(ctx.getStream()); @@ -84,7 +155,10 @@ PARROTS_EXTENSION_REGISTER(roi_align_forward) .attr("aligned") .input(2) .output(3) + .apply(roi_align_forward_cpu) +#ifdef PARROTS_USE_CUDA .apply(roi_align_forward_cuda) +#endif .done(); PARROTS_EXTENSION_REGISTER(roi_align_backward) @@ -96,5 +170,8 @@ PARROTS_EXTENSION_REGISTER(roi_align_backward) .attr("aligned") .input(4) .output(1) + .apply(roi_align_backward_cpu) +#ifdef PARROTS_USE_CUDA .apply(roi_align_backward_cuda) +#endif .done(); diff --git a/mmcv/ops/csrc/parrots/roi_align_cpu.cpp b/mmcv/ops/csrc/parrots/roi_align_cpu.cpp new file mode 100644 index 000000000..39d440dff --- /dev/null +++ b/mmcv/ops/csrc/parrots/roi_align_cpu.cpp @@ -0,0 +1,430 @@ +// Modified from +// https://github.com/facebookresearch/detectron2/tree/master/detectron2/layers/csrc/ROIAlign +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +#include + +#include "parrots_cpp_helper.hpp" + +// implementation taken from Caffe2 +template +struct PreCalc { + int pos1; + int pos2; + int pos3; + int pos4; + T w1; + T w2; + T w3; + T w4; +}; + +template +void pre_calc_for_bilinear_interpolate( + const int height, const int width, const int pooled_height, + const int pooled_width, const int iy_upper, const int ix_upper, + T roi_start_h, T roi_start_w, T bin_size_h, T bin_size_w, + int roi_bin_grid_h, int roi_bin_grid_w, std::vector>& pre_calc) { + int pre_calc_index = 0; + for (int ph = 0; ph < pooled_height; ph++) { + for (int pw = 0; pw < pooled_width; pw++) { + for (int iy = 0; iy < iy_upper; iy++) { + const T yy = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < ix_upper; ix++) { + const T xx = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T x = xx; + T y = yy; + // deal with: inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + PreCalc pc; + pc.pos1 = 0; + pc.pos2 = 0; + pc.pos3 = 0; + pc.pos4 = 0; + pc.w1 = 0; + pc.w2 = 0; + pc.w3 = 0; + pc.w4 = 0; + pre_calc[pre_calc_index] = pc; + pre_calc_index += 1; + continue; + } + + if (y <= 0) { + y = 0; + } + if (x <= 0) { + x = 0; + } + + int y_low = (int)y; + int x_low = (int)x; + int y_high; + int x_high; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + // save weights and indices + PreCalc pc; + pc.pos1 = y_low * width + x_low; + pc.pos2 = y_low * width + x_high; + pc.pos3 = y_high * width + x_low; + pc.pos4 = y_high * width + x_high; + pc.w1 = w1; + pc.w2 = w2; + pc.w3 = w3; + pc.w4 = w4; + pre_calc[pre_calc_index] = pc; + + pre_calc_index += 1; + } + } + } + } +} + +template +void ROIAlignForward(const int nthreads, const T* input, const T* rois, + T* output, T* argmax_y, T* argmax_x, + const int pooled_height, const int pooled_width, + const T spatial_scale, const int sampling_ratio, + const int pool_mode, // 0 - max pool, 1 - avg pool + const bool aligned, const int channels, const int height, + const int width) { + int n_rois = nthreads / channels / pooled_width / pooled_height; + // (n, c, ph, pw) is an element in the pooled output + // can be parallelized using omp + // #pragma omp parallel for num_threads(32) + for (int n = 0; n < n_rois; n++) { + int index_n = n * channels * pooled_width * pooled_height; + + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not use rounding; this implementation detail is critical + T offset = aligned ? (T)0.5 : (T)0.0; + T roi_start_w = offset_rois[1] * spatial_scale - offset; + T roi_start_h = offset_rois[2] * spatial_scale - offset; + T roi_end_w = offset_rois[3] * spatial_scale - offset; + T roi_end_h = offset_rois[4] * spatial_scale - offset; + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (aligned) { + PARROTS_CHECKARGS(roi_width >= 0 && roi_height >= 0) + << "ROIs in ROIAlign cannot have non-negative size!"; + } else { // for backward-compatibility only + roi_width = std::max(roi_width, (T)1.); + roi_height = std::max(roi_height, (T)1.); + } + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // When the grid is empty, output zeros == 0/1, instead of NaN. + const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4 + + // we want to precalculate indices and weights shared by all channels, + // this is the key point of optimization + std::vector> pre_calc(roi_bin_grid_h * roi_bin_grid_w * + pooled_width * pooled_height); + pre_calc_for_bilinear_interpolate( + height, width, pooled_height, pooled_width, roi_bin_grid_h, + roi_bin_grid_w, roi_start_h, roi_start_w, bin_size_h, bin_size_w, + roi_bin_grid_h, roi_bin_grid_w, pre_calc); + + for (int c = 0; c < channels; c++) { + int index_n_c = index_n + c * pooled_width * pooled_height; + const T* offset_input = + input + (roi_batch_ind * channels + c) * height * width; + int pre_calc_index = 0; + + for (int ph = 0; ph < pooled_height; ph++) { + for (int pw = 0; pw < pooled_width; pw++) { + int index = index_n_c + ph * pooled_width + pw; + + T output_val = 0.; + T maxval = -10000; + T maxidx_y = -1.f, maxidx_x = -1.f; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + PreCalc pc = pre_calc[pre_calc_index]; + T val = pc.w1 * offset_input[pc.pos1] + + pc.w2 * offset_input[pc.pos2] + + pc.w3 * offset_input[pc.pos3] + + pc.w4 * offset_input[pc.pos4]; + if (val > maxval) { + maxval = val; + maxidx_y = y; + maxidx_x = x; + } + output_val += val; + pre_calc_index += 1; + } + } + if (pool_mode == 0) { + // We do max pooling inside a bin + output[index] = maxval; + argmax_y[index] = maxidx_y; + argmax_x[index] = maxidx_x; + } else if (pool_mode == 1) { + // We do average (integral) pooling inside a bin + output[index] = output_val / count; + } // if + } // for pw + } // for ph + } // for c + } // for n +} + +template +void bilinear_interpolate_gradient(const int height, const int width, T y, T x, + T& w1, T& w2, T& w3, T& w4, int& x_low, + int& x_high, int& y_low, int& y_high, + const int index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + w1 = w2 = w3 = w4 = 0.; + x_low = x_high = y_low = y_high = -1; + return; + } + + if (y <= 0) y = 0; + if (x <= 0) x = 0; + + y_low = (int)y; + x_low = (int)x; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // reference in forward + // T v1 = input[y_low * width + x_low]; + // T v2 = input[y_low * width + x_high]; + // T v3 = input[y_high * width + x_low]; + // T v4 = input[y_high * width + x_high]; + // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + return; +} + +template +inline void add(T* address, const T& val) { + *address += val; +} + +template +void ROIAlignBackward(const int nthreads, const T* grad_output, const T* rois, + const T* argmax_y, const T* argmax_x, T* grad_input, + const int pooled_height, const int pooled_width, + const T spatial_scale, const int sampling_ratio, + const int pool_mode, // 0 - max pool, 1 - avg pool + const bool aligned, const int channels, const int height, + const int width, const int n_stride, const int c_stride, + const int h_stride, const int w_stride) { + for (int index = 0; index < nthreads; index++) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not use rounding; this implementation detail is critical + T offset = aligned ? (T)0.5 : (T)0.0; + T roi_start_w = offset_rois[1] * spatial_scale - offset; + T roi_start_h = offset_rois[2] * spatial_scale - offset; + T roi_end_w = offset_rois[3] * spatial_scale - offset; + T roi_end_h = offset_rois[4] * spatial_scale - offset; + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (aligned) { + PARROTS_CHECKARGS(roi_width >= 0 && roi_height >= 0) + << "ROIs in ROIAlign do not have non-negative size!"; + } else { // for backward-compatibility only + roi_width = std::max(roi_width, (T)1.); + roi_height = std::max(roi_height, (T)1.); + } + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + T* offset_grad_input = + grad_input + ((roi_batch_ind * channels + c) * height * width); + + int output_offset = n * n_stride + c * c_stride; + const T* offset_grad_output = grad_output + output_offset; + const T grad_output_this_bin = + offset_grad_output[ph * h_stride + pw * w_stride]; + + if (pool_mode == 0) { + // We do max pooling inside a bin + T y = argmax_y[index], x = argmax_x[index]; + if (y != -1.f) { + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4, + x_low, x_high, y_low, y_high, index); + + T g1 = grad_output_this_bin * w1; + T g2 = grad_output_this_bin * w2; + T g3 = grad_output_this_bin * w3; + T g4 = grad_output_this_bin * w4; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + // atomic add is not needed for now since it is single threaded + add(offset_grad_input + y_low * width + x_low, static_cast(g1)); + add(offset_grad_input + y_low * width + x_high, static_cast(g2)); + add(offset_grad_input + y_high * width + x_low, static_cast(g3)); + add(offset_grad_input + y_high * width + x_high, static_cast(g4)); + } // if + } // mode + } else if (pool_mode == 1) { + // We do average (integral) pooling inside a bin + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_width / pooled_width); + + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4, + x_low, x_high, y_low, y_high, index); + + T g1 = grad_output_this_bin * w1 / count; + T g2 = grad_output_this_bin * w2 / count; + T g3 = grad_output_this_bin * w3 / count; + T g4 = grad_output_this_bin * w4 / count; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + // atomic add is not needed for now since it is single threaded + add(offset_grad_input + y_low * width + x_low, static_cast(g1)); + add(offset_grad_input + y_low * width + x_high, static_cast(g2)); + add(offset_grad_input + y_high * width + x_low, static_cast(g3)); + add(offset_grad_input + y_high * width + x_high, + static_cast(g4)); + } // if + } // ix + } // iy + } // mode + } // for +} // ROIAlignBackward + +void ROIAlignForwardCPULauncher(DArrayLite input, DArrayLite rois, + DArrayLite output, DArrayLite argmax_y, + DArrayLite argmax_x, int aligned_height, + int aligned_width, float spatial_scale, + int sampling_ratio, int pool_mode, + bool aligned) { + int output_size = output.size(); + int channels = input.dim(1); + int height = input.dim(2); + int width = input.dim(3); + + PARROTS_DISPATCH_FLOATING_TYPES_AND_HALF( + input.elemType().prim(), ([&] { + ROIAlignForward( + output_size, input.ptr(), rois.ptr(), + output.ptr(), argmax_y.ptr(), + argmax_x.ptr(), aligned_height, aligned_width, + static_cast(spatial_scale), sampling_ratio, pool_mode, + aligned, channels, height, width); + })); +} + +void ROIAlignBackwardCPULauncher(DArrayLite grad_output, DArrayLite rois, + DArrayLite argmax_y, DArrayLite argmax_x, + DArrayLite grad_input, int aligned_height, + int aligned_width, float spatial_scale, + int sampling_ratio, int pool_mode, + bool aligned) { + int output_size = grad_output.size(); + int channels = grad_input.dim(1); + int height = grad_input.dim(2); + int width = grad_input.dim(3); + + // get stride values to ensure indexing into gradients is correct. + int n_stride = grad_output.stride(0); + int c_stride = grad_output.stride(1); + int h_stride = grad_output.stride(2); + int w_stride = grad_output.stride(3); + + PARROTS_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_output.elemType().prim(), ([&] { + ROIAlignBackward( + output_size, grad_output.ptr(), rois.ptr(), + argmax_y.ptr(), argmax_x.ptr(), + grad_input.ptr(), aligned_height, aligned_width, + static_cast(spatial_scale), sampling_ratio, pool_mode, + aligned, channels, height, width, n_stride, c_stride, h_stride, + w_stride); + })); +} diff --git a/mmcv/ops/csrc/parrots/roi_align_cuda.cu b/mmcv/ops/csrc/parrots/roi_align_cuda.cu index 2c3983111..05eb36d2c 100644 --- a/mmcv/ops/csrc/parrots/roi_align_cuda.cu +++ b/mmcv/ops/csrc/parrots/roi_align_cuda.cu @@ -1,13 +1,12 @@ #include "parrots_cuda_helper.hpp" #include "roi_align_cuda_kernel.cuh" -void ROIAlignForwardCUDAKernelLauncher(const DArrayLite input, - const DArrayLite rois, DArrayLite output, - DArrayLite argmax_y, DArrayLite argmax_x, - int aligned_height, int aligned_width, - float spatial_scale, int sampling_ratio, - int pool_mode, bool aligned, - cudaStream_t stream) { +void ROIAlignForwardCUDAKernelLauncher(DArrayLite input, DArrayLite rois, + DArrayLite output, DArrayLite argmax_y, + DArrayLite argmax_x, int aligned_height, + int aligned_width, float spatial_scale, + int sampling_ratio, int pool_mode, + bool aligned, cudaStream_t stream) { int output_size = output.size(); int channels = input.dim(1); int height = input.dim(2); @@ -20,18 +19,18 @@ void ROIAlignForwardCUDAKernelLauncher(const DArrayLite input, output_size, input.ptr(), rois.ptr(), output.ptr(), argmax_y.ptr(), argmax_x.ptr(), aligned_height, aligned_width, - spatial_scale, sampling_ratio, pool_mode, aligned, channels, - height, width); + static_cast(spatial_scale), sampling_ratio, pool_mode, + aligned, channels, height, width); })); PARROTS_CUDA_CHECK(cudaGetLastError()); } void ROIAlignBackwardCUDAKernelLauncher( - const DArrayLite grad_output, const DArrayLite rois, - const DArrayLite argmax_y, const DArrayLite argmax_x, DArrayLite grad_input, - int aligned_height, int aligned_width, float spatial_scale, - int sampling_ratio, int pool_mode, bool aligned, cudaStream_t stream) { + DArrayLite grad_output, DArrayLite rois, DArrayLite argmax_y, + DArrayLite argmax_x, DArrayLite grad_input, int aligned_height, + int aligned_width, float spatial_scale, int sampling_ratio, int pool_mode, + bool aligned, cudaStream_t stream) { int output_size = grad_output.size(); int channels = grad_input.dim(1); int height = grad_input.dim(2); @@ -44,8 +43,8 @@ void ROIAlignBackwardCUDAKernelLauncher( output_size, grad_output.ptr(), rois.ptr(), argmax_y.ptr(), argmax_x.ptr(), grad_input.ptr(), aligned_height, aligned_width, - spatial_scale, sampling_ratio, pool_mode, aligned, channels, - height, width); + static_cast(spatial_scale), sampling_ratio, pool_mode, + aligned, channels, height, width); })); PARROTS_CUDA_CHECK(cudaGetLastError()); diff --git a/mmcv/ops/csrc/parrots_cpp_helper.hpp b/mmcv/ops/csrc/parrots_cpp_helper.hpp index 5ff2e8f1c..72701890d 100644 --- a/mmcv/ops/csrc/parrots_cpp_helper.hpp +++ b/mmcv/ops/csrc/parrots_cpp_helper.hpp @@ -8,4 +8,33 @@ using namespace parrots; +#define PARROTS_PRIVATE_CASE_TYPE(prim_type, type, ...) \ + case prim_type: { \ + using scalar_t = type; \ + return __VA_ARGS__(); \ + } + +#define PARROTS_DISPATCH_FLOATING_TYPES(TYPE, ...) \ + [&] { \ + const auto& the_type = TYPE; \ + switch (the_type) { \ + PARROTS_PRIVATE_CASE_TYPE(Prim::Float64, double, __VA_ARGS__) \ + PARROTS_PRIVATE_CASE_TYPE(Prim::Float32, float, __VA_ARGS__) \ + default: \ + PARROTS_NOTSUPPORTED; \ + } \ + }() + +#define PARROTS_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, ...) \ + [&] { \ + const auto& the_type = TYPE; \ + switch (the_type) { \ + PARROTS_PRIVATE_CASE_TYPE(Prim::Float64, double, __VA_ARGS__) \ + PARROTS_PRIVATE_CASE_TYPE(Prim::Float32, float, __VA_ARGS__) \ + PARROTS_PRIVATE_CASE_TYPE(Prim::Float16, float16, __VA_ARGS__) \ + default: \ + PARROTS_NOTSUPPORTED; \ + } \ + }() + #endif // PARROTS_CPP_HELPER diff --git a/tests/test_ops/test_roi_align.py b/tests/test_ops/test_roi_align.py index a94e6f988..db7c03740 100644 --- a/tests/test_ops/test_roi_align.py +++ b/tests/test_ops/test_roi_align.py @@ -56,7 +56,11 @@ def _test_roialign_gradcheck(device, dtype): froipool = RoIAlign((pool_h, pool_w), spatial_scale, sampling_ratio) - gradcheck(froipool, (x, rois), eps=1e-5, atol=1e-5) + if torch.__version__ == 'parrots': + gradcheck( + froipool, (x, rois), no_grads=[rois], delta=1e-5, pt_atol=1e-5) + else: + gradcheck(froipool, (x, rois), eps=1e-5, atol=1e-5) def _test_roialign_allclose(device, dtype):