Remove torch_out from onnx export, no point without the export_ fn

This commit is contained in:
Ross Wightman 2025-04-15 12:07:19 -07:00 committed by Ross Wightman
parent 0cae8a4cd8
commit ceca5efdec

View File

@ -78,9 +78,8 @@ def onnx_export(
export_options=export_options, export_options=export_options,
) )
export_output.save(output_file) export_output.save(output_file)
torch_out = None
else: else:
torch_out = torch.onnx.export( torch.onnx.export(
model, model,
example_input, example_input,
output_file, output_file,
@ -101,9 +100,5 @@ def onnx_export(
if check_forward and not training: if check_forward and not training:
import numpy as np import numpy as np
onnx_out = onnx_forward(output_file, example_input) onnx_out = onnx_forward(output_file, example_input)
if torch_out is not None: np.testing.assert_almost_equal(original_out.numpy(), onnx_out, decimal=3)
np.testing.assert_almost_equal(torch_out.numpy(), onnx_out, decimal=3)
np.testing.assert_almost_equal(original_out.numpy(), torch_out.numpy(), decimal=5)
else:
np.testing.assert_almost_equal(original_out.numpy(), onnx_out, decimal=3)