diff --git a/timm/utils/onnx.py b/timm/utils/onnx.py index 5c4fd16f..2ee920ff 100644 --- a/timm/utils/onnx.py +++ b/timm/utils/onnx.py @@ -78,9 +78,8 @@ def onnx_export( export_options=export_options, ) export_output.save(output_file) - torch_out = None else: - torch_out = torch.onnx.export( + torch.onnx.export( model, example_input, output_file, @@ -101,9 +100,5 @@ def onnx_export( if check_forward and not training: import numpy as np onnx_out = onnx_forward(output_file, example_input) - if torch_out is not None: - np.testing.assert_almost_equal(torch_out.numpy(), onnx_out, decimal=3) - np.testing.assert_almost_equal(original_out.numpy(), torch_out.numpy(), decimal=5) - else: - np.testing.assert_almost_equal(original_out.numpy(), onnx_out, decimal=3) + np.testing.assert_almost_equal(original_out.numpy(), onnx_out, decimal=3)