mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Remove torch_out from onnx export, no point without the export_ fn
This commit is contained in:
parent
0cae8a4cd8
commit
ceca5efdec
@ -78,9 +78,8 @@ def onnx_export(
|
|||||||
export_options=export_options,
|
export_options=export_options,
|
||||||
)
|
)
|
||||||
export_output.save(output_file)
|
export_output.save(output_file)
|
||||||
torch_out = None
|
|
||||||
else:
|
else:
|
||||||
torch_out = torch.onnx.export(
|
torch.onnx.export(
|
||||||
model,
|
model,
|
||||||
example_input,
|
example_input,
|
||||||
output_file,
|
output_file,
|
||||||
@ -101,9 +100,5 @@ def onnx_export(
|
|||||||
if check_forward and not training:
|
if check_forward and not training:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
onnx_out = onnx_forward(output_file, example_input)
|
onnx_out = onnx_forward(output_file, example_input)
|
||||||
if torch_out is not None:
|
np.testing.assert_almost_equal(original_out.numpy(), onnx_out, decimal=3)
|
||||||
np.testing.assert_almost_equal(torch_out.numpy(), onnx_out, decimal=3)
|
|
||||||
np.testing.assert_almost_equal(original_out.numpy(), torch_out.numpy(), decimal=5)
|
|
||||||
else:
|
|
||||||
np.testing.assert_almost_equal(original_out.numpy(), onnx_out, decimal=3)
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user