[Enhancement] Optimize multilevel roi align (#167)

* optimize multilevel roi align

* add pool mode
This commit is contained in:
q.yao 2022-03-14 10:26:27 +08:00 committed by GitHub
parent df4e9e6cae
commit 1f5e670421
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 160 additions and 115 deletions

View File

@ -49,44 +49,31 @@ cublasStatus_t cublasGemmWrap(cublasHandle_t handle, cublasOperation_t transa,
const scalar_t* beta, scalar_t* C, int ldc);
template <typename scalar_t>
__device__ scalar_t bilinear_interpolate(const scalar_t* input, const int height, const int width,
scalar_t y, scalar_t x) {
__device__ __forceinline__ scalar_t bilinear_interpolate(const scalar_t* __restrict__ input,
const int height, const int width,
scalar_t y, scalar_t x) {
// deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) return 0;
if (y <= 0) y = 0;
if (x <= 0) x = 0;
y = min(scalar_t(height - 1), max(scalar_t(0), y));
x = min(scalar_t(width - 1), max(scalar_t(0), x));
int y_low = (int)y;
int x_low = (int)x;
int y_high;
int x_high;
const int y_low = floor(y);
const int x_low = floor(x);
const int y_high = ceil(y);
const int x_high = ceil(x);
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (scalar_t)y_low;
} else {
y_high = y_low + 1;
}
const scalar_t v1 = input[y_low * width + x_low];
const scalar_t v2 = input[y_low * width + x_high];
const scalar_t v3 = input[y_high * width + x_low];
const scalar_t v4 = input[y_high * width + x_high];
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (scalar_t)x_low;
} else {
x_high = x_low + 1;
}
scalar_t ly = y - y_low;
scalar_t lx = x - x_low;
scalar_t hy = 1. - ly, hx = 1. - lx;
// do bilinear interpolation
scalar_t v1 = input[y_low * width + x_low];
scalar_t v2 = input[y_low * width + x_high];
scalar_t v3 = input[y_high * width + x_low];
scalar_t v4 = input[y_high * width + x_high];
scalar_t w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
// lerp can be performed by fma
const scalar_t ly = y - y_low;
const scalar_t lx = x - x_low;
const scalar_t v_low = fma(v2 - v1, lx, v1);
const scalar_t v_high = fma(v4 - v3, lx, v3);
const scalar_t val = fma(v_high - v_low, ly, v_low);
return val;
}

View File

