New `@try_export` decorator (#9096)

* New export decorator

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* New export decorator

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Cleanup

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* rename fcn to func

* rename to @try_export

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/9105/head
Glenn Jocher 2022-08-23 13:06:33 +02:00 committed by GitHub
parent eab35f66f9
commit d0fa0042bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 285 additions and 299 deletions

569
export.py
View File

@ -67,8 +67,8 @@ if platform.system() != 'Windows':
from models.experimental import attempt_load from models.experimental import attempt_load
from models.yolo import Detect from models.yolo import Detect
from utils.dataloaders import LoadImages from utils.dataloaders import LoadImages
from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, check_version, check_yaml, from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_version,
colorstr, file_size, print_args, url2file) check_yaml, colorstr, file_size, get_default_args, print_args, url2file)
from utils.torch_utils import select_device, smart_inference_mode from utils.torch_utils import select_device, smart_inference_mode
@ -89,200 +89,199 @@ def export_formats():
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU']) return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
def try_export(inner_func):
# YOLOv5 export decorator, i..e @try_export
inner_args = get_default_args(inner_func)
def outer_func(*args, **kwargs):
prefix = inner_args['prefix']
try:
with Profile() as dt:
f, model = inner_func(*args, **kwargs)
LOGGER.info(f'{prefix} export success ✅ {dt.t:.1f}s, saved as {f} ({file_size(f):.1f} MB)')
return f, model
except Exception as e:
LOGGER.info(f'{prefix} export failure ❌ {dt.t:.1f}s: {e}')
return None, None
return outer_func
@try_export
def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')): def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):
# YOLOv5 TorchScript model export # YOLOv5 TorchScript model export
try: LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')
LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...') f = file.with_suffix('.torchscript')
f = file.with_suffix('.torchscript')
ts = torch.jit.trace(model, im, strict=False) ts = torch.jit.trace(model, im, strict=False)
d = {"shape": im.shape, "stride": int(max(model.stride)), "names": model.names} d = {"shape": im.shape, "stride": int(max(model.stride)), "names": model.names}
extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap() extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap()
if optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html if optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files) optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)
else: else:
ts.save(str(f), _extra_files=extra_files) ts.save(str(f), _extra_files=extra_files)
return f, None
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
return f
except Exception as e:
LOGGER.info(f'{prefix} export failure: {e}')
@try_export
def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')): def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')):
# YOLOv5 ONNX export # YOLOv5 ONNX export
try: check_requirements(('onnx',))
check_requirements(('onnx',)) import onnx
import onnx
LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...') LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
f = file.with_suffix('.onnx') f = file.with_suffix('.onnx')
torch.onnx.export( torch.onnx.export(
model.cpu() if dynamic else model, # --dynamic only compatible with cpu model.cpu() if dynamic else model, # --dynamic only compatible with cpu
im.cpu() if dynamic else im, im.cpu() if dynamic else im,
f, f,
verbose=False, verbose=False,
opset_version=opset, opset_version=opset,
training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL, training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
do_constant_folding=not train, do_constant_folding=not train,
input_names=['images'], input_names=['images'],
output_names=['output'], output_names=['output'],
dynamic_axes={ dynamic_axes={
'images': { 'images': {
0: 'batch', 0: 'batch',
2: 'height', 2: 'height',
3: 'width'}, # shape(1,3,640,640) 3: 'width'}, # shape(1,3,640,640)
'output': { 'output': {
0: 'batch', 0: 'batch',
1: 'anchors'} # shape(1,25200,85) 1: 'anchors'} # shape(1,25200,85)
} if dynamic else None) } if dynamic else None)
# Checks # Checks
model_onnx = onnx.load(f) # load onnx model model_onnx = onnx.load(f) # load onnx model
onnx.checker.check_model(model_onnx) # check onnx model onnx.checker.check_model(model_onnx) # check onnx model
# Metadata # Metadata
d = {'stride': int(max(model.stride)), 'names': model.names} d = {'stride': int(max(model.stride)), 'names': model.names}
for k, v in d.items(): for k, v in d.items():
meta = model_onnx.metadata_props.add() meta = model_onnx.metadata_props.add()
meta.key, meta.value = k, str(v) meta.key, meta.value = k, str(v)
onnx.save(model_onnx, f) onnx.save(model_onnx, f)
# Simplify # Simplify
if simplify: if simplify:
try: try:
cuda = torch.cuda.is_available() cuda = torch.cuda.is_available()
check_requirements(('onnxruntime-gpu' if cuda else 'onnxruntime', 'onnx-simplifier>=0.4.1')) check_requirements(('onnxruntime-gpu' if cuda else 'onnxruntime', 'onnx-simplifier>=0.4.1'))
import onnxsim import onnxsim
LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...') LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
model_onnx, check = onnxsim.simplify(model_onnx) model_onnx, check = onnxsim.simplify(model_onnx)
assert check, 'assert check failed' assert check, 'assert check failed'
onnx.save(model_onnx, f) onnx.save(model_onnx, f)
except Exception as e: except Exception as e:
LOGGER.info(f'{prefix} simplifier failure: {e}') LOGGER.info(f'{prefix} simplifier failure: {e}')
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') return f, model_onnx
return f
except Exception as e:
LOGGER.info(f'{prefix} export failure: {e}')
@try_export
def export_openvino(model, file, half, prefix=colorstr('OpenVINO:')): def export_openvino(model, file, half, prefix=colorstr('OpenVINO:')):
# YOLOv5 OpenVINO export # YOLOv5 OpenVINO export
try: check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/
check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/ import openvino.inference_engine as ie
import openvino.inference_engine as ie
LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...') LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...')
f = str(file).replace('.pt', f'_openvino_model{os.sep}') f = str(file).replace('.pt', f'_openvino_model{os.sep}')
cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} --data_type {'FP16' if half else 'FP32'}" cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} --data_type {'FP16' if half else 'FP32'}"
subprocess.check_output(cmd.split()) # export subprocess.check_output(cmd.split()) # export
with open(Path(f) / file.with_suffix('.yaml').name, 'w') as g: with open(Path(f) / file.with_suffix('.yaml').name, 'w') as g:
yaml.dump({'stride': int(max(model.stride)), 'names': model.names}, g) # add metadata.yaml yaml.dump({'stride': int(max(model.stride)), 'names': model.names}, g) # add metadata.yaml
return f, None
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
return f
except Exception as e:
LOGGER.info(f'\n{prefix} export failure: {e}')
@try_export
def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')): def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')):
# YOLOv5 CoreML export # YOLOv5 CoreML export
try: check_requirements(('coremltools',))
check_requirements(('coremltools',)) import coremltools as ct
import coremltools as ct
LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...') LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')
f = file.with_suffix('.mlmodel') f = file.with_suffix('.mlmodel')
ts = torch.jit.trace(model, im, strict=False) # TorchScript model ts = torch.jit.trace(model, im, strict=False) # TorchScript model
ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255, bias=[0, 0, 0])]) ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255, bias=[0, 0, 0])])
bits, mode = (8, 'kmeans_lut') if int8 else (16, 'linear') if half else (32, None) bits, mode = (8, 'kmeans_lut') if int8 else (16, 'linear') if half else (32, None)
if bits < 32: if bits < 32:
if platform.system() == 'Darwin': # quantization only supported on macOS if platform.system() == 'Darwin': # quantization only supported on macOS
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress numpy==1.20 float warning warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress numpy==1.20 float warning
ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode) ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
else: else:
print(f'{prefix} quantization only supported on macOS, skipping...') print(f'{prefix} quantization only supported on macOS, skipping...')
ct_model.save(f) ct_model.save(f)
return f, ct_model
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
return ct_model, f
except Exception as e:
LOGGER.info(f'\n{prefix} export failure: {e}')
return None, None
def export_engine(model, im, file, train, half, dynamic, simplify, workspace=4, verbose=False): @try_export
def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
# YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt # YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt
prefix = colorstr('TensorRT:') assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
try: try:
assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`' import tensorrt as trt
try: except Exception:
import tensorrt as trt if platform.system() == 'Linux':
except Exception: check_requirements(('nvidia-tensorrt',), cmds=('-U --index-url https://pypi.ngc.nvidia.com',))
if platform.system() == 'Linux': import tensorrt as trt
check_requirements(('nvidia-tensorrt',), cmds=('-U --index-url https://pypi.ngc.nvidia.com',))
import tensorrt as trt
if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012 if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
grid = model.model[-1].anchor_grid grid = model.model[-1].anchor_grid
model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid] model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
export_onnx(model, im, file, 12, train, dynamic, simplify) # opset 12 export_onnx(model, im, file, 12, False, dynamic, simplify) # opset 12
model.model[-1].anchor_grid = grid model.model[-1].anchor_grid = grid
else: # TensorRT >= 8 else: # TensorRT >= 8
check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0 check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0
export_onnx(model, im, file, 13, train, dynamic, simplify) # opset 13 export_onnx(model, im, file, 13, False, dynamic, simplify) # opset 13
onnx = file.with_suffix('.onnx') onnx = file.with_suffix('.onnx')
LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...') LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
assert onnx.exists(), f'failed to export ONNX file: {onnx}' assert onnx.exists(), f'failed to export ONNX file: {onnx}'
f = file.with_suffix('.engine') # TensorRT engine file f = file.with_suffix('.engine') # TensorRT engine file
logger = trt.Logger(trt.Logger.INFO) logger = trt.Logger(trt.Logger.INFO)
if verbose: if verbose:
logger.min_severity = trt.Logger.Severity.VERBOSE logger.min_severity = trt.Logger.Severity.VERBOSE
builder = trt.Builder(logger) builder = trt.Builder(logger)
config = builder.create_builder_config() config = builder.create_builder_config()
config.max_workspace_size = workspace * 1 << 30 config.max_workspace_size = workspace * 1 << 30
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
network = builder.create_network(flag) network = builder.create_network(flag)
parser = trt.OnnxParser(network, logger) parser = trt.OnnxParser(network, logger)
if not parser.parse_from_file(str(onnx)): if not parser.parse_from_file(str(onnx)):
raise RuntimeError(f'failed to load ONNX file: {onnx}') raise RuntimeError(f'failed to load ONNX file: {onnx}')
inputs = [network.get_input(i) for i in range(network.num_inputs)] inputs = [network.get_input(i) for i in range(network.num_inputs)]
outputs = [network.get_output(i) for i in range(network.num_outputs)] outputs = [network.get_output(i) for i in range(network.num_outputs)]
LOGGER.info(f'{prefix} Network Description:') LOGGER.info(f'{prefix} Network Description:')
for inp in inputs:
LOGGER.info(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}')
for out in outputs:
LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')
if dynamic:
if im.shape[0] <= 1:
LOGGER.warning(f"{prefix}WARNING: --dynamic model requires maximum --batch-size argument")
profile = builder.create_optimization_profile()
for inp in inputs: for inp in inputs:
LOGGER.info(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}') profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)
for out in outputs: config.add_optimization_profile(profile)
LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')
if dynamic: LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine in {f}')
if im.shape[0] <= 1: if builder.platform_has_fast_fp16 and half:
LOGGER.warning(f"{prefix}WARNING: --dynamic model requires maximum --batch-size argument") config.set_flag(trt.BuilderFlag.FP16)
profile = builder.create_optimization_profile() with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
for inp in inputs: t.write(engine.serialize())
profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape) return f, None
config.add_optimization_profile(profile)
LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine in {f}')
if builder.platform_has_fast_fp16 and half:
config.set_flag(trt.BuilderFlag.FP16)
with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
t.write(engine.serialize())
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
return f
except Exception as e:
LOGGER.info(f'\n{prefix} export failure: {e}')
@try_export
def export_saved_model(model, def export_saved_model(model,
im, im,
file, file,
@ -296,162 +295,142 @@ def export_saved_model(model,
keras=False, keras=False,
prefix=colorstr('TensorFlow SavedModel:')): prefix=colorstr('TensorFlow SavedModel:')):
# YOLOv5 TensorFlow SavedModel export # YOLOv5 TensorFlow SavedModel export
try: import tensorflow as tf
import tensorflow as tf from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
from models.tf import TFDetect, TFModel from models.tf import TFModel
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
f = str(file).replace('.pt', '_saved_model') f = str(file).replace('.pt', '_saved_model')
batch_size, ch, *imgsz = list(im.shape) # BCHW batch_size, ch, *imgsz = list(im.shape) # BCHW
tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz) tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
im = tf.zeros((batch_size, *imgsz, ch)) # BHWC order for TensorFlow im = tf.zeros((batch_size, *imgsz, ch)) # BHWC order for TensorFlow
_ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres) _ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
inputs = tf.keras.Input(shape=(*imgsz, ch), batch_size=None if dynamic else batch_size) inputs = tf.keras.Input(shape=(*imgsz, ch), batch_size=None if dynamic else batch_size)
outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres) outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
keras_model = tf.keras.Model(inputs=inputs, outputs=outputs) keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
keras_model.trainable = False keras_model.trainable = False
keras_model.summary() keras_model.summary()
if keras: if keras:
keras_model.save(f, save_format='tf') keras_model.save(f, save_format='tf')
else: else:
spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype) spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)
m = tf.function(lambda x: keras_model(x)) # full model m = tf.function(lambda x: keras_model(x)) # full model
m = m.get_concrete_function(spec) m = m.get_concrete_function(spec)
frozen_func = convert_variables_to_constants_v2(m) frozen_func = convert_variables_to_constants_v2(m)
tfm = tf.Module() tfm = tf.Module()
tfm.__call__ = tf.function(lambda x: frozen_func(x)[:4] if tf_nms else frozen_func(x)[0], [spec]) tfm.__call__ = tf.function(lambda x: frozen_func(x)[:4] if tf_nms else frozen_func(x)[0], [spec])
tfm.__call__(im) tfm.__call__(im)
tf.saved_model.save(tfm, tf.saved_model.save(tfm,
f, f,
options=tf.saved_model.SaveOptions(experimental_custom_gradients=False) options=tf.saved_model.SaveOptions(experimental_custom_gradients=False) if check_version(
if check_version(tf.__version__, '2.6') else tf.saved_model.SaveOptions()) tf.__version__, '2.6') else tf.saved_model.SaveOptions())
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') return f, keras_model
return keras_model, f
except Exception as e:
LOGGER.info(f'\n{prefix} export failure: {e}')
return None, None
@try_export
def export_pb(keras_model, file, prefix=colorstr('TensorFlow GraphDef:')): def export_pb(keras_model, file, prefix=colorstr('TensorFlow GraphDef:')):
# YOLOv5 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow # YOLOv5 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow
try: import tensorflow as tf
import tensorflow as tf from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
f = file.with_suffix('.pb') f = file.with_suffix('.pb')
m = tf.function(lambda x: keras_model(x)) # full model m = tf.function(lambda x: keras_model(x)) # full model
m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)) m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
frozen_func = convert_variables_to_constants_v2(m) frozen_func = convert_variables_to_constants_v2(m)
frozen_func.graph.as_graph_def() frozen_func.graph.as_graph_def()
tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False) tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
return f, None
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
return f
except Exception as e:
LOGGER.info(f'\n{prefix} export failure: {e}')
@try_export
def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')): def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):
# YOLOv5 TensorFlow Lite export # YOLOv5 TensorFlow Lite export
try: import tensorflow as tf
import tensorflow as tf
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
batch_size, ch, *imgsz = list(im.shape) # BCHW batch_size, ch, *imgsz = list(im.shape) # BCHW
f = str(file).replace('.pt', '-fp16.tflite') f = str(file).replace('.pt', '-fp16.tflite')
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS] converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
converter.target_spec.supported_types = [tf.float16] converter.target_spec.supported_types = [tf.float16]
converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.optimizations = [tf.lite.Optimize.DEFAULT]
if int8: if int8:
from models.tf import representative_dataset_gen from models.tf import representative_dataset_gen
dataset = LoadImages(check_dataset(check_yaml(data))['train'], img_size=imgsz, auto=False) dataset = LoadImages(check_dataset(check_yaml(data))['train'], img_size=imgsz, auto=False)
converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib=100) converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib=100)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.target_spec.supported_types = [] converter.target_spec.supported_types = []
converter.inference_input_type = tf.uint8 # or tf.int8 converter.inference_input_type = tf.uint8 # or tf.int8
converter.inference_output_type = tf.uint8 # or tf.int8 converter.inference_output_type = tf.uint8 # or tf.int8
converter.experimental_new_quantizer = True converter.experimental_new_quantizer = True
f = str(file).replace('.pt', '-int8.tflite') f = str(file).replace('.pt', '-int8.tflite')
if nms or agnostic_nms: if nms or agnostic_nms:
converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS) converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)
tflite_model = converter.convert() tflite_model = converter.convert()
open(f, "wb").write(tflite_model) open(f, "wb").write(tflite_model)
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') return f, None
return f
except Exception as e:
LOGGER.info(f'\n{prefix} export failure: {e}')
@try_export
def export_edgetpu(file, prefix=colorstr('Edge TPU:')): def export_edgetpu(file, prefix=colorstr('Edge TPU:')):
# YOLOv5 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/ # YOLOv5 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/
try: cmd = 'edgetpu_compiler --version'
cmd = 'edgetpu_compiler --version' help_url = 'https://coral.ai/docs/edgetpu/compiler/'
help_url = 'https://coral.ai/docs/edgetpu/compiler/' assert platform.system() == 'Linux', f'export only supported on Linux. See {help_url}'
assert platform.system() == 'Linux', f'export only supported on Linux. See {help_url}' if subprocess.run(f'{cmd} >/dev/null', shell=True).returncode != 0:
if subprocess.run(f'{cmd} >/dev/null', shell=True).returncode != 0: LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')
LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}') sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system
sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system for c in (
for c in ( 'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -',
'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -', 'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list',
'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list', 'sudo apt-get update', 'sudo apt-get install edgetpu-compiler'):
'sudo apt-get update', 'sudo apt-get install edgetpu-compiler'): subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True)
subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True) ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...') LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
f = str(file).replace('.pt', '-int8_edgetpu.tflite') # Edge TPU model f = str(file).replace('.pt', '-int8_edgetpu.tflite') # Edge TPU model
f_tfl = str(file).replace('.pt', '-int8.tflite') # TFLite model f_tfl = str(file).replace('.pt', '-int8.tflite') # TFLite model
cmd = f"edgetpu_compiler -s -d -k 10 --out_dir {file.parent} {f_tfl}" cmd = f"edgetpu_compiler -s -d -k 10 --out_dir {file.parent} {f_tfl}"
subprocess.run(cmd.split(), check=True) subprocess.run(cmd.split(), check=True)
return f, None
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
return f
except Exception as e:
LOGGER.info(f'\n{prefix} export failure: {e}')
@try_export
def export_tfjs(file, prefix=colorstr('TensorFlow.js:')): def export_tfjs(file, prefix=colorstr('TensorFlow.js:')):
# YOLOv5 TensorFlow.js export # YOLOv5 TensorFlow.js export
try: check_requirements(('tensorflowjs',))
check_requirements(('tensorflowjs',)) import re
import re
import tensorflowjs as tfjs import tensorflowjs as tfjs
LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...') LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
f = str(file).replace('.pt', '_web_model') # js dir f = str(file).replace('.pt', '_web_model') # js dir
f_pb = file.with_suffix('.pb') # *.pb path f_pb = file.with_suffix('.pb') # *.pb path
f_json = f'{f}/model.json' # *.json path f_json = f'{f}/model.json' # *.json path
cmd = f'tensorflowjs_converter --input_format=tf_frozen_model ' \ cmd = f'tensorflowjs_converter --input_format=tf_frozen_model ' \
f'--output_node_names=Identity,Identity_1,Identity_2,Identity_3 {f_pb} {f}' f'--output_node_names=Identity,Identity_1,Identity_2,Identity_3 {f_pb} {f}'
subprocess.run(cmd.split()) subprocess.run(cmd.split())
json = Path(f_json).read_text() json = Path(f_json).read_text()
with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order
subst = re.sub( subst = re.sub(
r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, ' r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
r'"Identity.?.?": {"name": "Identity.?.?"}, ' r'"Identity.?.?": {"name": "Identity.?.?"}, '
r'"Identity.?.?": {"name": "Identity.?.?"}, ' r'"Identity.?.?": {"name": "Identity.?.?"}, '
r'"Identity.?.?": {"name": "Identity.?.?"}}}', r'{"outputs": {"Identity": {"name": "Identity"}, ' r'"Identity.?.?": {"name": "Identity.?.?"}}}', r'{"outputs": {"Identity": {"name": "Identity"}, '
r'"Identity_1": {"name": "Identity_1"}, ' r'"Identity_1": {"name": "Identity_1"}, '
r'"Identity_2": {"name": "Identity_2"}, ' r'"Identity_2": {"name": "Identity_2"}, '
r'"Identity_3": {"name": "Identity_3"}}}', json) r'"Identity_3": {"name": "Identity_3"}}}', json)
j.write(subst) j.write(subst)
return f, None
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
return f
except Exception as e:
LOGGER.info(f'\n{prefix} export failure: {e}')
@smart_inference_mode() @smart_inference_mode()
@ -524,22 +503,22 @@ def run(
f = [''] * 10 # exported filenames f = [''] * 10 # exported filenames
warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
if jit: if jit:
f[0] = export_torchscript(model, im, file, optimize) f[0], _ = export_torchscript(model, im, file, optimize)
if engine: # TensorRT required before ONNX if engine: # TensorRT required before ONNX
f[1] = export_engine(model, im, file, train, half, dynamic, simplify, workspace, verbose) f[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose)
if onnx or xml: # OpenVINO requires ONNX if onnx or xml: # OpenVINO requires ONNX
f[2] = export_onnx(model, im, file, opset, train, dynamic, simplify) f[2], _ = export_onnx(model, im, file, opset, train, dynamic, simplify)
if xml: # OpenVINO if xml: # OpenVINO
f[3] = export_openvino(model, file, half) f[3], _ = export_openvino(model, file, half)
if coreml: if coreml:
_, f[4] = export_coreml(model, im, file, int8, half) f[4], _ = export_coreml(model, im, file, int8, half)
# TensorFlow Exports # TensorFlow Exports
if any((saved_model, pb, tflite, edgetpu, tfjs)): if any((saved_model, pb, tflite, edgetpu, tfjs)):
if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707 if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707
check_requirements(('flatbuffers==1.12',)) # required before `import tensorflow` check_requirements(('flatbuffers==1.12',)) # required before `import tensorflow`
assert not tflite or not tfjs, 'TFLite and TF.js models must be exported separately, please pass only one type.' assert not tflite or not tfjs, 'TFLite and TF.js models must be exported separately, please pass only one type.'
model, f[5] = export_saved_model(model.cpu(), f[5], model = export_saved_model(model.cpu(),
im, im,
file, file,
dynamic, dynamic,
@ -551,19 +530,19 @@ def run(
conf_thres=conf_thres, conf_thres=conf_thres,
keras=keras) keras=keras)
if pb or tfjs: # pb prerequisite to tfjs if pb or tfjs: # pb prerequisite to tfjs
f[6] = export_pb(model, file) f[6], _ = export_pb(model, file)
if tflite or edgetpu: if tflite or edgetpu:
f[7] = export_tflite(model, im, file, int8=int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms) f[7], _ = export_tflite(model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms)
if edgetpu: if edgetpu:
f[8] = export_edgetpu(file) f[8], _ = export_edgetpu(file)
if tfjs: if tfjs:
f[9] = export_tfjs(file) f[9], _ = export_tfjs(file)
# Finish # Finish
f = [str(x) for x in f if x] # filter out '' and None f = [str(x) for x in f if x] # filter out '' and None
if any(f): if any(f):
h = '--half' if half else '' # --half FP16 inference arg h = '--half' if half else '' # --half FP16 inference arg
LOGGER.info(f'\nExport complete ({time.time() - t:.2f}s)' LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
f"\nResults saved to {colorstr('bold', file.parent.resolve())}" f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
f"\nDetect: python detect.py --weights {f[-1]} {h}" f"\nDetect: python detect.py --weights {f[-1]} {h}"
f"\nValidate: python val.py --weights {f[-1]} {h}" f"\nValidate: python val.py --weights {f[-1]} {h}"

View File

@ -148,6 +148,7 @@ class Profile(contextlib.ContextDecorator):
def __enter__(self): def __enter__(self):
self.start = self.time() self.start = self.time()
return self
def __exit__(self, type, value, traceback): def __exit__(self, type, value, traceback):
self.dt = self.time() - self.start # delta-time self.dt = self.time() - self.start # delta-time
@ -220,10 +221,10 @@ def methods(instance):
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")] return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
def print_args(args: Optional[dict] = None, show_file=True, show_fcn=False): def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
# Print function arguments (optional args dict) # Print function arguments (optional args dict)
x = inspect.currentframe().f_back # previous frame x = inspect.currentframe().f_back # previous frame
file, _, fcn, _, _ = inspect.getframeinfo(x) file, _, func, _, _ = inspect.getframeinfo(x)
if args is None: # get args automatically if args is None: # get args automatically
args, _, _, frm = inspect.getargvalues(x) args, _, _, frm = inspect.getargvalues(x)
args = {k: v for k, v in frm.items() if k in args} args = {k: v for k, v in frm.items() if k in args}
@ -231,7 +232,7 @@ def print_args(args: Optional[dict] = None, show_file=True, show_fcn=False):
file = Path(file).resolve().relative_to(ROOT).with_suffix('') file = Path(file).resolve().relative_to(ROOT).with_suffix('')
except ValueError: except ValueError:
file = Path(file).stem file = Path(file).stem
s = (f'{file}: ' if show_file else '') + (f'{fcn}: ' if show_fcn else '') s = (f'{file}: ' if show_file else '') + (f'{func}: ' if show_func else '')
LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items())) LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
@ -255,7 +256,13 @@ def init_seeds(seed=0, deterministic=False):
def intersect_dicts(da, db, exclude=()): def intersect_dicts(da, db, exclude=()):
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape} return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
def get_default_args(func):
# Get func() default arguments
signature = inspect.signature(func)
return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
def get_latest_run(search_dir='.'): def get_latest_run(search_dir='.'):