[Fix] Fix RoiAlign Unittest (#90)

* fix roi_align unittest

* fix lint

* remove non must code

* fix isort

* reply for review

* fix lint

* reply code review

* fix docformatter

* fix review

* reuse mmcv.roialignfunction
This commit is contained in:
hanrui1sensetime 2021-09-26 11:21:18 +08:00 committed by GitHub
parent 5b8750b83b
commit 01e2240b94
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 43 additions and 13 deletions

View File

@ -1 +1,4 @@
from .nms import * # noqa: F401,F403
from .roi_align import roi_align_default
__all__ = ['roi_align_default']

View File

@ -0,0 +1,20 @@
from mmdeploy.core import SYMBOLIC_REGISTER
# Here using mmcv.ops.roi_align.__self__ to find
# mmcv.ops.roi_align.RoIAlignFunction, because RoIAlignFunction is not
# visiable in mmcv.
@SYMBOLIC_REGISTER.register_symbolic(
'mmcv.ops.roi_align.__self__', backend='default')
def roi_align_default(ctx, g, input, rois, output_size, spatial_scale,
sampling_ratio, pool_mode, aligned):
return g.op(
'mmcv::MMCVRoiAlign',
input,
rois,
output_height_i=output_size[0],
output_width_i=output_size[1],
spatial_scale_f=spatial_scale,
sampling_ratio_i=sampling_ratio,
mode_s=pool_mode,
aligned_i=aligned)

View File

@ -1,6 +1,7 @@
import pytest
import torch
from mmdeploy.core import register_extra_symbolics
from mmdeploy.utils.test import WrapFunction
from .utils import TestOnnxRTExporter, TestTensorRTExporter
@ -20,7 +21,8 @@ def test_roi_align(backend,
inputs=None,
work_dir=None):
backend.check_env()
# TODO: check if mmcv-full is installed
# using rewriter of roi_align to bypass mmcv has_custom_ops check.
register_extra_symbolics(cfg=dict(), backend='default', opset=11)
from mmcv.ops import roi_align
def wrapped_function(torch_input, torch_rois):
@ -33,7 +35,7 @@ def test_roi_align(backend,
single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
else:
input = torch.tensor(inputs[0], dtype=torch.float32)
single_roi = torch.tensor(input[1], dtype=torch.float32)
single_roi = torch.tensor(inputs[1], dtype=torch.float32)
backend.run_and_validate(
wrapped_model, [input, single_roi],

View File

@ -101,17 +101,22 @@ class TestTensorRTExporter:
deploy_cfg = mmcv.Config(
dict(
backend='tensorrt',
tensorrt_params=dict(model_params=[
dict(
opt_shape_dict=dict(
zip(input_names, [[
list(data.shape),
list(data.shape),
list(data.shape)
] for data in inputs_list])),
max_workspace_size=0)
])))
backend_config=dict(
type='tensorrt',
common_config=dict(
fp16_mode=False, max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
zip(input_names, [
dict(
min_shape=data.shape,
opt_shape=data.shape,
max_shape=data.shape)
for data in inputs_list
])))
])))
onnx_model = onnx.load(onnx_file_path)
trt_apis.onnx2tensorrt(
os.path.dirname(trt_file_path),