mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Fix] Fix RoiAlign Unittest (#90)
* fix roi_align unittest * fix lint * remove non must code * fix isort * reply for review * fix lint * reply code review * fix docformatter * fix review * reuse mmcv.roialignfunction
This commit is contained in:
parent
5b8750b83b
commit
01e2240b94
@ -1 +1,4 @@
|
||||
from .nms import * # noqa: F401,F403
|
||||
from .roi_align import roi_align_default
|
||||
|
||||
__all__ = ['roi_align_default']
|
||||
|
20
mmdeploy/mmcv/ops/roi_align.py
Normal file
20
mmdeploy/mmcv/ops/roi_align.py
Normal file
@ -0,0 +1,20 @@
|
||||
from mmdeploy.core import SYMBOLIC_REGISTER
|
||||
|
||||
|
||||
# Here using mmcv.ops.roi_align.__self__ to find
|
||||
# mmcv.ops.roi_align.RoIAlignFunction, because RoIAlignFunction is not
|
||||
# visiable in mmcv.
|
||||
@SYMBOLIC_REGISTER.register_symbolic(
|
||||
'mmcv.ops.roi_align.__self__', backend='default')
|
||||
def roi_align_default(ctx, g, input, rois, output_size, spatial_scale,
|
||||
sampling_ratio, pool_mode, aligned):
|
||||
return g.op(
|
||||
'mmcv::MMCVRoiAlign',
|
||||
input,
|
||||
rois,
|
||||
output_height_i=output_size[0],
|
||||
output_width_i=output_size[1],
|
||||
spatial_scale_f=spatial_scale,
|
||||
sampling_ratio_i=sampling_ratio,
|
||||
mode_s=pool_mode,
|
||||
aligned_i=aligned)
|
@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmdeploy.core import register_extra_symbolics
|
||||
from mmdeploy.utils.test import WrapFunction
|
||||
from .utils import TestOnnxRTExporter, TestTensorRTExporter
|
||||
|
||||
@ -20,7 +21,8 @@ def test_roi_align(backend,
|
||||
inputs=None,
|
||||
work_dir=None):
|
||||
backend.check_env()
|
||||
# TODO: check if mmcv-full is installed
|
||||
# using rewriter of roi_align to bypass mmcv has_custom_ops check.
|
||||
register_extra_symbolics(cfg=dict(), backend='default', opset=11)
|
||||
from mmcv.ops import roi_align
|
||||
|
||||
def wrapped_function(torch_input, torch_rois):
|
||||
@ -33,7 +35,7 @@ def test_roi_align(backend,
|
||||
single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
|
||||
else:
|
||||
input = torch.tensor(inputs[0], dtype=torch.float32)
|
||||
single_roi = torch.tensor(input[1], dtype=torch.float32)
|
||||
single_roi = torch.tensor(inputs[1], dtype=torch.float32)
|
||||
|
||||
backend.run_and_validate(
|
||||
wrapped_model, [input, single_roi],
|
||||
|
@ -101,17 +101,22 @@ class TestTensorRTExporter:
|
||||
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend='tensorrt',
|
||||
tensorrt_params=dict(model_params=[
|
||||
dict(
|
||||
opt_shape_dict=dict(
|
||||
zip(input_names, [[
|
||||
list(data.shape),
|
||||
list(data.shape),
|
||||
list(data.shape)
|
||||
] for data in inputs_list])),
|
||||
max_workspace_size=0)
|
||||
])))
|
||||
backend_config=dict(
|
||||
type='tensorrt',
|
||||
common_config=dict(
|
||||
fp16_mode=False, max_workspace_size=1 << 30),
|
||||
model_inputs=[
|
||||
dict(
|
||||
input_shapes=dict(
|
||||
zip(input_names, [
|
||||
dict(
|
||||
min_shape=data.shape,
|
||||
opt_shape=data.shape,
|
||||
max_shape=data.shape)
|
||||
for data in inputs_list
|
||||
])))
|
||||
])))
|
||||
|
||||
onnx_model = onnx.load(onnx_file_path)
|
||||
trt_apis.onnx2tensorrt(
|
||||
os.path.dirname(trt_file_path),
|
||||
|
Loading…
x
Reference in New Issue
Block a user