Remove torch_out from onnx export, no point without the export_ fn
parent
0cae8a4cd8
commit
ceca5efdec
|
@ -78,9 +78,8 @@ def onnx_export(
|
|||
export_options=export_options,
|
||||
)
|
||||
export_output.save(output_file)
|
||||
torch_out = None
|
||||
else:
|
||||
torch_out = torch.onnx.export(
|
||||
torch.onnx.export(
|
||||
model,
|
||||
example_input,
|
||||
output_file,
|
||||
|
@ -101,9 +100,5 @@ def onnx_export(
|
|||
if check_forward and not training:
|
||||
import numpy as np
|
||||
onnx_out = onnx_forward(output_file, example_input)
|
||||
if torch_out is not None:
|
||||
np.testing.assert_almost_equal(torch_out.numpy(), onnx_out, decimal=3)
|
||||
np.testing.assert_almost_equal(original_out.numpy(), torch_out.numpy(), decimal=5)
|
||||
else:
|
||||
np.testing.assert_almost_equal(original_out.numpy(), onnx_out, decimal=3)
|
||||
|
||||
|
|
Loading…
Reference in New Issue