Remove `.train()` mode exports (#9429)

* Remove `.train()` mode exports

No common use cases.

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/9430/head
Glenn Jocher 2022-09-15 19:05:10 +02:00 committed by GitHub
parent a4ed988893
commit 1323b48053
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 7 deletions

View File

@ -126,7 +126,7 @@ def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:'
@try_export
def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')):
def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX:')):
# YOLOv5 ONNX export
check_requirements('onnx')
import onnx
@ -140,8 +140,7 @@ def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorst
f,
verbose=False,
opset_version=opset,
training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
do_constant_folding=not train,
do_constant_folding=True,
input_names=['images'],
output_names=['output'],
dynamic_axes={
@ -459,7 +458,6 @@ def run(
include=('torchscript', 'onnx'), # include formats
half=False, # FP16 half-precision export
inplace=False, # set YOLOv5 Detect() inplace=True
train=False, # model.train() mode
keras=False, # use Keras
optimize=False, # TorchScript: optimize for mobile
int8=False, # CoreML/TF INT8 quantization
@ -501,7 +499,7 @@ def run(
im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
# Update model
model.train() if train else model.eval() # training mode = no Detect() layer grid construction
model.eval()
for k, m in model.named_modules():
if isinstance(m, Detect):
m.inplace = inplace
@ -524,7 +522,7 @@ def run(
if engine: # TensorRT required before ONNX
f[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose)
if onnx or xml: # OpenVINO requires ONNX
f[2], _ = export_onnx(model, im, file, opset, train, dynamic, simplify)
f[2], _ = export_onnx(model, im, file, opset, dynamic, simplify)
if xml: # OpenVINO
f[3], _ = export_openvino(file, metadata, half)
if coreml: # CoreML
@ -578,7 +576,6 @@ def parse_opt():
parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True')
parser.add_argument('--train', action='store_true', help='model.train() mode')
parser.add_argument('--keras', action='store_true', help='TF: use Keras')
parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')
parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization')