mmcv/tests/test_ops/test_onnx.py
Wang Xinjiang ee041cec6a
[Feature]: Add Rotated ROI align op for pytorch (cpu&cuda), parrots (cpu&cuda) and onnxruntime (cpu) (#933)
* add roi_align_rotated

* code format

* Add align key to roi align rotated

* Add clockwise for rotated roi align

* fix bugs in onnx export

* Add docstring for RoIAlignRotated

* remove cuda unittest

* Reformat c++ code

* add onnx roi align rotated file

* fix unittest

* Add  cpu and float64 of cuda support for parrots

* code format

* Add unified header to roi align rotated

Co-authored-by: luopeichao <luopeichao@sensetime.com>
2021-04-19 15:04:01 +08:00

451 lines
16 KiB
Python

import os
import warnings
from functools import partial
import numpy as np
import onnx
import onnxruntime as rt
import pytest
import torch
import torch.nn as nn
from packaging import version
onnx_file = 'tmp.onnx'
class WrapFunction(nn.Module):
def __init__(self, wrapped_function):
super(WrapFunction, self).__init__()
self.wrapped_function = wrapped_function
def forward(self, *args, **kwargs):
return self.wrapped_function(*args, **kwargs)
@pytest.mark.parametrize('mode', ['bilinear', 'nearest'])
@pytest.mark.parametrize('padding_mode', ['zeros', 'border', 'reflection'])
@pytest.mark.parametrize('align_corners', [True, False])
def test_grid_sample(mode, padding_mode, align_corners):
from mmcv.onnx.symbolic import register_extra_symbolics
opset_version = 11
register_extra_symbolics(opset_version)
from mmcv.ops import get_onnxruntime_op_path
ort_custom_op_path = get_onnxruntime_op_path()
if not os.path.exists(ort_custom_op_path):
pytest.skip('custom ops for onnxruntime are not compiled.')
input = torch.rand(1, 1, 10, 10)
grid = torch.Tensor([[[1, 0, 0], [0, 1, 0]]])
grid = nn.functional.affine_grid(grid, (1, 1, 15, 15)).type_as(input)
def func(input, grid):
return nn.functional.grid_sample(
input,
grid,
mode=mode,
padding_mode=padding_mode,
align_corners=align_corners)
wrapped_model = WrapFunction(func).eval()
input_names = ['input', 'grid']
output_names = ['output']
with torch.no_grad():
torch.onnx.export(
wrapped_model, (input, grid),
onnx_file,
export_params=True,
keep_initializers_as_inputs=True,
input_names=input_names,
output_names=output_names,
opset_version=11)
onnx_model = onnx.load(onnx_file)
session_options = rt.SessionOptions()
session_options.register_custom_ops_library(ort_custom_op_path)
# get onnx output
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [node.name for node in onnx_model.graph.initializer]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 2)
sess = rt.InferenceSession(onnx_file, session_options)
ort_result = sess.run(None, {
'input': input.detach().numpy(),
'grid': grid.detach().numpy()
})
pytorch_results = wrapped_model(input.clone(), grid.clone())
os.remove(onnx_file)
assert np.allclose(pytorch_results, ort_result, atol=1e-3)
def test_nms():
if torch.__version__ == 'parrots':
pytest.skip('onnx is not supported in parrots directly')
from mmcv.ops import get_onnxruntime_op_path, nms
np_boxes = np.array([[6.0, 3.0, 8.0, 7.0], [3.0, 6.0, 9.0, 11.0],
[3.0, 7.0, 10.0, 12.0], [1.0, 4.0, 13.0, 7.0]],
dtype=np.float32)
np_scores = np.array([0.6, 0.9, 0.7, 0.2], dtype=np.float32)
boxes = torch.from_numpy(np_boxes)
scores = torch.from_numpy(np_scores)
pytorch_dets, _ = nms(boxes, scores, iou_threshold=0.3, offset=0)
pytorch_score = pytorch_dets[:, 4]
nms = partial(nms, iou_threshold=0.3, offset=0)
wrapped_model = WrapFunction(nms)
wrapped_model.cpu().eval()
with torch.no_grad():
torch.onnx.export(
wrapped_model, (boxes, scores),
onnx_file,
export_params=True,
keep_initializers_as_inputs=True,
input_names=['boxes', 'scores'],
opset_version=11)
onnx_model = onnx.load(onnx_file)
ort_custom_op_path = get_onnxruntime_op_path()
if not os.path.exists(ort_custom_op_path):
pytest.skip('nms for onnxruntime is not compiled.')
session_options = rt.SessionOptions()
session_options.register_custom_ops_library(ort_custom_op_path)
# get onnx output
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [node.name for node in onnx_model.graph.initializer]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 2)
sess = rt.InferenceSession(onnx_file, session_options)
onnx_dets, _ = sess.run(None, {
'scores': scores.detach().numpy(),
'boxes': boxes.detach().numpy()
})
onnx_score = onnx_dets[:, 4]
os.remove(onnx_file)
assert np.allclose(pytorch_score, onnx_score, atol=1e-3)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires GPU')
def test_softnms():
if torch.__version__ == 'parrots':
pytest.skip('onnx is not supported in parrots directly')
from mmcv.ops import get_onnxruntime_op_path, soft_nms
# only support pytorch >= 1.7.0
if version.parse(torch.__version__) < version.parse('1.7.0'):
warnings.warn('test_softnms should be ran with pytorch >= 1.7.0')
return
# only support onnxruntime >= 1.5.1
assert version.parse(rt.__version__) >= version.parse(
'1.5.1'), 'test_softnms should be ran with onnxruntime >= 1.5.1'
ort_custom_op_path = get_onnxruntime_op_path()
if not os.path.exists(ort_custom_op_path):
pytest.skip('softnms for onnxruntime is not compiled.')
np_boxes = np.array([[6.0, 3.0, 8.0, 7.0], [3.0, 6.0, 9.0, 11.0],
[3.0, 7.0, 10.0, 12.0], [1.0, 4.0, 13.0, 7.0]],
dtype=np.float32)
np_scores = np.array([0.6, 0.9, 0.7, 0.2], dtype=np.float32)
boxes = torch.from_numpy(np_boxes)
scores = torch.from_numpy(np_scores)
configs = [[0.3, 0.5, 0.01, 'linear'], [0.3, 0.5, 0.01, 'gaussian'],
[0.3, 0.5, 0.01, 'naive']]
session_options = rt.SessionOptions()
session_options.register_custom_ops_library(ort_custom_op_path)
for _iou_threshold, _sigma, _min_score, _method in configs:
pytorch_dets, pytorch_inds = soft_nms(
boxes,
scores,
iou_threshold=_iou_threshold,
sigma=_sigma,
min_score=_min_score,
method=_method)
nms = partial(
soft_nms,
iou_threshold=_iou_threshold,
sigma=_sigma,
min_score=_min_score,
method=_method)
wrapped_model = WrapFunction(nms)
wrapped_model.cpu().eval()
with torch.no_grad():
torch.onnx.export(
wrapped_model, (boxes, scores),
onnx_file,
export_params=True,
keep_initializers_as_inputs=True,
input_names=['boxes', 'scores'],
opset_version=11)
onnx_model = onnx.load(onnx_file)
# get onnx output
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [
node.name for node in onnx_model.graph.initializer
]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 2)
sess = rt.InferenceSession(onnx_file, session_options)
onnx_dets, onnx_inds = sess.run(None, {
'scores': scores.detach().numpy(),
'boxes': boxes.detach().numpy()
})
os.remove(onnx_file)
assert np.allclose(pytorch_dets, onnx_dets, atol=1e-3)
assert np.allclose(onnx_inds, onnx_inds, atol=1e-3)
def test_roialign():
if torch.__version__ == 'parrots':
pytest.skip('onnx is not supported in parrots directly')
try:
from mmcv.ops import roi_align
from mmcv.ops import get_onnxruntime_op_path
except (ImportError, ModuleNotFoundError):
pytest.skip('roi_align op is not successfully compiled')
ort_custom_op_path = get_onnxruntime_op_path()
# roi align config
pool_h = 2
pool_w = 2
spatial_scale = 1.0
sampling_ratio = 2
inputs = [([[[[1., 2.], [3., 4.]]]], [[0., 0., 0., 1., 1.]]),
([[[[1., 2.], [3., 4.]], [[4., 3.],
[2., 1.]]]], [[0., 0., 0., 1., 1.]]),
([[[[1., 2., 5., 6.], [3., 4., 7., 8.], [9., 10., 13., 14.],
[11., 12., 15., 16.]]]], [[0., 0., 0., 3., 3.]])]
def warpped_function(torch_input, torch_rois):
return roi_align(torch_input, torch_rois, (pool_w, pool_h),
spatial_scale, sampling_ratio, 'avg', True)
for case in inputs:
np_input = np.array(case[0], dtype=np.float32)
np_rois = np.array(case[1], dtype=np.float32)
input = torch.from_numpy(np_input)
rois = torch.from_numpy(np_rois)
# compute pytorch_output
with torch.no_grad():
pytorch_output = roi_align(input, rois, (pool_w, pool_h),
spatial_scale, sampling_ratio, 'avg',
True)
# export and load onnx model
wrapped_model = WrapFunction(warpped_function)
with torch.no_grad():
torch.onnx.export(
wrapped_model, (input, rois),
onnx_file,
export_params=True,
keep_initializers_as_inputs=True,
input_names=['input', 'rois'],
opset_version=11)
onnx_model = onnx.load(onnx_file)
session_options = rt.SessionOptions()
if os.path.exists(ort_custom_op_path):
session_options.register_custom_ops_library(ort_custom_op_path)
# compute onnx_output
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [
node.name for node in onnx_model.graph.initializer
]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 2)
sess = rt.InferenceSession(onnx_file, session_options)
onnx_output = sess.run(None, {
'input': input.detach().numpy(),
'rois': rois.detach().numpy()
})
onnx_output = onnx_output[0]
# allclose
os.remove(onnx_file)
assert np.allclose(pytorch_output, onnx_output, atol=1e-3)
def test_roialign_rotated():
if torch.__version__ == 'parrots':
pytest.skip('onnx is not supported in parrots directly')
try:
from mmcv.ops import roi_align_rotated
from mmcv.ops import get_onnxruntime_op_path
except (ImportError, ModuleNotFoundError):
pytest.skip('roi_align_aligned op is not successfully compiled')
ort_custom_op_path = get_onnxruntime_op_path()
if not os.path.exists(ort_custom_op_path):
pytest.skip('custom ops for onnxruntime are not compiled.')
# roi align config
pool_h = 2
pool_w = 2
spatial_scale = 1.0
sampling_ratio = 2
inputs = [([[[[1., 2.], [3., 4.]]]], [[0., 0.5, 0.5, 1., 1., 0]]),
([[[[1., 2.], [3., 4.]]]], [[0., 0.5, 0.5, 1., 1., np.pi / 2]]),
([[[[1., 2.], [3., 4.]],
[[4., 3.], [2., 1.]]]], [[0., 0.5, 0.5, 1., 1., 0]]),
([[[[1., 2., 5., 6.], [3., 4., 7., 8.], [9., 10., 13., 14.],
[11., 12., 15., 16.]]]], [[0., 1.5, 1.5, 3., 3., 0]]),
([[[[1., 2., 5., 6.], [3., 4., 7., 8.], [9., 10., 13., 14.],
[11., 12., 15., 16.]]]], [[0., 1.5, 1.5, 3., 3.,
np.pi / 2]])]
def warpped_function(torch_input, torch_rois):
return roi_align_rotated(torch_input, torch_rois, (pool_w, pool_h),
spatial_scale, sampling_ratio, True, False)
for case in inputs:
np_input = np.array(case[0], dtype=np.float32)
np_rois = np.array(case[1], dtype=np.float32)
input = torch.from_numpy(np_input)
rois = torch.from_numpy(np_rois)
# compute pytorch_output
with torch.no_grad():
pytorch_output = roi_align_rotated(input, rois, (pool_w, pool_h),
spatial_scale, sampling_ratio,
True, False)
# export and load onnx model
wrapped_model = WrapFunction(warpped_function)
with torch.no_grad():
torch.onnx.export(
wrapped_model, (input, rois),
onnx_file,
export_params=True,
keep_initializers_as_inputs=True,
input_names=['features', 'rois'],
opset_version=11)
onnx_model = onnx.load(onnx_file)
session_options = rt.SessionOptions()
if os.path.exists(ort_custom_op_path):
session_options.register_custom_ops_library(ort_custom_op_path)
# compute onnx_output
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [
node.name for node in onnx_model.graph.initializer
]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 2)
sess = rt.InferenceSession(onnx_file, session_options)
onnx_output = sess.run(None, {
'features': input.detach().numpy(),
'rois': rois.detach().numpy()
})
onnx_output = onnx_output[0]
# allclose
os.remove(onnx_file)
assert np.allclose(pytorch_output, onnx_output, atol=1e-3)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires GPU')
def test_roipool():
if torch.__version__ == 'parrots':
pytest.skip('onnx is not supported in parrots directly')
from mmcv.ops import roi_pool
# roi pool config
pool_h = 2
pool_w = 2
spatial_scale = 1.0
inputs = [([[[[1., 2.], [3., 4.]]]], [[0., 0., 0., 1., 1.]]),
([[[[1., 2.], [3., 4.]], [[4., 3.],
[2., 1.]]]], [[0., 0., 0., 1., 1.]]),
([[[[1., 2., 5., 6.], [3., 4., 7., 8.], [9., 10., 13., 14.],
[11., 12., 15., 16.]]]], [[0., 0., 0., 3., 3.]])]
def warpped_function(torch_input, torch_rois):
return roi_pool(torch_input, torch_rois, (pool_w, pool_h),
spatial_scale)
for case in inputs:
np_input = np.array(case[0], dtype=np.float32)
np_rois = np.array(case[1], dtype=np.float32)
input = torch.from_numpy(np_input).cuda()
rois = torch.from_numpy(np_rois).cuda()
# compute pytorch_output
with torch.no_grad():
pytorch_output = roi_pool(input, rois, (pool_w, pool_h),
spatial_scale)
pytorch_output = pytorch_output.cpu()
# export and load onnx model
wrapped_model = WrapFunction(warpped_function)
with torch.no_grad():
torch.onnx.export(
wrapped_model, (input, rois),
onnx_file,
export_params=True,
keep_initializers_as_inputs=True,
input_names=['input', 'rois'],
opset_version=11)
onnx_model = onnx.load(onnx_file)
# compute onnx_output
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [
node.name for node in onnx_model.graph.initializer
]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 2)
sess = rt.InferenceSession(onnx_file)
onnx_output = sess.run(
None, {
'input': input.detach().cpu().numpy(),
'rois': rois.detach().cpu().numpy()
})
onnx_output = onnx_output[0]
# allclose
os.remove(onnx_file)
assert np.allclose(pytorch_output, onnx_output, atol=1e-3)
def test_interpolate():
from mmcv.onnx.symbolic import register_extra_symbolics
opset_version = 11
register_extra_symbolics(opset_version)
def func(feat, scale_factor=2):
out = nn.functional.interpolate(feat, scale_factor=scale_factor)
return out
net = WrapFunction(func)
net = net.cpu().eval()
dummy_input = torch.randn(2, 4, 8, 8).cpu()
torch.onnx.export(
net,
dummy_input,
onnx_file,
input_names=['input'],
opset_version=opset_version)
sess = rt.InferenceSession(onnx_file)
onnx_result = sess.run(None, {'input': dummy_input.detach().numpy()})
pytorch_result = func(dummy_input).detach().numpy()
if os.path.exists(onnx_file):
os.remove(onnx_file)
assert np.allclose(pytorch_result, onnx_result, atol=1e-3)