mmcv/tests/test_ops/test_onnx.py

185 lines
6.3 KiB
Python

import os
from functools import partial
import numpy as np
import onnx
import onnxruntime as rt
import torch
import torch.nn as nn
onnx_file = 'tmp.onnx'
class WrapFunction(nn.Module):
def __init__(self, wrapped_function):
super(WrapFunction, self).__init__()
self.wrapped_function = wrapped_function
def forward(self, *args, **kwargs):
return self.wrapped_function(*args, **kwargs)
def test_nms():
from mmcv.ops import nms
np_boxes = np.array([[6.0, 3.0, 8.0, 7.0], [3.0, 6.0, 9.0, 11.0],
[3.0, 7.0, 10.0, 12.0], [1.0, 4.0, 13.0, 7.0]],
dtype=np.float32)
np_scores = np.array([0.6, 0.9, 0.7, 0.2], dtype=np.float32)
boxes = torch.from_numpy(np_boxes)
scores = torch.from_numpy(np_scores)
pytorch_dets, _ = nms(boxes, scores, iou_threshold=0.3, offset=0)
pytorch_score = pytorch_dets[:, 4]
nms = partial(nms, iou_threshold=0.3, offset=0)
wrapped_model = WrapFunction(nms)
wrapped_model.cpu().eval()
with torch.no_grad():
torch.onnx.export(
wrapped_model, (boxes, scores),
onnx_file,
export_params=True,
keep_initializers_as_inputs=True,
input_names=['boxes', 'scores'],
opset_version=11)
onnx_model = onnx.load(onnx_file)
# get onnx output
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [node.name for node in onnx_model.graph.initializer]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 2)
sess = rt.InferenceSession(onnx_file)
onnx_dets, _ = sess.run(None, {
'scores': scores.detach().numpy(),
'boxes': boxes.detach().numpy()
})
onnx_score = onnx_dets[:, 4]
os.remove(onnx_file)
assert np.allclose(pytorch_score, onnx_score, atol=1e-3)
def test_roialign():
from mmcv.ops import roi_align
# roi align config
pool_h = 2
pool_w = 2
spatial_scale = 1.0
sampling_ratio = 2
inputs = [([[[[1., 2.], [3., 4.]]]], [[0., 0., 0., 1., 1.]]),
([[[[1., 2.], [3., 4.]], [[4., 3.],
[2., 1.]]]], [[0., 0., 0., 1., 1.]]),
([[[[1., 2., 5., 6.], [3., 4., 7., 8.], [9., 10., 13., 14.],
[11., 12., 15., 16.]]]], [[0., 0., 0., 3., 3.]])]
def warpped_function(torch_input, torch_rois):
return roi_align(torch_input, torch_rois, (pool_w, pool_h),
spatial_scale, sampling_ratio, 'avg', True)
for case in inputs:
np_input = np.array(case[0], dtype=np.float32)
np_rois = np.array(case[1], dtype=np.float32)
input = torch.from_numpy(np_input)
rois = torch.from_numpy(np_rois)
# compute pytorch_output
with torch.no_grad():
pytorch_output = roi_align(input, rois, (pool_w, pool_h),
spatial_scale, sampling_ratio, 'avg',
True)
# export and load onnx model
wrapped_model = WrapFunction(warpped_function)
with torch.no_grad():
torch.onnx.export(
wrapped_model, (input, rois),
onnx_file,
export_params=True,
keep_initializers_as_inputs=True,
input_names=['input', 'rois'],
opset_version=11)
onnx_model = onnx.load(onnx_file)
# compute onnx_output
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [
node.name for node in onnx_model.graph.initializer
]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 2)
sess = rt.InferenceSession(onnx_file)
onnx_output = sess.run(None, {
'input': input.detach().numpy(),
'rois': rois.detach().numpy()
})
onnx_output = onnx_output[0]
# allclose
os.remove(onnx_file)
assert np.allclose(pytorch_output, onnx_output, atol=1e-3)
def test_roipool():
if not torch.cuda.is_available():
return
from mmcv.ops import roi_pool
# roi pool config
pool_h = 2
pool_w = 2
spatial_scale = 1.0
inputs = [([[[[1., 2.], [3., 4.]]]], [[0., 0., 0., 1., 1.]]),
([[[[1., 2.], [3., 4.]], [[4., 3.],
[2., 1.]]]], [[0., 0., 0., 1., 1.]]),
([[[[1., 2., 5., 6.], [3., 4., 7., 8.], [9., 10., 13., 14.],
[11., 12., 15., 16.]]]], [[0., 0., 0., 3., 3.]])]
def warpped_function(torch_input, torch_rois):
return roi_pool(torch_input, torch_rois, (pool_w, pool_h),
spatial_scale)
for case in inputs:
np_input = np.array(case[0], dtype=np.float32)
np_rois = np.array(case[1], dtype=np.float32)
input = torch.from_numpy(np_input).cuda()
rois = torch.from_numpy(np_rois).cuda()
# compute pytorch_output
with torch.no_grad():
pytorch_output = roi_pool(input, rois, (pool_w, pool_h),
spatial_scale)
pytorch_output = pytorch_output.cpu()
# export and load onnx model
wrapped_model = WrapFunction(warpped_function)
with torch.no_grad():
torch.onnx.export(
wrapped_model, (input, rois),
onnx_file,
export_params=True,
keep_initializers_as_inputs=True,
input_names=['input', 'rois'],
opset_version=11)
onnx_model = onnx.load(onnx_file)
# compute onnx_output
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [
node.name for node in onnx_model.graph.initializer
]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 2)
sess = rt.InferenceSession(onnx_file)
onnx_output = sess.run(
None, {
'input': input.detach().cpu().numpy(),
'rois': rois.detach().cpu().numpy()
})
onnx_output = onnx_output[0]
# allclose
os.remove(onnx_file)
assert np.allclose(pytorch_output, onnx_output, atol=1e-3)