Remove torch_out from onnx export, no point without the export_ fn

pull/2418/merge
Ross Wightman 2025-04-15 12:07:19 -07:00 committed by Ross Wightman
parent 0cae8a4cd8
commit ceca5efdec
1 changed files with 2 additions and 7 deletions

View File

@ -78,9 +78,8 @@ def onnx_export(
export_options=export_options,
)
export_output.save(output_file)
torch_out = None
else:
torch_out = torch.onnx.export(
torch.onnx.export(
model,
example_input,
output_file,
@ -101,9 +100,5 @@ def onnx_export(
if check_forward and not training:
import numpy as np
onnx_out = onnx_forward(output_file, example_input)
if torch_out is not None:
np.testing.assert_almost_equal(torch_out.numpy(), onnx_out, decimal=3)
np.testing.assert_almost_equal(original_out.numpy(), torch_out.numpy(), decimal=5)
else:
np.testing.assert_almost_equal(original_out.numpy(), onnx_out, decimal=3)