From 1f5e6704217dfd7c581e49bb869e28b6c0147b44 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Mon, 14 Mar 2022 10:26:27 +0800 Subject: [PATCH] [Enhancement] Optimize multilevel roi align (#167) * optimize multilevel roi align * add pool mode --- .../tensorrt/common/common_cuda_helper.hpp | 51 +++---- .../trt_multi_level_roi_align.cpp | 26 ++-- .../trt_multi_level_roi_align.hpp | 7 +- .../trt_multi_level_roi_align_kernel.cu | 140 +++++++++++------- .../trt_multi_level_roi_align_kernel.hpp | 2 +- .../roi_heads/single_level_roi_extractor.py | 20 ++- tests/test_ops/test_ops.py | 29 +++- 7 files changed, 160 insertions(+), 115 deletions(-) diff --git a/csrc/backend_ops/tensorrt/common/common_cuda_helper.hpp b/csrc/backend_ops/tensorrt/common/common_cuda_helper.hpp index 02c57c62e..920a636fd 100644 --- a/csrc/backend_ops/tensorrt/common/common_cuda_helper.hpp +++ b/csrc/backend_ops/tensorrt/common/common_cuda_helper.hpp @@ -49,44 +49,31 @@ cublasStatus_t cublasGemmWrap(cublasHandle_t handle, cublasOperation_t transa, const scalar_t* beta, scalar_t* C, int ldc); template -__device__ scalar_t bilinear_interpolate(const scalar_t* input, const int height, const int width, - scalar_t y, scalar_t x) { +__device__ __forceinline__ scalar_t bilinear_interpolate(const scalar_t* __restrict__ input, + const int height, const int width, + scalar_t y, scalar_t x) { // deal with cases that inverse elements are out of feature map boundary if (y < -1.0 || y > height || x < -1.0 || x > width) return 0; - if (y <= 0) y = 0; - if (x <= 0) x = 0; + y = min(scalar_t(height - 1), max(scalar_t(0), y)); + x = min(scalar_t(width - 1), max(scalar_t(0), x)); - int y_low = (int)y; - int x_low = (int)x; - int y_high; - int x_high; + const int y_low = floor(y); + const int x_low = floor(x); + const int y_high = ceil(y); + const int x_high = ceil(x); - if (y_low >= height - 1) { - y_high = y_low = height - 1; - y = (scalar_t)y_low; - } else { - y_high = y_low + 1; - } + const scalar_t v1 = input[y_low * width + x_low]; + const scalar_t v2 = input[y_low * width + x_high]; + const scalar_t v3 = input[y_high * width + x_low]; + const scalar_t v4 = input[y_high * width + x_high]; - if (x_low >= width - 1) { - x_high = x_low = width - 1; - x = (scalar_t)x_low; - } else { - x_high = x_low + 1; - } - - scalar_t ly = y - y_low; - scalar_t lx = x - x_low; - scalar_t hy = 1. - ly, hx = 1. - lx; - // do bilinear interpolation - scalar_t v1 = input[y_low * width + x_low]; - scalar_t v2 = input[y_low * width + x_high]; - scalar_t v3 = input[y_high * width + x_low]; - scalar_t v4 = input[y_high * width + x_high]; - scalar_t w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; - - scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + // lerp can be performed by fma + const scalar_t ly = y - y_low; + const scalar_t lx = x - x_low; + const scalar_t v_low = fma(v2 - v1, lx, v1); + const scalar_t v_high = fma(v4 - v3, lx, v3); + const scalar_t val = fma(v_high - v_low, ly, v_low); return val; } diff --git a/csrc/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align.cpp b/csrc/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align.cpp index 256471c1a..823efbccf 100644 --- a/csrc/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align.cpp +++ b/csrc/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align.cpp @@ -16,12 +16,13 @@ static const char *PLUGIN_NAME{"MMCVMultiLevelRoiAlign"}; } // namespace TRTMultiLevelRoiAlign::TRTMultiLevelRoiAlign(const std::string &name, int alignedHeight, - int alignedWidth, int sampleNum, + int alignedWidth, int poolMode, int sampleNum, const std::vector &featmapStrides, float roiScaleFactor, int finestScale, bool aligned) : TRTPluginBase(name), mAlignedHeight(alignedHeight), mAlignedWidth(alignedWidth), + mPoolMode(poolMode), mSampleNum(sampleNum), mFeatmapStrides(featmapStrides), mRoiScaleFactor(roiScaleFactor), @@ -33,6 +34,7 @@ TRTMultiLevelRoiAlign::TRTMultiLevelRoiAlign(const std::string name, const void : TRTPluginBase(name) { deserialize_value(&data, &length, &mAlignedHeight); deserialize_value(&data, &length, &mAlignedWidth); + deserialize_value(&data, &length, &mPoolMode); deserialize_value(&data, &length, &mSampleNum); deserialize_value(&data, &length, &mRoiScaleFactor); deserialize_value(&data, &length, &mFinestScale); @@ -42,7 +44,7 @@ TRTMultiLevelRoiAlign::TRTMultiLevelRoiAlign(const std::string name, const void nvinfer1::IPluginV2DynamicExt *TRTMultiLevelRoiAlign::clone() const TRT_NOEXCEPT { TRTMultiLevelRoiAlign *plugin = - new TRTMultiLevelRoiAlign(mLayerName, mAlignedHeight, mAlignedWidth, mSampleNum, + new TRTMultiLevelRoiAlign(mLayerName, mAlignedHeight, mAlignedWidth, mPoolMode, mSampleNum, mFeatmapStrides, mRoiScaleFactor, mFinestScale, mAligned); plugin->setPluginNamespace(getPluginNamespace()); @@ -113,8 +115,8 @@ int TRTMultiLevelRoiAlign::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, multi_level_roi_align((float *)outputs[0], (const float *)rois, num_rois, feats, num_feats, batch_size, channels, &heights[0], &widths[0], &strides[0], - mAlignedHeight, mAlignedWidth, mSampleNum, mRoiScaleFactor, - mFinestScale, mAligned, stream); + mAlignedHeight, mAlignedWidth, mPoolMode, mSampleNum, + mRoiScaleFactor, mFinestScale, mAligned, stream); return 0; } @@ -134,7 +136,7 @@ int TRTMultiLevelRoiAlign::getNbOutputs() const TRT_NOEXCEPT { return 1; } size_t TRTMultiLevelRoiAlign::getSerializationSize() const TRT_NOEXCEPT { return serialized_size(mFeatmapStrides) + serialized_size(mAlignedHeight) + - serialized_size(mAlignedWidth) + serialized_size(mSampleNum) + + serialized_size(mAlignedWidth) + serialized_size(mPoolMode) + serialized_size(mSampleNum) + serialized_size(mRoiScaleFactor) + serialized_size(mFinestScale) + serialized_size(mAligned); } @@ -142,6 +144,7 @@ size_t TRTMultiLevelRoiAlign::getSerializationSize() const TRT_NOEXCEPT { void TRTMultiLevelRoiAlign::serialize(void *buffer) const TRT_NOEXCEPT { serialize_value(&buffer, mAlignedHeight); serialize_value(&buffer, mAlignedWidth); + serialize_value(&buffer, mPoolMode); serialize_value(&buffer, mSampleNum); serialize_value(&buffer, mRoiScaleFactor); serialize_value(&buffer, mFinestScale); @@ -152,9 +155,9 @@ void TRTMultiLevelRoiAlign::serialize(void *buffer) const TRT_NOEXCEPT { TRTMultiLevelRoiAlignCreator::TRTMultiLevelRoiAlignCreator() { mPluginAttributes = std::vector( {nvinfer1::PluginField("output_height"), nvinfer1::PluginField("output_width"), - nvinfer1::PluginField("sampling_ratio"), nvinfer1::PluginField("featmap_strides"), - nvinfer1::PluginField("roi_scale_factor"), nvinfer1::PluginField("finest_scale"), - nvinfer1::PluginField("aligned")}); + nvinfer1::PluginField("pool_mode"), nvinfer1::PluginField("sampling_ratio"), + nvinfer1::PluginField("featmap_strides"), nvinfer1::PluginField("roi_scale_factor"), + nvinfer1::PluginField("finest_scale"), nvinfer1::PluginField("aligned")}); mFC.nbFields = mPluginAttributes.size(); mFC.fields = mPluginAttributes.data(); } @@ -169,6 +172,7 @@ nvinfer1::IPluginV2 *TRTMultiLevelRoiAlignCreator::createPlugin( const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { int alignedHeight = 7; int alignedWidth = 7; + int poolMode = 0; int sampleNum = 2; std::vector featmapStrides; float roiScaleFactor = -1; @@ -185,6 +189,8 @@ nvinfer1::IPluginV2 *TRTMultiLevelRoiAlignCreator::createPlugin( alignedHeight = static_cast(fc->fields[i].data)[0]; } else if (field_name.compare("output_width") == 0) { alignedWidth = static_cast(fc->fields[i].data)[0]; + } else if (field_name.compare("pool_mode") == 0) { + poolMode = static_cast(fc->fields[i].data)[0]; } else if (field_name.compare("sampling_ratio") == 0) { sampleNum = static_cast(fc->fields[i].data)[0]; } else if (field_name.compare("roi_scale_factor") == 0) { @@ -204,8 +210,8 @@ nvinfer1::IPluginV2 *TRTMultiLevelRoiAlignCreator::createPlugin( ASSERT(featmapStrides.size() != 0); TRTMultiLevelRoiAlign *plugin = - new TRTMultiLevelRoiAlign(name, alignedHeight, alignedWidth, sampleNum, featmapStrides, - roiScaleFactor, finestScale, aligned); + new TRTMultiLevelRoiAlign(name, alignedHeight, alignedWidth, poolMode, sampleNum, + featmapStrides, roiScaleFactor, finestScale, aligned); plugin->setPluginNamespace(getPluginNamespace()); return plugin; } diff --git a/csrc/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align.hpp b/csrc/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align.hpp index 2b13a1473..a9a06236e 100644 --- a/csrc/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align.hpp +++ b/csrc/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align.hpp @@ -13,9 +13,9 @@ namespace mmdeploy { class TRTMultiLevelRoiAlign : public TRTPluginBase { public: - TRTMultiLevelRoiAlign(const std::string &name, int alignedHeight, int alignedWidth, int sampleNum, - const std::vector &featmapStrides, float roiScaleFactor = -1, - int finestScale = 56, bool aligned = false); + TRTMultiLevelRoiAlign(const std::string &name, int alignedHeight, int alignedWidth, int poolMode, + int sampleNum, const std::vector &featmapStrides, + float roiScaleFactor = -1, int finestScale = 56, bool aligned = false); TRTMultiLevelRoiAlign(const std::string name, const void *data, size_t length); @@ -52,6 +52,7 @@ class TRTMultiLevelRoiAlign : public TRTPluginBase { private: int mAlignedHeight; int mAlignedWidth; + int mPoolMode; int mSampleNum; std::vector mFeatmapStrides; float mRoiScaleFactor; diff --git a/csrc/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align_kernel.cu b/csrc/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align_kernel.cu index 3beeb9f3d..9eefbe3f3 100644 --- a/csrc/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align_kernel.cu +++ b/csrc/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align_kernel.cu @@ -1,4 +1,5 @@ // Copyright (c) OpenMMLab. All rights reserved. +#include #include #include @@ -19,21 +20,17 @@ struct FeatData { int num_featmap; }; -template -__device__ scalar_t roi_align_single(const scalar_t *bottom_data, const int roi_batch_ind, - const scalar_t roi_start_w, const scalar_t roi_start_h, - const scalar_t roi_end_w, const scalar_t roi_end_h, - const scalar_t spatial_scale, const int pw, const int ph, - const int c, const int sample_num, const int channels, - const int height, const int width, const int pooled_height, - const int pooled_width, const bool aligned) { +template +__device__ scalar_t roi_align_single(const scalar_t *__restrict__ bottom_data, + const int roi_batch_ind, const scalar_t roi_start_w, + const scalar_t roi_start_h, const scalar_t roi_end_w, + const scalar_t roi_end_h, const scalar_t spatial_scale, + const int pw, const int ph, const int c, const int sample_num, + const int channels, const int height, const int width, + const int pooled_height, const int pooled_width) { // Force malformed ROIs to be 1x1 - scalar_t roi_width = fmaxf((scalar_t)roi_end_w - (scalar_t)roi_start_w, 0.); - scalar_t roi_height = fmaxf((scalar_t)roi_end_h - (scalar_t)roi_start_h, 0.); - if (!aligned) { - roi_width = max(roi_width, (scalar_t)1.); - roi_height = max(roi_height, (scalar_t)1.); - } + scalar_t roi_width = max(roi_end_w - roi_start_w, (scalar_t)(aligned ? 0. : 1.)); + scalar_t roi_height = max(roi_end_h - roi_start_h, (scalar_t)(aligned ? 0. : 1.)); const scalar_t bin_size_h = roi_height / pooled_height; const scalar_t bin_size_w = roi_width / pooled_width; @@ -41,39 +38,49 @@ __device__ scalar_t roi_align_single(const scalar_t *bottom_data, const int roi_ const scalar_t *offset_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width; - int sample_num_h = (sample_num > 0) ? sample_num : ceil(roi_height / pooled_height); // e.g., = 2 - int sample_num_w = (sample_num > 0) ? sample_num : ceil(roi_width / pooled_width); + const int sample_num_h = (sample_num > 0) ? sample_num : ceil(roi_height / pooled_height); + const int sample_num_w = (sample_num > 0) ? sample_num : ceil(roi_width / pooled_width); - scalar_t output_val = 0; -#pragma unroll + scalar_t output_val = (pool_mode == 0) ? -FLT_MAX : 0; + const scalar_t y_offset = roi_start_h + ph * bin_size_h; + const scalar_t y_scale = bin_size_h / (scalar_t)(sample_num_h); + const scalar_t x_offset = roi_start_w + pw * bin_size_w; + const scalar_t x_scale = bin_size_w / (scalar_t)(sample_num_w); for (int iy = 0; iy < sample_num_h; iy++) { - const scalar_t y = roi_start_h + ph * bin_size_h + - (scalar_t)(iy + scalar_t(.5f)) * bin_size_h / (scalar_t)(sample_num_h); -#pragma unroll + const scalar_t y = fma(scalar_t(iy) + scalar_t(.5f), y_scale, y_offset); for (int ix = 0; ix < sample_num_w; ix++) { - const scalar_t x = roi_start_w + pw * bin_size_w + - (scalar_t)(ix + scalar_t(.5f)) * bin_size_w / (scalar_t)(sample_num_w); + const scalar_t x = fma(scalar_t(ix) + scalar_t(.5f), x_scale, x_offset); scalar_t val = bilinear_interpolate(offset_bottom_data, height, width, y, x); - output_val += val; + if (pool_mode == 0) { + output_val = max(output_val, val); + } else { + output_val += val; + } } } - output_val /= max(sample_num_h * sample_num_w, 1); + if (pool_mode != 0) { + output_val /= max(sample_num_h * sample_num_w, 1); + } return output_val; } -template -__global__ void roi_extractor_kernel(scalar_t *output, const scalar_t *bottom_rois, - FeatData feat_data, const int sample_num, +template +__global__ void roi_extractor_kernel(scalar_t *__restrict__ output, + const scalar_t *__restrict__ bottom_rois, FeatData feat_data, + const int pool_mode, const int sample_num, const float roi_scale_factor, const int finest_scale, const int pooled_height, const int pooled_width, - const bool aligned, int nThreads) { + int nThreads) { CUDA_1D_KERNEL_LOOP(index, nThreads) { const int channels = feat_data.channels; - const int pw = index % pooled_width; - const int ph = (index / pooled_width) % pooled_height; - const int c = (index / pooled_width / pooled_height) % channels; - const int n = index / pooled_width / pooled_height / channels; + int tmp_index = index; + const int pw = tmp_index % pooled_width; + tmp_index /= pooled_width; + const int ph = tmp_index % pooled_height; + tmp_index /= pooled_height; + const int c = tmp_index % channels; + const int n = tmp_index / channels; const scalar_t *offset_bottom_rois = bottom_rois + n * 5; @@ -84,19 +91,23 @@ __global__ void roi_extractor_kernel(scalar_t *output, const scalar_t *bottom_ro const scalar_t scale = sqrtf((roi_offset_y1 - roi_offset_y0) * (roi_offset_x1 - roi_offset_x0)); - const int target_lvls = fminf(feat_data.num_featmap - 1, - fmaxf(0, floorf(log2f(scale / (scalar_t)(finest_scale) + 1e-6)))); + const int target_lvls = + min(feat_data.num_featmap - 1, + max(0, int(floorf(log2f(scale / (scalar_t)(finest_scale) + 1e-6))))); if (roi_scale_factor > 0.) { const scalar_t roi_off_cx = (roi_offset_x0 + roi_offset_x1) * 0.5; const scalar_t roi_off_cy = (roi_offset_y0 + roi_offset_y1) * 0.5; - const scalar_t roi_off_w = (roi_offset_x1 - roi_offset_x0 + 1) * roi_scale_factor; - const scalar_t roi_off_h = (roi_offset_y1 - roi_offset_y0 + 1) * roi_scale_factor; + const scalar_t half_scale_factor = roi_scale_factor * 0.5; + const scalar_t half_roi_off_w = + fma(roi_offset_x1 - roi_offset_x0 + 1, half_scale_factor, scalar_t(-0.5)); + const scalar_t half_roi_off_h = + fma(roi_offset_y1 - roi_offset_y0 + 1, half_scale_factor, scalar_t(-0.5)); - roi_offset_x0 = roi_off_cx - roi_off_w * 0.5 + 0.5; - roi_offset_x1 = roi_off_cx + roi_off_w * 0.5 - 0.5; - roi_offset_y0 = roi_off_cy - roi_off_h * 0.5 + 0.5; - roi_offset_y1 = roi_off_cy + roi_off_h * 0.5 - 0.5; + roi_offset_x0 = roi_off_cx - half_roi_off_w; + roi_offset_x1 = roi_off_cx + half_roi_off_w; + roi_offset_y0 = roi_off_cy - half_roi_off_h; + roi_offset_y1 = roi_off_cy + half_roi_off_h; } const scalar_t spatial_scale = (scalar_t)feat_data.spatial_scale[target_lvls]; @@ -105,24 +116,34 @@ __global__ void roi_extractor_kernel(scalar_t *output, const scalar_t *bottom_ro const scalar_t *bottom_data = (scalar_t *)feat_data.data[target_lvls]; const int roi_batch_ind = offset_bottom_rois[0]; - const scalar_t offset = aligned ? (scalar_t)0.5 : (scalar_t)0.0; - const scalar_t roi_start_w = roi_offset_x0 * spatial_scale - offset; - const scalar_t roi_start_h = roi_offset_y0 * spatial_scale - offset; - const scalar_t roi_end_w = (roi_offset_x1)*spatial_scale - offset; - const scalar_t roi_end_h = (roi_offset_y1)*spatial_scale - offset; + const scalar_t offset = aligned ? (scalar_t)-0.5 : (scalar_t)0.0; + const scalar_t roi_start_w = + fma(roi_offset_x0, spatial_scale, offset); // roi_offset_x0 * spatial_scale + offset; + const scalar_t roi_start_h = + fma(roi_offset_y0, spatial_scale, offset); // roi_offset_y0 * spatial_scale + offset; + const scalar_t roi_end_w = + fma(roi_offset_x1, spatial_scale, offset); // (roi_offset_x1) * spatial_scale - offset; + const scalar_t roi_end_h = + fma(roi_offset_y1, spatial_scale, offset); // (roi_offset_y1)*spatial_scale - offset; - const scalar_t output_val = roi_align_single( - bottom_data, roi_batch_ind, roi_start_w, roi_start_h, roi_end_w, roi_end_h, spatial_scale, - pw, ph, c, sample_num, channels, height, width, pooled_height, pooled_width, aligned); - - output[index] = output_val; + if (pool_mode == 0) { + const scalar_t output_val = roi_align_single( + bottom_data, roi_batch_ind, roi_start_w, roi_start_h, roi_end_w, roi_end_h, spatial_scale, + pw, ph, c, sample_num, channels, height, width, pooled_height, pooled_width); + output[index] = output_val; + } else { + const scalar_t output_val = roi_align_single( + bottom_data, roi_batch_ind, roi_start_w, roi_start_h, roi_end_w, roi_end_h, spatial_scale, + pw, ph, c, sample_num, channels, height, width, pooled_height, pooled_width); + output[index] = output_val; + } } } template void multi_level_roi_align(T *output, const T *rois, int num_rois, const void *const *feats, int num_feats, int n, int c, int *h, int *w, float *strides, - int aligned_height, int aligned_width, int sample_num, + int aligned_height, int aligned_width, int pool_mode, int sample_num, float roi_scale_factor, int finest_scale, bool aligned, cudaStream_t stream) { FeatData feat_data; @@ -136,15 +157,20 @@ void multi_level_roi_align(T *output, const T *rois, int num_rois, const void *c feat_data.spatial_scale[i] = 1. / float(strides[i]); } int nThreads = num_rois * c * aligned_height * aligned_width; - // bool aligned = true; - roi_extractor_kernel<<>>( - output, rois, feat_data, sample_num, roi_scale_factor, finest_scale, aligned_height, - aligned_width, aligned, nThreads); + if (aligned) { + roi_extractor_kernel<<>>( + output, rois, feat_data, pool_mode, sample_num, roi_scale_factor, finest_scale, + aligned_height, aligned_width, nThreads); + } else { + roi_extractor_kernel<<>>( + output, rois, feat_data, pool_mode, sample_num, roi_scale_factor, finest_scale, + aligned_height, aligned_width, nThreads); + } } template void multi_level_roi_align(float *output, const float *rois, int num_rois, const void *const *feats, int num_feats, int n, int c, int *h, int *w, float *strides, int aligned_height, - int aligned_width, int sample_num, + int aligned_width, int pool_mode, int sample_num, float roi_scale_factor, int finest_scale, bool aligned, cudaStream_t stream); diff --git a/csrc/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align_kernel.hpp b/csrc/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align_kernel.hpp index 872d203d9..5f7220dbf 100644 --- a/csrc/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align_kernel.hpp +++ b/csrc/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align_kernel.hpp @@ -6,7 +6,7 @@ template void multi_level_roi_align(T *output, const T *rois, int num_rois, const void *const *feats, int num_feats, int n, int c, int *h, int *w, float *strides, - int aligned_height, int aligned_width, int sample_num, + int aligned_height, int aligned_width, int pool_mode, int sample_num, float roi_scale_factor, int finest_scale, bool aligned, cudaStream_t stream); diff --git a/mmdeploy/codebase/mmdet/models/roi_heads/single_level_roi_extractor.py b/mmdeploy/codebase/mmdet/models/roi_heads/single_level_roi_extractor.py index c997a5ab6..f87bdeabe 100644 --- a/mmdeploy/codebase/mmdet/models/roi_heads/single_level_roi_extractor.py +++ b/mmdeploy/codebase/mmdet/models/roi_heads/single_level_roi_extractor.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch +from mmcv.ops import RoIAlign from torch.autograd import Function from mmdeploy.core.optimizers import mark @@ -24,7 +25,9 @@ class MultiLevelRoiAlign(Function): finest_scale = args[-3] roi_scale_factor = args[-4] sampling_ratio = args[-5] - output_size = args[-6] + pool_mode = args[-6] + pool_mode_flag = 0 if pool_mode == 'max' else 1 + output_size = args[-7] inputs = args[:len(featmap_strides)] rois = args[len(featmap_strides)] return g.op( @@ -33,6 +36,7 @@ class MultiLevelRoiAlign(Function): *inputs, output_height_i=output_size[1], output_width_i=output_size[0], + pool_mode_i=pool_mode_flag, sampling_ratio_i=sampling_ratio, roi_scale_factor_f=roi_scale_factor, finest_scale_i=finest_scale, @@ -47,7 +51,7 @@ class MultiLevelRoiAlign(Function): # finest_scale = args[-3] # roi_scale_factor = args[-4] # sampling_ratio = args[-5] - output_size = args[-6] + output_size = args[-7] inputs = args[:len(featmap_strides)] rois = args[len(featmap_strides)] @@ -75,17 +79,23 @@ def single_roi_extractor__forward__tensorrt(ctx, featmap_strides = self.featmap_strides finest_scale = self.finest_scale + for roi_layer in self.roi_layers: + assert isinstance( + roi_layer, + RoIAlign), f'{type(roi_layer)} is not supported in TensorRT.' + roi_layer = self.roi_layers[0] out_size = roi_layer.output_size sampling_ratio = roi_layer.sampling_ratio + pool_mode = roi_layer.pool_mode aligned = roi_layer.aligned if roi_scale_factor is None: roi_scale_factor = 1.0 featmap_strides = [float(s) for s in featmap_strides] - return MultiLevelRoiAlign.apply(*feats, rois, out_size, sampling_ratio, - roi_scale_factor, finest_scale, - featmap_strides, aligned) + return MultiLevelRoiAlign.apply(*feats, rois, out_size, pool_mode, + sampling_ratio, roi_scale_factor, + finest_scale, featmap_strides, aligned) @FUNCTION_REWRITER.register_rewriter( diff --git a/tests/test_ops/test_ops.py b/tests/test_ops/test_ops.py index 60143d878..22a4640d6 100644 --- a/tests/test_ops/test_ops.py +++ b/tests/test_ops/test_ops.py @@ -334,11 +334,14 @@ def test_batched_nms(backend, @pytest.mark.parametrize('backend', [TEST_TENSORRT]) -@pytest.mark.parametrize('out_size, sampling_ratio,roi_scale_factor,' - ' finest_scale,featmap_strides, aligned', - [(tuple([2, 2]), 2, 1.0, 2, list([2.0, 4.0]), 1)]) +@pytest.mark.parametrize( + 'out_size, pool_mode, sampling_ratio,roi_scale_factor,' + ' finest_scale,featmap_strides, aligned', + [(tuple([2, 2]), 0, 2, 1.0, 2, list([2.0, 4.0]), 1), + (tuple([2, 2]), 1, 2, 1.0, 2, list([2.0, 4.0]), 1)]) def test_multi_level_roi_align(backend, out_size, + pool_mode, sampling_ratio, roi_scale_factor, finest_scale, @@ -376,10 +379,21 @@ def test_multi_level_roi_align(backend, [0.9178, 0.7282, 0.0291, 0.3028]]]]) ] rois = torch.tensor([[0., 0., 0., 4., 4.]]) - expected_result = torch.tensor([[[[0.1939, 0.3950], [0.3437, 0.4543]], - [[0.0778, 0.1641], [0.1305, 0.2301]], - [[0.1542, 0.2413], [0.2094, - 0.2688]]]]) + if pool_mode == 1: + expected_result = torch.tensor([[[[0.1939, 0.3950], + [0.3437, 0.4543]], + [[0.0778, 0.1641], + [0.1305, 0.2301]], + [[0.1542, 0.2413], + [0.2094, 0.2688]]]]) + else: + expected_result = torch.tensor([[[[0.1939, 0.4956], + [0.4185, 0.5167]], + [[0.0778, 0.2073], + [0.1569, 0.3162]], + [[0.1542, 0.2849], + [0.2370, 0.3053]]]]) + else: input = input_list[0] rois = input_list[1] @@ -405,6 +419,7 @@ def test_multi_level_roi_align(backend, 'MMCVMultiLevelRoiAlign_0', None, 'mmdeploy', + pool_mode=pool_mode, aligned=aligned, featmap_strides=featmap_strides, finest_scale=finest_scale,