diff --git a/export.py b/export.py index 574fee050..c9ad158c5 100644 --- a/export.py +++ b/export.py @@ -475,9 +475,9 @@ def run( ): t = time.time() include = [x.lower() for x in include] # to lowercase - formats = tuple(export_formats()['Argument'][1:]) # --include arguments - flags = [x in include for x in formats] - assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {formats}' + fmts = tuple(export_formats()['Argument'][1:]) # --include arguments + flags = [x in include for x in fmts] + assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {fmts}' jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = flags # export booleans file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) # PyTorch weights @@ -499,7 +499,7 @@ def run( im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection # Update model - if half and not (coreml or xml): + if half and not coreml and not xml: im, model = im.half(), model.half() # to FP16 model.train() if train else model.eval() # training mode = no Detect() layer grid construction for k, m in model.named_modules(): @@ -531,7 +531,7 @@ def run( if any((saved_model, pb, tflite, edgetpu, tfjs)): if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707 check_requirements(('flatbuffers==1.12',)) # required before `import tensorflow` - assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.' + assert not tflite or not tfjs, 'TFLite and TF.js models must be exported separately, please pass only one type.' model, f[5] = export_saved_model(model.cpu(), im, file, diff --git a/utils/benchmarks.py b/utils/benchmarks.py index 99910050d..d0f2a2529 100644 --- a/utils/benchmarks.py +++ b/utils/benchmarks.py @@ -56,9 +56,8 @@ def run( pt_only=False, # test PyTorch only ): y, t = [], time.time() - formats = export.export_formats() device = select_device(device) - for i, (name, f, suffix, gpu) in formats.iterrows(): # index, (name, file, suffix, gpu-capable) + for i, (name, f, suffix, gpu) in export.export_formats().iterrows(): # index, (name, file, suffix, gpu-capable) try: assert i != 9, 'Edge TPU not supported' assert i != 10, 'TF.js not supported' @@ -104,9 +103,8 @@ def test( pt_only=False, # test PyTorch only ): y, t = [], time.time() - formats = export.export_formats() device = select_device(device) - for i, (name, f, suffix, gpu) in formats.iterrows(): # index, (name, file, suffix, gpu-capable) + for i, (name, f, suffix, gpu) in export.export_formats().iterrows(): # index, (name, file, suffix, gpu-capable) try: w = weights if f == '-' else \ export.run(weights=weights, imgsz=[imgsz], include=[f], device=device, half=half)[-1] # weights