Add nms and agnostic nms to export.py (#5938)

* add nms and agnostic nms to export.py

* fix agnostic implies nms

* reorder args to group TF args

* PEP8 120 char

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Diego Montes 2021-12-10 12:24:32 -05:00 committed by GitHub
parent a42af30d8a
commit 2c6317547a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -328,6 +328,8 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
opset=14, # ONNX: opset version opset=14, # ONNX: opset version
verbose=False, # TensorRT: verbose log verbose=False, # TensorRT: verbose log
workspace=4, # TensorRT: workspace size (GB) workspace=4, # TensorRT: workspace size (GB)
nms=False, # TF: add NMS to model
agnostic_nms=False, # TF: add agnostic NMS to model
topk_per_class=100, # TF.js NMS: topk per class to keep topk_per_class=100, # TF.js NMS: topk per class to keep
topk_all=100, # TF.js NMS: topk for all classes to keep topk_all=100, # TF.js NMS: topk for all classes to keep
iou_thres=0.45, # TF.js NMS: IoU threshold iou_thres=0.45, # TF.js NMS: IoU threshold
@ -381,9 +383,9 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
if any(tf_exports): if any(tf_exports):
pb, tflite, tfjs = tf_exports[1:] pb, tflite, tfjs = tf_exports[1:]
assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.' assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.'
model = export_saved_model(model, im, file, dynamic, tf_nms=tfjs, agnostic_nms=tfjs, model = export_saved_model(model, im, file, dynamic, tf_nms=nms or agnostic_nms or tfjs,
topk_per_class=topk_per_class, topk_all=topk_all, conf_thres=conf_thres, agnostic_nms=agnostic_nms or tfjs, topk_per_class=topk_per_class, topk_all=topk_all,
iou_thres=iou_thres) # keras model conf_thres=conf_thres, iou_thres=iou_thres) # keras model
if pb or tfjs: # pb prerequisite to tfjs if pb or tfjs: # pb prerequisite to tfjs
export_pb(model, im, file) export_pb(model, im, file)
if tflite: if tflite:
@ -414,6 +416,8 @@ def parse_opt():
parser.add_argument('--opset', type=int, default=14, help='ONNX: opset version') parser.add_argument('--opset', type=int, default=14, help='ONNX: opset version')
parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log') parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')
parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)') parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)')
parser.add_argument('--nms', action='store_true', help='TF: add NMS to model')
parser.add_argument('--agnostic-nms', action='store_true', help='TF: add agnostic NMS to model')
parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep') parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep')
parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep') parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep')
parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold') parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold')