mmcv/tests/test_ops/test_tensorrt.py
q.yao 0de9e149c0
[Feature]: Add custom operators support for TensorRT in mmcv (#686)
* start trt plugin prototype

* Add test module, modify roialign convertor

* finish roi_align trt plugin

* fix conflict of RoiAlign and MMCVRoiAlign

* fix for lint

* fix test tensorrt module

* test_tensorrt move import to test func

* add except error type

* add tensorrt to setup.cfg

* code format with yapf

* fix for clang-format

* move tensorrt_utils to mmcv/tensorrt, add comments, better test module

* fix line endings, docformatter

* isort init, remove trailing whitespace

* add except type

* fix setup.py

* put import extension inside trt setup

* change c++ guard, update pytest script, better setup, etc

* sort import with isort

* sort import with isort

* move init of plugin lib to init_plugins.py

* resolve format and add test dependency: tensorrt

* tensorrt should be installed from source not from pypi

* update naming style and input check

* resolve lint error

Co-authored-by: maningsheng <maningsheng@sensetime.com>
2021-01-06 11:05:19 +08:00

94 lines
3.0 KiB
Python

import os
import numpy as np
import onnx
import pytest
import torch
onnx_file = 'tmp.onnx'
trt_file = 'tmp.engine'
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='CUDA is required for test_roialign')
def test_roialign():
try:
from mmcv.tensorrt import (TRTWraper, onnx2trt, save_trt_engine,
is_tensorrt_plugin_loaded)
if not is_tensorrt_plugin_loaded():
pytest.skip('test requires to complie TensorRT plugins in mmcv')
except (ImportError, ModuleNotFoundError):
pytest.skip('test requires to install TensorRT from source.')
try:
from mmcv.ops import RoIAlign
except (ImportError, ModuleNotFoundError):
pytest.skip('test requires compilation')
# trt config
fp16_mode = False
max_workspace_size = 1 << 30
# roi align config
pool_h = 2
pool_w = 2
spatial_scale = 1.0
sampling_ratio = 2
inputs = [([[[[1., 2.], [3., 4.]]]], [[0., 0., 0., 1., 1.]]),
([[[[1., 2.], [3., 4.]], [[4., 3.],
[2., 1.]]]], [[0., 0., 0., 1., 1.]]),
([[[[1., 2., 5., 6.], [3., 4., 7., 8.], [9., 10., 13., 14.],
[11., 12., 15., 16.]]]], [[0., 0., 0., 3., 3.]])]
wrapped_model = RoIAlign((pool_w, pool_h), spatial_scale, sampling_ratio,
'avg', True).cuda()
for case in inputs:
np_input = np.array(case[0], dtype=np.float32)
np_rois = np.array(case[1], dtype=np.float32)
input = torch.from_numpy(np_input).cuda()
rois = torch.from_numpy(np_rois).cuda()
with torch.no_grad():
torch.onnx.export(
wrapped_model, (input, rois),
onnx_file,
export_params=True,
keep_initializers_as_inputs=True,
input_names=['input', 'rois'],
output_names=['roi_feat'],
opset_version=11)
onnx_model = onnx.load(onnx_file)
# create trt engine and wraper
opt_shape_dict = {
'input': [list(input.shape),
list(input.shape),
list(input.shape)],
'rois': [list(rois.shape),
list(rois.shape),
list(rois.shape)]
}
trt_engine = onnx2trt(
onnx_model,
opt_shape_dict,
fp16_mode=fp16_mode,
max_workspace_size=max_workspace_size)
save_trt_engine(trt_engine, trt_file)
trt_model = TRTWraper(trt_file, ['input', 'rois'], ['roi_feat'])
with torch.no_grad():
trt_outputs = trt_model({'input': input, 'rois': rois})
trt_roi_feat = trt_outputs['roi_feat']
# compute pytorch_output
with torch.no_grad():
pytorch_roi_feat = wrapped_model(input, rois)
# allclose
if os.path.exists(onnx_file):
os.remove(onnx_file)
if os.path.exists(trt_file):
os.remove(trt_file)
assert torch.allclose(pytorch_roi_feat, trt_roi_feat)