@ -16,12 +16,13 @@ static const char *PLUGIN_NAME{"MMCVMultiLevelRoiAlign"};
} // namespace
TRTMultiLevelRoiAlign::TRTMultiLevelRoiAlign(const std::string &name, int alignedHeight,
int alignedWidth, int sampleNum,
int alignedWidth, int poolMode, int sampleNum,
const std::vector<float> &featmapStrides,
float roiScaleFactor, int finestScale, bool aligned)
: TRTPluginBase(name),
mAlignedHeight(alignedHeight),
mAlignedWidth(alignedWidth),
mPoolMode(poolMode),
mSampleNum(sampleNum),
mFeatmapStrides(featmapStrides),
mRoiScaleFactor(roiScaleFactor),
@ -33,6 +34,7 @@ TRTMultiLevelRoiAlign::TRTMultiLevelRoiAlign(const std::string name, const void
: TRTPluginBase(name) {
deserialize_value(&data, &length, &mAlignedHeight);
deserialize_value(&data, &length, &mAlignedWidth);
deserialize_value(&data, &length, &mPoolMode);
deserialize_value(&data, &length, &mSampleNum);
deserialize_value(&data, &length, &mRoiScaleFactor);
deserialize_value(&data, &length, &mFinestScale);
@ -42,7 +44,7 @@ TRTMultiLevelRoiAlign::TRTMultiLevelRoiAlign(const std::string name, const void
nvinfer1::IPluginV2DynamicExt *TRTMultiLevelRoiAlign::clone() const TRT_NOEXCEPT {
TRTMultiLevelRoiAlign *plugin =
new TRTMultiLevelRoiAlign(mLayerName, mAlignedHeight, mAlignedWidth, mSampleNum,
new TRTMultiLevelRoiAlign(mLayerName, mAlignedHeight, mAlignedWidth, mPoolMode, mSampleNum,
mFeatmapStrides, mRoiScaleFactor, mFinestScale, mAligned);
plugin->setPluginNamespace(getPluginNamespace());
@ -113,8 +115,8 @@ int TRTMultiLevelRoiAlign::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
multi_level_roi_align<float>((float *)outputs[0], (const float *)rois, num_rois, feats, num_feats,
batch_size, channels, &heights[0], &widths[0], &strides[0],
mAlignedHeight, mAlignedWidth, mSampleNum, mRoiScaleFactor,
mFinestScale, mAligned, stream);
mAlignedHeight, mAlignedWidth, mPoolMode, mSampleNum,
mRoiScaleFactor, mFinestScale, mAligned, stream);
return 0;
}
@ -134,7 +136,7 @@ int TRTMultiLevelRoiAlign::getNbOutputs() const TRT_NOEXCEPT { return 1; }
size_t TRTMultiLevelRoiAlign::getSerializationSize() const TRT_NOEXCEPT {
return serialized_size(mFeatmapStrides) + serialized_size(mAlignedHeight) +
serialized_size(mAlignedWidth) + serialized_size(mSampleNum) +
serialized_size(mAlignedWidth) + serialized_size(mPoolMode) + serialized_size(mSampleNum) +
serialized_size(mRoiScaleFactor) + serialized_size(mFinestScale) +
serialized_size(mAligned);
}
@ -142,6 +144,7 @@ size_t TRTMultiLevelRoiAlign::getSerializationSize() const TRT_NOEXCEPT {
void TRTMultiLevelRoiAlign::serialize(void *buffer) const TRT_NOEXCEPT {
serialize_value(&buffer, mAlignedHeight);
serialize_value(&buffer, mAlignedWidth);
serialize_value(&buffer, mPoolMode);
serialize_value(&buffer, mSampleNum);
serialize_value(&buffer, mRoiScaleFactor);
serialize_value(&buffer, mFinestScale);
@ -152,9 +155,9 @@ void TRTMultiLevelRoiAlign::serialize(void *buffer) const TRT_NOEXCEPT {
TRTMultiLevelRoiAlignCreator::TRTMultiLevelRoiAlignCreator() {
mPluginAttributes = std::vector<nvinfer1::PluginField>(
{nvinfer1::PluginField("output_height"), nvinfer1::PluginField("output_width"),
nvinfer1::PluginField("sampling_ratio"), nvinfer1::PluginField("featmap_strides"),
nvinfer1::PluginField("roi_scale_factor"), nvinfer1::PluginField("finest_scale"),
nvinfer1::PluginField("aligned")});
nvinfer1::PluginField("pool_mode"), nvinfer1::PluginField("sampling_ratio"),
nvinfer1::PluginField("featmap_strides"), nvinfer1::PluginField("roi_scale_factor"),
nvinfer1::PluginField("finest_scale"), nvinfer1::PluginField("aligned")});
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
@ -169,6 +172,7 @@ nvinfer1::IPluginV2 *TRTMultiLevelRoiAlignCreator::createPlugin(
const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT {
int alignedHeight = 7;
int alignedWidth = 7;
int poolMode = 0;
int sampleNum = 2;
std::vector<float> featmapStrides;
float roiScaleFactor = -1;
@ -185,6 +189,8 @@ nvinfer1::IPluginV2 *TRTMultiLevelRoiAlignCreator::createPlugin(
alignedHeight = static_cast<const int *>(fc->fields[i].data)[0];
} else if (field_name.compare("output_width") == 0) {
alignedWidth = static_cast<const int *>(fc->fields[i].data)[0];
} else if (field_name.compare("pool_mode") == 0) {
poolMode = static_cast<const int *>(fc->fields[i].data)[0];
} else if (field_name.compare("sampling_ratio") == 0) {
sampleNum = static_cast<const int *>(fc->fields[i].data)[0];
} else if (field_name.compare("roi_scale_factor") == 0) {
@ -204,8 +210,8 @@ nvinfer1::IPluginV2 *TRTMultiLevelRoiAlignCreator::createPlugin(
ASSERT(featmapStrides.size() != 0);
TRTMultiLevelRoiAlign *plugin =
new TRTMultiLevelRoiAlign(name, alignedHeight, alignedWidth, sampleNum, featmapStrides,
roiScaleFactor, finestScale, aligned);
new TRTMultiLevelRoiAlign(name, alignedHeight, alignedWidth, poolMode, sampleNum,
featmapStrides, roiScaleFactor, finestScale, aligned);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}

View File

@ -13,9 +13,9 @@
namespace mmdeploy {
class TRTMultiLevelRoiAlign : public TRTPluginBase {
public:
TRTMultiLevelRoiAlign(const std::string &name, int alignedHeight, int alignedWidth, int sampleNum,
const std::vector<float> &featmapStrides, float roiScaleFactor = -1,
int finestScale = 56, bool aligned = false);
TRTMultiLevelRoiAlign(const std::string &name, int alignedHeight, int alignedWidth, int poolMode,
int sampleNum, const std::vector<float> &featmapStrides,
float roiScaleFactor = -1, int finestScale = 56, bool aligned = false);
TRTMultiLevelRoiAlign(const std::string name, const void *data, size_t length);
@ -52,6 +52,7 @@ class TRTMultiLevelRoiAlign : public TRTPluginBase {
private:
int mAlignedHeight;
int mAlignedWidth;
int mPoolMode;
int mSampleNum;
std::vector<float> mFeatmapStrides;
float mRoiScaleFactor;

View File

@ -1,4 +1,5 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include <float.h>
#include <stdio.h>
#include <algorithm>
@ -19,21 +20,17 @@ struct FeatData {
int num_featmap;
};
template <typename scalar_t>
__device__ scalar_t roi_align_single(const scalar_t *bottom_data, const int roi_batch_ind,
const scalar_t roi_start_w, const scalar_t roi_start_h,
const scalar_t roi_end_w, const scalar_t roi_end_h,
const scalar_t spatial_scale, const int pw, const int ph,
const int c, const int sample_num, const int channels,
const int height, const int width, const int pooled_height,
const int pooled_width, const bool aligned) {
template <typename scalar_t, bool aligned, int pool_mode>
__device__ scalar_t roi_align_single(const scalar_t *__restrict__ bottom_data,
const int roi_batch_ind, const scalar_t roi_start_w,
const scalar_t roi_start_h, const scalar_t roi_end_w,
const scalar_t roi_end_h, const scalar_t spatial_scale,
const int pw, const int ph, const int c, const int sample_num,
const int channels, const int height, const int width,
const int pooled_height, const int pooled_width) {
// Force malformed ROIs to be 1x1
scalar_t roi_width = fmaxf((scalar_t)roi_end_w - (scalar_t)roi_start_w, 0.);
scalar_t roi_height = fmaxf((scalar_t)roi_end_h - (scalar_t)roi_start_h, 0.);
if (!aligned) {
roi_width = max(roi_width, (scalar_t)1.);
roi_height = max(roi_height, (scalar_t)1.);
}
scalar_t roi_width = max(roi_end_w - roi_start_w, (scalar_t)(aligned ? 0. : 1.));
scalar_t roi_height = max(roi_end_h - roi_start_h, (scalar_t)(aligned ? 0. : 1.));
const scalar_t bin_size_h = roi_height / pooled_height;
const scalar_t bin_size_w = roi_width / pooled_width;
@ -41,39 +38,49 @@ __device__ scalar_t roi_align_single(const scalar_t *bottom_data, const int roi_
const scalar_t *offset_bottom_data =
bottom_data + (roi_batch_ind * channels + c) * height * width;
int sample_num_h = (sample_num > 0) ? sample_num : ceil(roi_height / pooled_height); // e.g., = 2
int sample_num_w = (sample_num > 0) ? sample_num : ceil(roi_width / pooled_width);
const int sample_num_h = (sample_num > 0) ? sample_num : ceil(roi_height / pooled_height);
const int sample_num_w = (sample_num > 0) ? sample_num : ceil(roi_width / pooled_width);
scalar_t output_val = 0;
#pragma unroll
scalar_t output_val = (pool_mode == 0) ? -FLT_MAX : 0;
const scalar_t y_offset = roi_start_h + ph * bin_size_h;
const scalar_t y_scale = bin_size_h / (scalar_t)(sample_num_h);
const scalar_t x_offset = roi_start_w + pw * bin_size_w;
const scalar_t x_scale = bin_size_w / (scalar_t)(sample_num_w);
for (int iy = 0; iy < sample_num_h; iy++) {
const scalar_t y = roi_start_h + ph * bin_size_h +
(scalar_t)(iy + scalar_t(.5f)) * bin_size_h / (scalar_t)(sample_num_h);
#pragma unroll
const scalar_t y = fma(scalar_t(iy) + scalar_t(.5f), y_scale, y_offset);
for (int ix = 0; ix < sample_num_w; ix++) {
const scalar_t x = roi_start_w + pw * bin_size_w +
(scalar_t)(ix + scalar_t(.5f)) * bin_size_w / (scalar_t)(sample_num_w);
const scalar_t x = fma(scalar_t(ix) + scalar_t(.5f), x_scale, x_offset);
scalar_t val = bilinear_interpolate<scalar_t>(offset_bottom_data, height, width, y, x);
output_val += val;
if (pool_mode == 0) {
output_val = max(output_val, val);
} else {
output_val += val;
}
}
}
output_val /= max(sample_num_h * sample_num_w, 1);
if (pool_mode != 0) {
output_val /= max(sample_num_h * sample_num_w, 1);
}
return output_val;
}
template <typename scalar_t>
__global__ void roi_extractor_kernel(scalar_t *output, const scalar_t *bottom_rois,
FeatData feat_data, const int sample_num,
template <typename scalar_t, bool aligned>
__global__ void roi_extractor_kernel(scalar_t *__restrict__ output,
const scalar_t *__restrict__ bottom_rois, FeatData feat_data,
const int pool_mode, const int sample_num,
const float roi_scale_factor, const int finest_scale,
const int pooled_height, const int pooled_width,
const bool aligned, int nThreads) {
int nThreads) {
CUDA_1D_KERNEL_LOOP(index, nThreads) {
const int channels = feat_data.channels;
const int pw = index % pooled_width;
const int ph = (index / pooled_width) % pooled_height;
const int c = (index / pooled_width / pooled_height) % channels;
const int n = index / pooled_width / pooled_height / channels;
int tmp_index = index;
const int pw = tmp_index % pooled_width;
tmp_index /= pooled_width;
const int ph = tmp_index % pooled_height;
tmp_index /= pooled_height;
const int c = tmp_index % channels;
const int n = tmp_index / channels;
const scalar_t *offset_bottom_rois = bottom_rois + n * 5;
@ -84,19 +91,23 @@ __global__ void roi_extractor_kernel(scalar_t *output, const scalar_t *bottom_ro
const scalar_t scale = sqrtf((roi_offset_y1 - roi_offset_y0) * (roi_offset_x1 - roi_offset_x0));
const int target_lvls = fminf(feat_data.num_featmap - 1,
fmaxf(0, floorf(log2f(scale / (scalar_t)(finest_scale) + 1e-6))));
const int target_lvls =
min(feat_data.num_featmap - 1,
max(0, int(floorf(log2f(scale / (scalar_t)(finest_scale) + 1e-6)))));
if (roi_scale_factor > 0.) {
const scalar_t roi_off_cx = (roi_offset_x0 + roi_offset_x1) * 0.5;
const scalar_t roi_off_cy = (roi_offset_y0 + roi_offset_y1) * 0.5;
const scalar_t roi_off_w = (roi_offset_x1 - roi_offset_x0 + 1) * roi_scale_factor;
const scalar_t roi_off_h = (roi_offset_y1 - roi_offset_y0 + 1) * roi_scale_factor;
const scalar_t half_scale_factor = roi_scale_factor * 0.5;
const scalar_t half_roi_off_w =
fma(roi_offset_x1 - roi_offset_x0 + 1, half_scale_factor, scalar_t(-0.5));
const scalar_t half_roi_off_h =
fma(roi_offset_y1 - roi_offset_y0 + 1, half_scale_factor, scalar_t(-0.5));
roi_offset_x0 = roi_off_cx - roi_off_w * 0.5 + 0.5;
roi_offset_x1 = roi_off_cx + roi_off_w * 0.5 - 0.5;
roi_offset_y0 = roi_off_cy - roi_off_h * 0.5 + 0.5;
roi_offset_y1 = roi_off_cy + roi_off_h * 0.5 - 0.5;
roi_offset_x0 = roi_off_cx - half_roi_off_w;
roi_offset_x1 = roi_off_cx + half_roi_off_w;
roi_offset_y0 = roi_off_cy - half_roi_off_h;
roi_offset_y1 = roi_off_cy + half_roi_off_h;
}
const scalar_t spatial_scale = (scalar_t)feat_data.spatial_scale[target_lvls];
@ -105,24 +116,34 @@ __global__ void roi_extractor_kernel(scalar_t *output, const scalar_t *bottom_ro
const scalar_t *bottom_data = (scalar_t *)feat_data.data[target_lvls];
const int roi_batch_ind = offset_bottom_rois[0];
const scalar_t offset = aligned ? (scalar_t)0.5 : (scalar_t)0.0;
const scalar_t roi_start_w = roi_offset_x0 * spatial_scale - offset;
const scalar_t roi_start_h = roi_offset_y0 * spatial_scale - offset;
const scalar_t roi_end_w = (roi_offset_x1)*spatial_scale - offset;
const scalar_t roi_end_h = (roi_offset_y1)*spatial_scale - offset;
const scalar_t offset = aligned ? (scalar_t)-0.5 : (scalar_t)0.0;
const scalar_t roi_start_w =
fma(roi_offset_x0, spatial_scale, offset); // roi_offset_x0 * spatial_scale + offset;
const scalar_t roi_start_h =
fma(roi_offset_y0, spatial_scale, offset); // roi_offset_y0 * spatial_scale + offset;
const scalar_t roi_end_w =
fma(roi_offset_x1, spatial_scale, offset); // (roi_offset_x1) * spatial_scale - offset;
const scalar_t roi_end_h =
fma(roi_offset_y1, spatial_scale, offset); // (roi_offset_y1)*spatial_scale - offset;
const scalar_t output_val = roi_align_single<scalar_t>(
bottom_data, roi_batch_ind, roi_start_w, roi_start_h, roi_end_w, roi_end_h, spatial_scale,
pw, ph, c, sample_num, channels, height, width, pooled_height, pooled_width, aligned);
output[index] = output_val;
if (pool_mode == 0) {
const scalar_t output_val = roi_align_single<scalar_t, aligned, 0>(
bottom_data, roi_batch_ind, roi_start_w, roi_start_h, roi_end_w, roi_end_h, spatial_scale,
pw, ph, c, sample_num, channels, height, width, pooled_height, pooled_width);
output[index] = output_val;
} else {
const scalar_t output_val = roi_align_single<scalar_t, aligned, 1>(
bottom_data, roi_batch_ind, roi_start_w, roi_start_h, roi_end_w, roi_end_h, spatial_scale,
pw, ph, c, sample_num, channels, height, width, pooled_height, pooled_width);
output[index] = output_val;
}
}
}
template <typename T>
void multi_level_roi_align(T *output, const T *rois, int num_rois, const void *const *feats,
int num_feats, int n, int c, int *h, int *w, float *strides,
int aligned_height, int aligned_width, int sample_num,
int aligned_height, int aligned_width, int pool_mode, int sample_num,
float roi_scale_factor, int finest_scale, bool aligned,
cudaStream_t stream) {
FeatData feat_data;
@ -136,15 +157,20 @@ void multi_level_roi_align(T *output, const T *rois, int num_rois, const void *c
feat_data.spatial_scale[i] = 1. / float(strides[i]);
}
int nThreads = num_rois * c * aligned_height * aligned_width;
// bool aligned = true;
roi_extractor_kernel<T><<<GET_BLOCKS(nThreads), THREADS_PER_BLOCK, 0, stream>>>(
output, rois, feat_data, sample_num, roi_scale_factor, finest_scale, aligned_height,
aligned_width, aligned, nThreads);
if (aligned) {
roi_extractor_kernel<T, true><<<GET_BLOCKS(nThreads), THREADS_PER_BLOCK, 0, stream>>>(
output, rois, feat_data, pool_mode, sample_num, roi_scale_factor, finest_scale,
aligned_height, aligned_width, nThreads);
} else {
roi_extractor_kernel<T, false><<<GET_BLOCKS(nThreads), THREADS_PER_BLOCK, 0, stream>>>(
output, rois, feat_data, pool_mode, sample_num, roi_scale_factor, finest_scale,
aligned_height, aligned_width, nThreads);
}
}
template void multi_level_roi_align<float>(float *output, const float *rois, int num_rois,
const void *const *feats, int num_feats, int n, int c,
int *h, int *w, float *strides, int aligned_height,
int aligned_width, int sample_num,
int aligned_width, int pool_mode, int sample_num,
float roi_scale_factor, int finest_scale, bool aligned,
cudaStream_t stream);

View File

@ -6,7 +6,7 @@
template <typename T>
void multi_level_roi_align(T *output, const T *rois, int num_rois, const void *const *feats,
int num_feats, int n, int c, int *h, int *w, float *strides,
int aligned_height, int aligned_width, int sample_num,
int aligned_height, int aligned_width, int pool_mode, int sample_num,
float roi_scale_factor, int finest_scale, bool aligned,
cudaStream_t stream);

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.ops import RoIAlign
from torch.autograd import Function
from mmdeploy.core.optimizers import mark
@ -24,7 +25,9 @@ class MultiLevelRoiAlign(Function):
finest_scale = args[-3]
roi_scale_factor = args[-4]
sampling_ratio = args[-5]
output_size = args[-6]
pool_mode = args[-6]
pool_mode_flag = 0 if pool_mode == 'max' else 1
output_size = args[-7]
inputs = args[:len(featmap_strides)]
rois = args[len(featmap_strides)]
return g.op(
@ -33,6 +36,7 @@ class MultiLevelRoiAlign(Function):
*inputs,
output_height_i=output_size[1],
output_width_i=output_size[0],
pool_mode_i=pool_mode_flag,
sampling_ratio_i=sampling_ratio,
roi_scale_factor_f=roi_scale_factor,
finest_scale_i=finest_scale,
@ -47,7 +51,7 @@ class MultiLevelRoiAlign(Function):
# finest_scale = args[-3]
# roi_scale_factor = args[-4]
# sampling_ratio = args[-5]
output_size = args[-6]
output_size = args[-7]
inputs = args[:len(featmap_strides)]
rois = args[len(featmap_strides)]
@ -75,17 +79,23 @@ def single_roi_extractor__forward__tensorrt(ctx,
featmap_strides = self.featmap_strides
finest_scale = self.finest_scale
for roi_layer in self.roi_layers:
assert isinstance(
roi_layer,
RoIAlign), f'{type(roi_layer)} is not supported in TensorRT.'
roi_layer = self.roi_layers[0]
out_size = roi_layer.output_size
sampling_ratio = roi_layer.sampling_ratio
pool_mode = roi_layer.pool_mode
aligned = roi_layer.aligned
if roi_scale_factor is None:
roi_scale_factor = 1.0
featmap_strides = [float(s) for s in featmap_strides]
return MultiLevelRoiAlign.apply(*feats, rois, out_size, sampling_ratio,
roi_scale_factor, finest_scale,
featmap_strides, aligned)
return MultiLevelRoiAlign.apply(*feats, rois, out_size, pool_mode,
sampling_ratio, roi_scale_factor,
finest_scale, featmap_strides, aligned)
@FUNCTION_REWRITER.register_rewriter(

View File

@ -334,11 +334,14 @@ def test_batched_nms(backend,
@pytest.mark.parametrize('backend', [TEST_TENSORRT])
@pytest.mark.parametrize('out_size, sampling_ratio,roi_scale_factor,'
' finest_scale,featmap_strides, aligned',
[(tuple([2, 2]), 2, 1.0, 2, list([2.0, 4.0]), 1)])
@pytest.mark.parametrize(
'out_size, pool_mode, sampling_ratio,roi_scale_factor,'
' finest_scale,featmap_strides, aligned',
[(tuple([2, 2]), 0, 2, 1.0, 2, list([2.0, 4.0]), 1),
(tuple([2, 2]), 1, 2, 1.0, 2, list([2.0, 4.0]), 1)])
def test_multi_level_roi_align(backend,
out_size,
pool_mode,
sampling_ratio,
roi_scale_factor,
finest_scale,
@ -376,10 +379,21 @@ def test_multi_level_roi_align(backend,
[0.9178, 0.7282, 0.0291, 0.3028]]]])
]
rois = torch.tensor([[0., 0., 0., 4., 4.]])
expected_result = torch.tensor([[[[0.1939, 0.3950], [0.3437, 0.4543]],
[[0.0778, 0.1641], [0.1305, 0.2301]],
[[0.1542, 0.2413], [0.2094,
0.2688]]]])
if pool_mode == 1:
expected_result = torch.tensor([[[[0.1939, 0.3950],
[0.3437, 0.4543]],
[[0.0778, 0.1641],
[0.1305, 0.2301]],
[[0.1542, 0.2413],
[0.2094, 0.2688]]]])
else:
expected_result = torch.tensor([[[[0.1939, 0.4956],
[0.4185, 0.5167]],
[[0.0778, 0.2073],
[0.1569, 0.3162]],
[[0.1542, 0.2849],
[0.2370, 0.3053]]]])
else:
input = input_list[0]
rois = input_list[1]
@ -405,6 +419,7 @@ def test_multi_level_roi_align(backend,
'MMCVMultiLevelRoiAlign_0',
None,
'mmdeploy',
pool_mode=pool_mode,
aligned=aligned,
featmap_strides=featmap_strides,
finest_scale=finest_scale,