Add PaddlePaddle export and inference (#9240)
* Add PaddlePaddle Model Export Test on Yolov5 DockerEnviroment with paddlepaddle-gpu v2.2 Signed-off-by: Katteria <39751846+kisaragychihaya@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Cleanup Paddle Export Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update common.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update export.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update export.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update export.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update export.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Use PyTorch2Paddle Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Paddle no longer requires ONNX Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update export.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update export.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update benchmarks.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Add inference code of PaddlePaddle Signed-off-by: Katteria <39751846+kisaragychihaya@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update common.py Signed-off-by: Katteria <39751846+kisaragychihaya@users.noreply.github.com> * Update common.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Add paddlepaddle-gpu install if cuda Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update common.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update common.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update common.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Signed-off-by: Katteria <39751846+kisaragychihaya@users.noreply.github.com> Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>pull/9358/head
parent
57ef676af2
commit
e3e5122f82
72
export.py
72
export.py
|
@ -15,6 +15,7 @@ TensorFlow GraphDef | `pb` | yolov5s.pb
|
|||
TensorFlow Lite | `tflite` | yolov5s.tflite
|
||||
TensorFlow Edge TPU | `edgetpu` | yolov5s_edgetpu.tflite
|
||||
TensorFlow.js | `tfjs` | yolov5s_web_model/
|
||||
PaddlePaddle | `paddle` | yolov5s_paddle_model/
|
||||
|
||||
Requirements:
|
||||
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime openvino-dev tensorflow-cpu # CPU
|
||||
|
@ -54,7 +55,6 @@ from pathlib import Path
|
|||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import yaml
|
||||
from torch.utils.mobile_optimizer import optimize_for_mobile
|
||||
|
||||
FILE = Path(__file__).resolve()
|
||||
|
@ -68,7 +68,7 @@ from models.experimental import attempt_load
|
|||
from models.yolo import ClassificationModel, Detect
|
||||
from utils.dataloaders import LoadImages
|
||||
from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_version,
|
||||
check_yaml, colorstr, file_size, get_default_args, print_args, url2file)
|
||||
check_yaml, colorstr, file_size, get_default_args, print_args, url2file, yaml_save)
|
||||
from utils.torch_utils import select_device, smart_inference_mode
|
||||
|
||||
|
||||
|
@ -85,7 +85,8 @@ def export_formats():
|
|||
['TensorFlow GraphDef', 'pb', '.pb', True, True],
|
||||
['TensorFlow Lite', 'tflite', '.tflite', True, False],
|
||||
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False, False],
|
||||
['TensorFlow.js', 'tfjs', '_web_model', False, False],]
|
||||
['TensorFlow.js', 'tfjs', '_web_model', False, False],
|
||||
['PaddlePaddle', 'paddle', '_paddle_model', True, True],]
|
||||
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
|
||||
|
||||
|
||||
|
@ -180,7 +181,7 @@ def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorst
|
|||
|
||||
|
||||
@try_export
|
||||
def export_openvino(model, file, half, prefix=colorstr('OpenVINO:')):
|
||||
def export_openvino(file, metadata, half, prefix=colorstr('OpenVINO:')):
|
||||
# YOLOv5 OpenVINO export
|
||||
check_requirements('openvino-dev') # requires openvino-dev: https://pypi.org/project/openvino-dev/
|
||||
import openvino.inference_engine as ie
|
||||
|
@ -189,9 +190,23 @@ def export_openvino(model, file, half, prefix=colorstr('OpenVINO:')):
|
|||
f = str(file).replace('.pt', f'_openvino_model{os.sep}')
|
||||
|
||||
cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} --data_type {'FP16' if half else 'FP32'}"
|
||||
subprocess.check_output(cmd.split()) # export
|
||||
with open(Path(f) / file.with_suffix('.yaml').name, 'w') as g:
|
||||
yaml.dump({'stride': int(max(model.stride)), 'names': model.names}, g) # add metadata.yaml
|
||||
subprocess.run(cmd.split(), check=True, env=os.environ) # export
|
||||
yaml_save(Path(f) / file.with_suffix('.yaml').name, metadata) # add metadata.yaml
|
||||
return f, None
|
||||
|
||||
|
||||
@try_export
|
||||
def export_paddle(model, im, file, metadata, prefix=colorstr('PaddlePaddle:')):
|
||||
# YOLOv5 Paddle export
|
||||
check_requirements(('paddlepaddle', 'x2paddle'))
|
||||
import x2paddle
|
||||
from x2paddle.convert import pytorch2paddle
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with X2Paddle {x2paddle.__version__}...')
|
||||
f = str(file).replace('.pt', f'_paddle_model{os.sep}')
|
||||
|
||||
pytorch2paddle(module=model, save_dir=f, jit_type='trace', input_examples=[im]) # export
|
||||
yaml_save(Path(f) / file.with_suffix('.yaml').name, metadata) # add metadata.yaml
|
||||
return f, None
|
||||
|
||||
|
||||
|
@ -464,7 +479,7 @@ def run(
|
|||
fmts = tuple(export_formats()['Argument'][1:]) # --include arguments
|
||||
flags = [x in include for x in fmts]
|
||||
assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {fmts}'
|
||||
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = flags # export booleans
|
||||
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle = flags # export booleans
|
||||
file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) # PyTorch weights
|
||||
|
||||
# Load PyTorch model
|
||||
|
@ -497,47 +512,48 @@ def run(
|
|||
if half and not coreml:
|
||||
im, model = im.half(), model.half() # to FP16
|
||||
shape = tuple((y[0] if isinstance(y, tuple) else y).shape) # model output shape
|
||||
metadata = {'stride': int(max(model.stride)), 'names': model.names} # model metadata
|
||||
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)")
|
||||
|
||||
# Exports
|
||||
f = [''] * 10 # exported filenames
|
||||
f = [''] * len(fmts) # exported filenames
|
||||
warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
|
||||
if jit:
|
||||
if jit: # TorchScript
|
||||
f[0], _ = export_torchscript(model, im, file, optimize)
|
||||
if engine: # TensorRT required before ONNX
|
||||
f[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose)
|
||||
if onnx or xml: # OpenVINO requires ONNX
|
||||
f[2], _ = export_onnx(model, im, file, opset, train, dynamic, simplify)
|
||||
if xml: # OpenVINO
|
||||
f[3], _ = export_openvino(model, file, half)
|
||||
if coreml:
|
||||
f[3], _ = export_openvino(file, metadata, half)
|
||||
if coreml: # CoreML
|
||||
f[4], _ = export_coreml(model, im, file, int8, half)
|
||||
|
||||
# TensorFlow Exports
|
||||
if any((saved_model, pb, tflite, edgetpu, tfjs)):
|
||||
if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
|
||||
if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707
|
||||
check_requirements('flatbuffers==1.12') # required before `import tensorflow`
|
||||
assert not tflite or not tfjs, 'TFLite and TF.js models must be exported separately, please pass only one type.'
|
||||
assert not isinstance(model, ClassificationModel), 'ClassificationModel export to TF formats not yet supported.'
|
||||
f[5], model = export_saved_model(model.cpu(),
|
||||
im,
|
||||
file,
|
||||
dynamic,
|
||||
tf_nms=nms or agnostic_nms or tfjs,
|
||||
agnostic_nms=agnostic_nms or tfjs,
|
||||
topk_per_class=topk_per_class,
|
||||
topk_all=topk_all,
|
||||
iou_thres=iou_thres,
|
||||
conf_thres=conf_thres,
|
||||
keras=keras)
|
||||
f[5], s_model = export_saved_model(model.cpu(),
|
||||
im,
|
||||
file,
|
||||
dynamic,
|
||||
tf_nms=nms or agnostic_nms or tfjs,
|
||||
agnostic_nms=agnostic_nms or tfjs,
|
||||
topk_per_class=topk_per_class,
|
||||
topk_all=topk_all,
|
||||
iou_thres=iou_thres,
|
||||
conf_thres=conf_thres,
|
||||
keras=keras)
|
||||
if pb or tfjs: # pb prerequisite to tfjs
|
||||
f[6], _ = export_pb(model, file)
|
||||
f[6], _ = export_pb(s_model, file)
|
||||
if tflite or edgetpu:
|
||||
f[7], _ = export_tflite(model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms)
|
||||
f[7], _ = export_tflite(s_model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms)
|
||||
if edgetpu:
|
||||
f[8], _ = export_edgetpu(file)
|
||||
if tfjs:
|
||||
f[9], _ = export_tfjs(file)
|
||||
if paddle: # PaddlePaddle
|
||||
f[10], _ = export_paddle(model, im, file, metadata)
|
||||
|
||||
# Finish
|
||||
f = [str(x) for x in f if x] # filter out '' and None
|
||||
|
|
112
models/common.py
112
models/common.py
|
@ -320,14 +320,16 @@ class DetectMultiBackend(nn.Module):
|
|||
# TensorFlow GraphDef: *.pb
|
||||
# TensorFlow Lite: *.tflite
|
||||
# TensorFlow Edge TPU: *_edgetpu.tflite
|
||||
# PaddlePaddle: *_paddle_model
|
||||
from models.experimental import attempt_download, attempt_load # scoped to avoid circular import
|
||||
|
||||
super().__init__()
|
||||
w = str(weights[0] if isinstance(weights, list) else weights)
|
||||
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = self._model_type(w) # get backend
|
||||
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle = self._model_type(w) # type
|
||||
w = attempt_download(w) # download if not local
|
||||
fp16 &= pt or jit or onnx or engine # FP16
|
||||
stride = 32 # default stride
|
||||
cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
|
||||
|
||||
if pt: # PyTorch
|
||||
model = attempt_load(weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse)
|
||||
|
@ -351,7 +353,6 @@ class DetectMultiBackend(nn.Module):
|
|||
net = cv2.dnn.readNetFromONNX(w)
|
||||
elif onnx: # ONNX Runtime
|
||||
LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
|
||||
cuda = torch.cuda.is_available() and device.type != 'cpu'
|
||||
check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
|
||||
import onnxruntime
|
||||
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
|
||||
|
@ -408,48 +409,60 @@ class DetectMultiBackend(nn.Module):
|
|||
LOGGER.info(f'Loading {w} for CoreML inference...')
|
||||
import coremltools as ct
|
||||
model = ct.models.MLModel(w)
|
||||
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
|
||||
if saved_model: # SavedModel
|
||||
LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
|
||||
import tensorflow as tf
|
||||
keras = False # assume TF1 saved_model
|
||||
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
|
||||
elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
|
||||
LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
|
||||
import tensorflow as tf
|
||||
elif saved_model: # TF SavedModel
|
||||
LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
|
||||
import tensorflow as tf
|
||||
keras = False # assume TF1 saved_model
|
||||
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
|
||||
elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
|
||||
LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
|
||||
import tensorflow as tf
|
||||
|
||||
def wrap_frozen_graph(gd, inputs, outputs):
|
||||
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
|
||||
ge = x.graph.as_graph_element
|
||||
return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
|
||||
def wrap_frozen_graph(gd, inputs, outputs):
|
||||
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
|
||||
ge = x.graph.as_graph_element
|
||||
return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
|
||||
|
||||
gd = tf.Graph().as_graph_def() # graph_def
|
||||
with open(w, 'rb') as f:
|
||||
gd.ParseFromString(f.read())
|
||||
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs="Identity:0")
|
||||
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
|
||||
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
|
||||
from tflite_runtime.interpreter import Interpreter, load_delegate
|
||||
except ImportError:
|
||||
import tensorflow as tf
|
||||
Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate,
|
||||
if edgetpu: # Edge TPU https://coral.ai/software/#edgetpu-runtime
|
||||
LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
|
||||
delegate = {
|
||||
'Linux': 'libedgetpu.so.1',
|
||||
'Darwin': 'libedgetpu.1.dylib',
|
||||
'Windows': 'edgetpu.dll'}[platform.system()]
|
||||
interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
|
||||
else: # Lite
|
||||
LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
|
||||
interpreter = Interpreter(model_path=w) # load TFLite model
|
||||
interpreter.allocate_tensors() # allocate
|
||||
input_details = interpreter.get_input_details() # inputs
|
||||
output_details = interpreter.get_output_details() # outputs
|
||||
elif tfjs:
|
||||
raise NotImplementedError('ERROR: YOLOv5 TF.js inference is not supported')
|
||||
else:
|
||||
raise NotImplementedError(f'ERROR: {w} is not a supported format')
|
||||
gd = tf.Graph().as_graph_def() # TF GraphDef
|
||||
with open(w, 'rb') as f:
|
||||
gd.ParseFromString(f.read())
|
||||
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs="Identity:0")
|
||||
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
|
||||
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
|
||||
from tflite_runtime.interpreter import Interpreter, load_delegate
|
||||
except ImportError:
|
||||
import tensorflow as tf
|
||||
Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate,
|
||||
if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
|
||||
LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
|
||||
delegate = {
|
||||
'Linux': 'libedgetpu.so.1',
|
||||
'Darwin': 'libedgetpu.1.dylib',
|
||||
'Windows': 'edgetpu.dll'}[platform.system()]
|
||||
interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
|
||||
else: # TFLite
|
||||
LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
|
||||
interpreter = Interpreter(model_path=w) # load TFLite model
|
||||
interpreter.allocate_tensors() # allocate
|
||||
input_details = interpreter.get_input_details() # inputs
|
||||
output_details = interpreter.get_output_details() # outputs
|
||||
elif tfjs: # TF.js
|
||||
raise NotImplementedError('ERROR: YOLOv5 TF.js inference is not supported')
|
||||
elif paddle: # PaddlePaddle
|
||||
LOGGER.info(f'Loading {w} for PaddlePaddle inference...')
|
||||
check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle')
|
||||
import paddle.inference as pdi
|
||||
if not Path(w).is_file(): # if not *.pdmodel
|
||||
w = next(Path(w).rglob('*.pdmodel')) # get *.xml file from *_openvino_model dir
|
||||
weights = Path(w).with_suffix('.pdiparams')
|
||||
config = pdi.Config(str(w), str(weights))
|
||||
if cuda:
|
||||
config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0)
|
||||
predictor = pdi.create_predictor(config)
|
||||
input_names = predictor.get_input_names()
|
||||
input_handle = predictor.get_input_handle(input_names[0])
|
||||
else:
|
||||
raise NotImplementedError(f'ERROR: {w} is not a supported format')
|
||||
|
||||
# class names
|
||||
if 'names' not in locals():
|
||||
|
@ -502,6 +515,13 @@ class DetectMultiBackend(nn.Module):
|
|||
else:
|
||||
k = 'var_' + str(sorted(int(k.replace('var_', '')) for k in y)[-1]) # output key
|
||||
y = y[k] # output
|
||||
elif self.paddle: # PaddlePaddle
|
||||
im = im.cpu().numpy().astype("float32")
|
||||
self.input_handle.copy_from_cpu(im)
|
||||
self.predictor.run()
|
||||
output_names = self.predictor.get_output_names()
|
||||
output_handle = self.predictor.get_output_handle(output_names[0])
|
||||
y = output_handle.copy_to_cpu()
|
||||
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
|
||||
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
|
||||
if self.saved_model: # SavedModel
|
||||
|
@ -542,13 +562,13 @@ class DetectMultiBackend(nn.Module):
|
|||
def _model_type(p='path/to/model.pt'):
|
||||
# Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
|
||||
from export import export_formats
|
||||
suffixes = list(export_formats().Suffix) + ['.xml'] # export suffixes
|
||||
check_suffix(p, suffixes) # checks
|
||||
sf = list(export_formats().Suffix) + ['.xml'] # export suffixes
|
||||
check_suffix(p, sf) # checks
|
||||
p = Path(p).name # eliminate trailing separators
|
||||
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, xml2 = (s in p for s in suffixes)
|
||||
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, xml2 = (s in p for s in sf)
|
||||
xml |= xml2 # *_openvino_model or *.xml
|
||||
tflite &= not edgetpu # *.tflite
|
||||
return pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs
|
||||
return pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle
|
||||
|
||||
@staticmethod
|
||||
def _load_metadata(f=Path('path/to/meta.yaml')):
|
||||
|
|
|
@ -61,7 +61,7 @@ def run(
|
|||
device = select_device(device)
|
||||
for i, (name, f, suffix, cpu, gpu) in export.export_formats().iterrows(): # index, (name, file, suffix, CPU, GPU)
|
||||
try:
|
||||
assert i not in (9, 10), 'inference not supported' # Edge TPU and TF.js are unsupported
|
||||
assert i not in (9, 10, 11), 'inference not supported' # Edge TPU, TF.js and Paddle are unsupported
|
||||
assert i != 5 or platform.system() == 'Darwin', 'inference only supported on macOS>=10.13' # CoreML
|
||||
if 'cpu' in device.type:
|
||||
assert cpu, 'inference not supported on CPU'
|
||||
|
|
Loading…
Reference in New Issue