TensorFlow.js export enhancements (#4905)

* Add arguments to TensorFlow NMS call

* Add regex substitution to reorder Identity_*

* Delete reorder in docstring

* Cleanup

* Cleanup2

* Removed `+ \` on string ends (not needed)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Jiacong Fang 2021-09-25 05:18:15 +08:00 committed by GitHub
parent ce7fa81d4e
commit 2c2ef25f8b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 3 deletions

View File

@ -14,7 +14,6 @@ Inference:
yolov5s.tflite yolov5s.tflite
TensorFlow.js: TensorFlow.js:
$ # Edit yolov5s_web_model/model.json to sort Identity* in ascending order
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example $ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
$ npm install $ npm install
$ ln -s ../../yolov5/yolov5s_web_model public/yolov5s_web_model $ ln -s ../../yolov5/yolov5s_web_model public/yolov5s_web_model
@ -213,16 +212,32 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):
# YOLOv5 TensorFlow.js export # YOLOv5 TensorFlow.js export
try: try:
check_requirements(('tensorflowjs',)) check_requirements(('tensorflowjs',))
import re
import tensorflowjs as tfjs import tensorflowjs as tfjs
print(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...') print(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
f = str(file).replace('.pt', '_web_model') # js dir f = str(file).replace('.pt', '_web_model') # js dir
f_pb = file.with_suffix('.pb') # *.pb path f_pb = file.with_suffix('.pb') # *.pb path
f_json = f + '/model.json' # *.json path
cmd = f"tensorflowjs_converter --input_format=tf_frozen_model " \ cmd = f"tensorflowjs_converter --input_format=tf_frozen_model " \
f"--output_node_names='Identity,Identity_1,Identity_2,Identity_3' {f_pb} {f}" f"--output_node_names='Identity,Identity_1,Identity_2,Identity_3' {f_pb} {f}"
subprocess.run(cmd, shell=True) subprocess.run(cmd, shell=True)
json = open(f_json).read()
with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order
subst = re.sub(
r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
r'"Identity.?.?": {"name": "Identity.?.?"}, '
r'"Identity.?.?": {"name": "Identity.?.?"}, '
r'"Identity.?.?": {"name": "Identity.?.?"}}}',
r'{"outputs": {"Identity": {"name": "Identity"}, '
r'"Identity_1": {"name": "Identity_1"}, '
r'"Identity_2": {"name": "Identity_2"}, '
r'"Identity_3": {"name": "Identity_3"}}}',
json)
j.write(subst)
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e: except Exception as e:
print(f'\n{prefix} export failure: {e}') print(f'\n{prefix} export failure: {e}')
@ -243,6 +258,10 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
dynamic=False, # ONNX/TF: dynamic axes dynamic=False, # ONNX/TF: dynamic axes
simplify=False, # ONNX: simplify model simplify=False, # ONNX: simplify model
opset=12, # ONNX: opset version opset=12, # ONNX: opset version
topk_per_class=100, # TF.js NMS: topk per class to keep
topk_all=100, # TF.js NMS: topk for all classes to keep
iou_thres=0.45, # TF.js NMS: IoU threshold
conf_thres=0.25 # TF.js NMS: confidence threshold
): ):
t = time.time() t = time.time()
include = [x.lower() for x in include] include = [x.lower() for x in include]
@ -290,7 +309,9 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
if any(tf_exports): if any(tf_exports):
pb, tflite, tfjs = tf_exports[1:] pb, tflite, tfjs = tf_exports[1:]
assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.' assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.'
model = export_saved_model(model, im, file, dynamic, tf_nms=tfjs, agnostic_nms=tfjs) # keras model model = export_saved_model(model, im, file, dynamic, tf_nms=tfjs, agnostic_nms=tfjs,
topk_per_class=topk_per_class, topk_all=topk_all, conf_thres=conf_thres,
iou_thres=iou_thres) # keras model
if pb or tfjs: # pb prerequisite to tfjs if pb or tfjs: # pb prerequisite to tfjs
export_pb(model, im, file) export_pb(model, im, file)
if tflite: if tflite:
@ -319,6 +340,10 @@ def parse_opt():
parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes') parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes')
parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model') parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
parser.add_argument('--opset', type=int, default=13, help='ONNX: opset version') parser.add_argument('--opset', type=int, default=13, help='ONNX: opset version')
parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep')
parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep')
parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold')
parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
parser.add_argument('--include', nargs='+', parser.add_argument('--include', nargs='+',
default=['torchscript', 'onnx'], default=['torchscript', 'onnx'],
help='available formats are (torchscript, onnx, coreml, saved_model, pb, tflite, tfjs)') help='available formats are (torchscript, onnx, coreml, saved_model, pb, tflite, tfjs)')

View File

@ -367,7 +367,7 @@ class AgnosticNMS(keras.layers.Layer):
# TF Agnostic NMS # TF Agnostic NMS
def call(self, input, topk_all, iou_thres, conf_thres): def call(self, input, topk_all, iou_thres, conf_thres):
# wrap map_fn to avoid TypeSpec related error https://stackoverflow.com/a/65809989/3036450 # wrap map_fn to avoid TypeSpec related error https://stackoverflow.com/a/65809989/3036450
return tf.map_fn(self._nms, input, return tf.map_fn(lambda x: self._nms(x, topk_all, iou_thres, conf_thres), input,
fn_output_signature=(tf.float32, tf.float32, tf.float32, tf.int32), fn_output_signature=(tf.float32, tf.float32, tf.float32, tf.int32),
name='agnostic_nms') name='agnostic_nms')