Remove `.train()` mode exports (#9429)
* Remove `.train()` mode exports No common use cases. Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/9430/head
parent
a4ed988893
commit
1323b48053
11
export.py
11
export.py
|
@ -126,7 +126,7 @@ def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:'
|
|||
|
||||
|
||||
@try_export
|
||||
def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')):
|
||||
def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX:')):
|
||||
# YOLOv5 ONNX export
|
||||
check_requirements('onnx')
|
||||
import onnx
|
||||
|
@ -140,8 +140,7 @@ def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorst
|
|||
f,
|
||||
verbose=False,
|
||||
opset_version=opset,
|
||||
training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
|
||||
do_constant_folding=not train,
|
||||
do_constant_folding=True,
|
||||
input_names=['images'],
|
||||
output_names=['output'],
|
||||
dynamic_axes={
|
||||
|
@ -459,7 +458,6 @@ def run(
|
|||
include=('torchscript', 'onnx'), # include formats
|
||||
half=False, # FP16 half-precision export
|
||||
inplace=False, # set YOLOv5 Detect() inplace=True
|
||||
train=False, # model.train() mode
|
||||
keras=False, # use Keras
|
||||
optimize=False, # TorchScript: optimize for mobile
|
||||
int8=False, # CoreML/TF INT8 quantization
|
||||
|
@ -501,7 +499,7 @@ def run(
|
|||
im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
|
||||
|
||||
# Update model
|
||||
model.train() if train else model.eval() # training mode = no Detect() layer grid construction
|
||||
model.eval()
|
||||
for k, m in model.named_modules():
|
||||
if isinstance(m, Detect):
|
||||
m.inplace = inplace
|
||||
|
@ -524,7 +522,7 @@ def run(
|
|||
if engine: # TensorRT required before ONNX
|
||||
f[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose)
|
||||
if onnx or xml: # OpenVINO requires ONNX
|
||||
f[2], _ = export_onnx(model, im, file, opset, train, dynamic, simplify)
|
||||
f[2], _ = export_onnx(model, im, file, opset, dynamic, simplify)
|
||||
if xml: # OpenVINO
|
||||
f[3], _ = export_openvino(file, metadata, half)
|
||||
if coreml: # CoreML
|
||||
|
@ -578,7 +576,6 @@ def parse_opt():
|
|||
parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
||||
parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
|
||||
parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True')
|
||||
parser.add_argument('--train', action='store_true', help='model.train() mode')
|
||||
parser.add_argument('--keras', action='store_true', help='TF: use Keras')
|
||||
parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')
|
||||
parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization')
|
||||
|
|
Loading…
Reference in New Issue