Add TensorFlow formats to export.py (#4479)

* Initial commit

* Remove unused export_torchscript return

* ROOT variable

* Add prefix to fcn arg

* fix ROOT

* check_yaml into run()

* interim fixes

* imgsz=(320, 320)

* Hardcode tf_raw_resize False

* Finish opt elimination

* Update representative_dataset_gen()

* Update export.py with TF methods

* SiLU and GraphDef fixes

* file_size() directory handling feature

* export fixes

* add lambda: to representative_dataset

* Detect training False default

* Fuse false for TF models

* Embed agnostic NMS arguments

* Remove lambda

* TensorFlow.js export success

* Add pb to Usage

* Add *_tfjs_model/ to ignore files

* prepend YOLOv5 to function headers

* Remove end --- comments

* parameterize tfjs export pb file

* update run() data default /ROOT

* update --include help

* update imports

* return ct_model

* Consolidate TFLite export

* pb prerequisite to tfjs

* TF modules CamelCase

* Remove exports from tf.py and cleanup

* pass agnostic NMS arguments

* CI

* CI

* ignore *_web_model/

* Add tensorflow to CI dependencies

* CI tensorflow-cpu

* Update requirements.txt

* Remove tensorflow check_requirement

* CI coreml tfjs

* export only onnx torchscript

* reorder exports torchscript first
This commit is contained in:
Glenn Jocher 2021-09-12 15:52:24 +02:00 committed by GitHub
parent c47be26f34
commit c3a93d783d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 359 additions and 322 deletions

View File

@ -22,6 +22,7 @@ data/samples/*
**/*.h5 **/*.h5
**/*.pb **/*.pb
*_saved_model/ *_saved_model/
*_web_model/
# Below Copied From .gitignore ----------------------------------------------------------------------------------------- # Below Copied From .gitignore -----------------------------------------------------------------------------------------
# Below Copied From .gitignore ----------------------------------------------------------------------------------------- # Below Copied From .gitignore -----------------------------------------------------------------------------------------

View File

@ -48,7 +48,7 @@ jobs:
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install -qr requirements.txt -f https://download.pytorch.org/whl/cpu/torch_stable.html pip install -qr requirements.txt -f https://download.pytorch.org/whl/cpu/torch_stable.html
pip install -q onnx onnx-simplifier coremltools # for export pip install -q onnx tensorflow-cpu # for export
python --version python --version
pip --version pip --version
pip list pip list
@ -75,6 +75,7 @@ jobs:
python val.py --img 128 --batch 16 --weights runs/train/exp/weights/last.pt --device $di python val.py --img 128 --batch 16 --weights runs/train/exp/weights/last.pt --device $di
python hubconf.py # hub python hubconf.py # hub
python models/yolo.py --cfg ${{ matrix.model }}.yaml # inspect python models/yolo.py --cfg ${{ matrix.model }}.yaml # build PyTorch model
python export.py --img 128 --batch 1 --weights ${{ matrix.model }}.pt --include onnx torchscript # export python models/tf.py --weights ${{ matrix.model }}.pt # build TensorFlow model
python export.py --img 128 --batch 1 --weights ${{ matrix.model }}.pt --include torchscript onnx # export
shell: bash shell: bash

1
.gitignore vendored
View File

@ -52,6 +52,7 @@ VOC/
*.tflite *.tflite
*.h5 *.h5
*_saved_model/ *_saved_model/
*_web_model/
darknet53.conv.74 darknet53.conv.74
yolov3-tiny.conv.15 yolov3-tiny.conv.15

View File

@ -253,7 +253,7 @@ def run(weights='yolov5s.pt', # model.pt path(s)
def parse_opt(): def parse_opt():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)') parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model path(s)')
parser.add_argument('--source', type=str, default='data/images', help='file/dir/URL/glob, 0 for webcam') parser.add_argument('--source', type=str, default='data/images', help='file/dir/URL/glob, 0 for webcam')
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w') parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold') parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')

219
export.py
View File

@ -1,12 +1,28 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
""" """
Export a PyTorch model to TorchScript, ONNX, CoreML formats Export a YOLOv5 PyTorch model to TorchScript, ONNX, CoreML, TensorFlow (saved_model, pb, TFLite, TF.js,) formats
TensorFlow exports authored by https://github.com/zldrobit
Usage: Usage:
$ python path/to/export.py --weights yolov5s.pt --img 640 --batch 1 $ python path/to/export.py --weights yolov5s.pt --include torchscript onnx coreml saved_model pb tflite tfjs
Inference:
$ python path/to/detect.py --weights yolov5s.pt
yolov5s.onnx (must export with --dynamic)
yolov5s_saved_model
yolov5s.pb
yolov5s.tflite
TensorFlow.js:
$ # Edit yolov5s_web_model/model.json to sort Identity* in ascending order
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
$ npm install
$ ln -s ../../yolov5/yolov5s_web_model public/yolov5s_web_model
$ npm start
""" """
import argparse import argparse
import subprocess
import sys import sys
import time import time
from pathlib import Path from pathlib import Path
@ -16,40 +32,42 @@ import torch.nn as nn
from torch.utils.mobile_optimizer import optimize_for_mobile from torch.utils.mobile_optimizer import optimize_for_mobile
FILE = Path(__file__).resolve() FILE = Path(__file__).resolve()
sys.path.append(FILE.parents[0].as_posix()) # add yolov5/ to path ROOT = FILE.parents[0] # yolov5/ dir
sys.path.append(ROOT.as_posix()) # add yolov5/ to path
from models.common import Conv from models.common import Conv
from models.yolo import Detect
from models.experimental import attempt_load from models.experimental import attempt_load
from utils.activations import Hardswish, SiLU from models.yolo import Detect
from utils.general import colorstr, check_img_size, check_requirements, file_size, set_logging from utils.activations import SiLU
from utils.datasets import LoadImages
from utils.general import colorstr, check_dataset, check_img_size, check_requirements, file_size, set_logging
from utils.torch_utils import select_device from utils.torch_utils import select_device
def export_torchscript(model, img, file, optimize): def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):
# TorchScript model export # YOLOv5 TorchScript model export
prefix = colorstr('TorchScript:')
try: try:
print(f'\n{prefix} starting export with torch {torch.__version__}...') print(f'\n{prefix} starting export with torch {torch.__version__}...')
f = file.with_suffix('.torchscript.pt') f = file.with_suffix('.torchscript.pt')
ts = torch.jit.trace(model, img, strict=False)
ts = torch.jit.trace(model, im, strict=False)
(optimize_for_mobile(ts) if optimize else ts).save(f) (optimize_for_mobile(ts) if optimize else ts).save(f)
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
return ts
except Exception as e: except Exception as e:
print(f'{prefix} export failure: {e}') print(f'{prefix} export failure: {e}')
def export_onnx(model, img, file, opset, train, dynamic, simplify): def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')):
# ONNX model export # YOLOv5 ONNX export
prefix = colorstr('ONNX:')
try: try:
check_requirements(('onnx',)) check_requirements(('onnx',))
import onnx import onnx
print(f'\n{prefix} starting export with onnx {onnx.__version__}...') print(f'\n{prefix} starting export with onnx {onnx.__version__}...')
f = file.with_suffix('.onnx') f = file.with_suffix('.onnx')
torch.onnx.export(model, img, f, verbose=False, opset_version=opset,
torch.onnx.export(model, im, f, verbose=False, opset_version=opset,
training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL, training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
do_constant_folding=not train, do_constant_folding=not train,
input_names=['images'], input_names=['images'],
@ -73,7 +91,7 @@ def export_onnx(model, img, file, opset, train, dynamic, simplify):
model_onnx, check = onnxsim.simplify( model_onnx, check = onnxsim.simplify(
model_onnx, model_onnx,
dynamic_input_shape=dynamic, dynamic_input_shape=dynamic,
input_shapes={'images': list(img.shape)} if dynamic else None) input_shapes={'images': list(im.shape)} if dynamic else None)
assert check, 'assert check failed' assert check, 'assert check failed'
onnx.save(model_onnx, f) onnx.save(model_onnx, f)
except Exception as e: except Exception as e:
@ -84,26 +102,131 @@ def export_onnx(model, img, file, opset, train, dynamic, simplify):
print(f'{prefix} export failure: {e}') print(f'{prefix} export failure: {e}')
def export_coreml(model, img, file): def export_coreml(model, im, file, prefix=colorstr('CoreML:')):
# CoreML model export # YOLOv5 CoreML export
prefix = colorstr('CoreML:') ct_model = None
try: try:
check_requirements(('coremltools',)) check_requirements(('coremltools',))
import coremltools as ct import coremltools as ct
print(f'\n{prefix} starting export with coremltools {ct.__version__}...') print(f'\n{prefix} starting export with coremltools {ct.__version__}...')
f = file.with_suffix('.mlmodel') f = file.with_suffix('.mlmodel')
model.train() # CoreML exports should be placed in model.train() mode model.train() # CoreML exports should be placed in model.train() mode
ts = torch.jit.trace(model, img, strict=False) # TorchScript model ts = torch.jit.trace(model, im, strict=False) # TorchScript model
model = ct.convert(ts, inputs=[ct.ImageType('image', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])]) ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255.0, bias=[0, 0, 0])])
model.save(f) ct_model.save(f)
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
print(f'\n{prefix} export failure: {e}')
return ct_model
def export_saved_model(model, im, file, dynamic,
tf_nms=False, agnostic_nms=False, topk_per_class=100, topk_all=100, iou_thres=0.45,
conf_thres=0.25, prefix=colorstr('TensorFlow saved_model:')):
# YOLOv5 TensorFlow saved_model export
keras_model = None
try:
import tensorflow as tf
from tensorflow import keras
from models.tf import TFModel, TFDetect
print(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
f = str(file).replace('.pt', '_saved_model')
batch_size, ch, *imgsz = list(im.shape) # BCHW
tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
im = tf.zeros((batch_size, *imgsz, 3)) # BHWC order for TensorFlow
y = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
inputs = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
keras_model = keras.Model(inputs=inputs, outputs=outputs)
keras_model.summary()
keras_model.save(f, save_format='tf')
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
print(f'\n{prefix} export failure: {e}')
return keras_model
def export_pb(keras_model, im, file, prefix=colorstr('TensorFlow GraphDef:')):
# YOLOv5 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow
try:
import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
print(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
f = file.with_suffix('.pb')
m = tf.function(lambda x: keras_model(x)) # full model
m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
frozen_func = convert_variables_to_constants_v2(m)
frozen_func.graph.as_graph_def()
tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e: except Exception as e:
print(f'\n{prefix} export failure: {e}') print(f'\n{prefix} export failure: {e}')
def run(weights='./yolov5s.pt', # weights path def export_tflite(keras_model, im, file, tfl_int8, data, ncalib, prefix=colorstr('TensorFlow Lite:')):
img_size=(640, 640), # image (height, width) # YOLOv5 TensorFlow Lite export
try:
import tensorflow as tf
from models.tf import representative_dataset_gen
print(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
batch_size, ch, *imgsz = list(im.shape) # BCHW
f = file.with_suffix('.tflite')
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
converter.optimizations = [tf.lite.Optimize.DEFAULT]
if tfl_int8:
dataset = LoadImages(check_dataset(data)['train'], img_size=imgsz, auto=False) # representative data
converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8 # or tf.int8
converter.inference_output_type = tf.uint8 # or tf.int8
converter.experimental_new_quantizer = False
f = str(file).replace('.pt', '-int8.tflite')
tflite_model = converter.convert()
open(f, "wb").write(tflite_model)
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
print(f'\n{prefix} export failure: {e}')
def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):
# YOLOv5 TensorFlow.js export
try:
check_requirements(('tensorflowjs',))
import tensorflowjs as tfjs
print(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
f = str(file).replace('.pt', '_web_model') # js dir
f_pb = file.with_suffix('.pb') # *.pb path
cmd = f"tensorflowjs_converter --input_format=tf_frozen_model " \
f"--output_node_names='Identity,Identity_1,Identity_2,Identity_3' {f_pb} {f}"
subprocess.run(cmd, shell=True)
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
print(f'\n{prefix} export failure: {e}')
@torch.no_grad()
def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
weights=ROOT / 'yolov5s.pt', # weights path
imgsz=(640, 640), # image (height, width)
batch_size=1, # batch size batch_size=1, # batch size
device='cpu', # cuda device, i.e. 0 or 0,1,2,3 or cpu device='cpu', # cuda device, i.e. 0 or 0,1,2,3 or cpu
include=('torchscript', 'onnx', 'coreml'), # include formats include=('torchscript', 'onnx', 'coreml'), # include formats
@ -117,29 +240,28 @@ def run(weights='./yolov5s.pt', # weights path
): ):
t = time.time() t = time.time()
include = [x.lower() for x in include] include = [x.lower() for x in include]
img_size *= 2 if len(img_size) == 1 else 1 # expand tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'tfjs')) # TensorFlow exports
imgsz *= 2 if len(imgsz) == 1 else 1 # expand
file = Path(weights) file = Path(weights)
# Load PyTorch model # Load PyTorch model
device = select_device(device) device = select_device(device)
assert not (device.type == 'cpu' and half), '--half only compatible with GPU export, i.e. use --device 0' assert not (device.type == 'cpu' and half), '--half only compatible with GPU export, i.e. use --device 0'
model = attempt_load(weights, map_location=device) # load FP32 model model = attempt_load(weights, map_location=device, inplace=True, fuse=not any(tf_exports)) # load FP32 model
names = model.names nc, names = model.nc, model.names # number of classes, class names
# Input # Input
gs = int(max(model.stride)) # grid size (max stride) gs = int(max(model.stride)) # grid size (max stride)
img_size = [check_img_size(x, gs) for x in img_size] # verify img_size are gs-multiples imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiples
img = torch.zeros(batch_size, 3, *img_size).to(device) # image size(1,3,320,192) iDetection im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
# Update model # Update model
if half: if half:
img, model = img.half(), model.half() # to FP16 im, model = im.half(), model.half() # to FP16
model.train() if train else model.eval() # training mode = no Detect() layer grid construction model.train() if train else model.eval() # training mode = no Detect() layer grid construction
for k, m in model.named_modules(): for k, m in model.named_modules():
if isinstance(m, Conv): # assign export-friendly activations if isinstance(m, Conv): # assign export-friendly activations
if isinstance(m.act, nn.Hardswish): if isinstance(m.act, nn.SiLU):
m.act = Hardswish()
elif isinstance(m.act, nn.SiLU):
m.act = SiLU() m.act = SiLU()
elif isinstance(m, Detect): elif isinstance(m, Detect):
m.inplace = inplace m.inplace = inplace
@ -147,16 +269,28 @@ def run(weights='./yolov5s.pt', # weights path
# m.forward = m.forward_export # assign forward (optional) # m.forward = m.forward_export # assign forward (optional)
for _ in range(2): for _ in range(2):
y = model(img) # dry runs y = model(im) # dry runs
print(f"\n{colorstr('PyTorch:')} starting from {weights} ({file_size(weights):.1f} MB)") print(f"\n{colorstr('PyTorch:')} starting from {weights} ({file_size(weights):.1f} MB)")
# Exports # Exports
if 'torchscript' in include: if 'torchscript' in include:
export_torchscript(model, img, file, optimize) export_torchscript(model, im, file, optimize)
if 'onnx' in include: if 'onnx' in include:
export_onnx(model, img, file, opset, train, dynamic, simplify) export_onnx(model, im, file, opset, train, dynamic, simplify)
if 'coreml' in include: if 'coreml' in include:
export_coreml(model, img, file) export_coreml(model, im, file)
# TensorFlow Exports
if any(tf_exports):
pb, tflite, tfjs = tf_exports[1:]
assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.'
model = export_saved_model(model, im, file, dynamic, tf_nms=tfjs, agnostic_nms=tfjs) # keras model
if pb or tfjs: # pb prerequisite to tfjs
export_pb(model, im, file)
if tflite:
export_tflite(model, im, file, tfl_int8=False, data=data, ncalib=100)
if tfjs:
export_tfjs(model, im, file)
# Finish # Finish
print(f'\nExport complete ({time.time() - t:.2f}s)' print(f'\nExport complete ({time.time() - t:.2f}s)'
@ -166,18 +300,21 @@ def run(weights='./yolov5s.pt', # weights path
def parse_opt(): def parse_opt():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='./yolov5s.pt', help='weights path') parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image (height, width)') parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s.pt', help='weights path')
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640, 640], help='image (h, w)')
parser.add_argument('--batch-size', type=int, default=1, help='batch size') parser.add_argument('--batch-size', type=int, default=1, help='batch size')
parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--include', nargs='+', default=['torchscript', 'onnx', 'coreml'], help='include formats')
parser.add_argument('--half', action='store_true', help='FP16 half-precision export') parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True') parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True')
parser.add_argument('--train', action='store_true', help='model.train() mode') parser.add_argument('--train', action='store_true', help='model.train() mode')
parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile') parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')
parser.add_argument('--dynamic', action='store_true', help='ONNX: dynamic axes') parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes')
parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model') parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
parser.add_argument('--opset', type=int, default=13, help='ONNX: opset version') parser.add_argument('--opset', type=int, default=13, help='ONNX: opset version')
parser.add_argument('--include', nargs='+',
default=['torchscript', 'onnx'],
help='available formats are (torchscript, onnx, coreml, saved_model, pb, tflite, tfjs)')
opt = parser.parse_args() opt = parser.parse_args()
return opt return opt

View File

@ -1,67 +1,44 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
""" """
TensorFlow/Keras and TFLite versions of YOLOv5 TensorFlow, Keras and TFLite versions of YOLOv5
Authored by https://github.com/zldrobit in PR https://github.com/ultralytics/yolov5/pull/1127 Authored by https://github.com/zldrobit in PR https://github.com/ultralytics/yolov5/pull/1127
Usage: Usage:
$ python models/tf.py --weights yolov5s.pt --cfg yolov5s.yaml $ python models/tf.py --weights yolov5s.pt
Export int8 TFLite models: Export:
$ python models/tf.py --weights yolov5s.pt --cfg models/yolov5s.yaml --tfl-int8 \ $ python path/to/export.py --weights yolov5s.pt --include saved_model pb tflite tfjs
--source path/to/images/ --ncalib 100
Detection:
$ python detect.py --weights yolov5s.pb --img 320
$ python detect.py --weights yolov5s_saved_model --img 320
$ python detect.py --weights yolov5s-fp16.tflite --img 320
$ python detect.py --weights yolov5s-int8.tflite --img 320 --tfl-int8
For TensorFlow.js:
$ python models/tf.py --weights yolov5s.pt --cfg models/yolov5s.yaml --img 320 --tf-nms --agnostic-nms
$ pip install tensorflowjs
$ tensorflowjs_converter \
--input_format=tf_frozen_model \
--output_node_names='Identity,Identity_1,Identity_2,Identity_3' \
yolov5s.pb \
web_model
$ # Edit web_model/model.json to sort Identity* in ascending order
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
$ npm install
$ ln -s ../../yolov5/web_model public/web_model
$ npm start
""" """
import argparse import argparse
import logging import logging
import os
import sys import sys
import traceback
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
sys.path.append('./') # to run '$ python *.py' files in subdirectories FILE = Path(__file__).resolve()
ROOT = FILE.parents[1] # yolov5/ dir
sys.path.append(ROOT.as_posix()) # add yolov5/ to path
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import torch import torch
import torch.nn as nn import torch.nn as nn
import yaml
from tensorflow import keras from tensorflow import keras
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, autopad, C3 from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, autopad, C3
from models.experimental import MixConv2d, CrossConv, attempt_load from models.experimental import MixConv2d, CrossConv, attempt_load
from models.yolo import Detect from models.yolo import Detect
from utils.datasets import LoadImages from utils.general import colorstr, make_divisible, set_logging
from utils.general import check_dataset, check_yaml, make_divisible from utils.activations import SiLU
logger = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
class tf_BN(keras.layers.Layer): class TFBN(keras.layers.Layer):
# TensorFlow BatchNormalization wrapper # TensorFlow BatchNormalization wrapper
def __init__(self, w=None): def __init__(self, w=None):
super(tf_BN, self).__init__() super(TFBN, self).__init__()
self.bn = keras.layers.BatchNormalization( self.bn = keras.layers.BatchNormalization(
beta_initializer=keras.initializers.Constant(w.bias.numpy()), beta_initializer=keras.initializers.Constant(w.bias.numpy()),
gamma_initializer=keras.initializers.Constant(w.weight.numpy()), gamma_initializer=keras.initializers.Constant(w.weight.numpy()),
@ -73,20 +50,20 @@ class tf_BN(keras.layers.Layer):
return self.bn(inputs) return self.bn(inputs)
class tf_Pad(keras.layers.Layer): class TFPad(keras.layers.Layer):
def __init__(self, pad): def __init__(self, pad):
super(tf_Pad, self).__init__() super(TFPad, self).__init__()
self.pad = tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]]) self.pad = tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]])
def call(self, inputs): def call(self, inputs):
return tf.pad(inputs, self.pad, mode='constant', constant_values=0) return tf.pad(inputs, self.pad, mode='constant', constant_values=0)
class tf_Conv(keras.layers.Layer): class TFConv(keras.layers.Layer):
# Standard convolution # Standard convolution
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None): def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
# ch_in, ch_out, weights, kernel, stride, padding, groups # ch_in, ch_out, weights, kernel, stride, padding, groups
super(tf_Conv, self).__init__() super(TFConv, self).__init__()
assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument" assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
assert isinstance(k, int), "Convolution with multiple kernels are not allowed." assert isinstance(k, int), "Convolution with multiple kernels are not allowed."
# TensorFlow convolution padding is inconsistent with PyTorch (e.g. k=3 s=2 'SAME' padding) # TensorFlow convolution padding is inconsistent with PyTorch (e.g. k=3 s=2 'SAME' padding)
@ -95,27 +72,29 @@ class tf_Conv(keras.layers.Layer):
conv = keras.layers.Conv2D( conv = keras.layers.Conv2D(
c2, k, s, 'SAME' if s == 1 else 'VALID', use_bias=False, c2, k, s, 'SAME' if s == 1 else 'VALID', use_bias=False,
kernel_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy())) kernel_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()))
self.conv = conv if s == 1 else keras.Sequential([tf_Pad(autopad(k, p)), conv]) self.conv = conv if s == 1 else keras.Sequential([TFPad(autopad(k, p)), conv])
self.bn = tf_BN(w.bn) if hasattr(w, 'bn') else tf.identity self.bn = TFBN(w.bn) if hasattr(w, 'bn') else tf.identity
# YOLOv5 activations # YOLOv5 activations
if isinstance(w.act, nn.LeakyReLU): if isinstance(w.act, nn.LeakyReLU):
self.act = (lambda x: keras.activations.relu(x, alpha=0.1)) if act else tf.identity self.act = (lambda x: keras.activations.relu(x, alpha=0.1)) if act else tf.identity
elif isinstance(w.act, nn.Hardswish): elif isinstance(w.act, nn.Hardswish):
self.act = (lambda x: x * tf.nn.relu6(x + 3) * 0.166666667) if act else tf.identity self.act = (lambda x: x * tf.nn.relu6(x + 3) * 0.166666667) if act else tf.identity
elif isinstance(w.act, nn.SiLU): elif isinstance(w.act, (nn.SiLU, SiLU)):
self.act = (lambda x: keras.activations.swish(x)) if act else tf.identity self.act = (lambda x: keras.activations.swish(x)) if act else tf.identity
else:
raise Exception(f'no matching TensorFlow activation found for {w.act}')
def call(self, inputs): def call(self, inputs):
return self.act(self.bn(self.conv(inputs))) return self.act(self.bn(self.conv(inputs)))
class tf_Focus(keras.layers.Layer): class TFFocus(keras.layers.Layer):
# Focus wh information into c-space # Focus wh information into c-space
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None): def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
# ch_in, ch_out, kernel, stride, padding, groups # ch_in, ch_out, kernel, stride, padding, groups
super(tf_Focus, self).__init__() super(TFFocus, self).__init__()
self.conv = tf_Conv(c1 * 4, c2, k, s, p, g, act, w.conv) self.conv = TFConv(c1 * 4, c2, k, s, p, g, act, w.conv)
def call(self, inputs): # x(b,w,h,c) -> y(b,w/2,h/2,4c) def call(self, inputs): # x(b,w,h,c) -> y(b,w/2,h/2,4c)
# inputs = inputs / 255. # normalize 0-255 to 0-1 # inputs = inputs / 255. # normalize 0-255 to 0-1
@ -125,23 +104,23 @@ class tf_Focus(keras.layers.Layer):
inputs[:, 1::2, 1::2, :]], 3)) inputs[:, 1::2, 1::2, :]], 3))
class tf_Bottleneck(keras.layers.Layer): class TFBottleneck(keras.layers.Layer):
# Standard bottleneck # Standard bottleneck
def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, w=None): # ch_in, ch_out, shortcut, groups, expansion def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, w=None): # ch_in, ch_out, shortcut, groups, expansion
super(tf_Bottleneck, self).__init__() super(TFBottleneck, self).__init__()
c_ = int(c2 * e) # hidden channels c_ = int(c2 * e) # hidden channels
self.cv1 = tf_Conv(c1, c_, 1, 1, w=w.cv1) self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
self.cv2 = tf_Conv(c_, c2, 3, 1, g=g, w=w.cv2) self.cv2 = TFConv(c_, c2, 3, 1, g=g, w=w.cv2)
self.add = shortcut and c1 == c2 self.add = shortcut and c1 == c2
def call(self, inputs): def call(self, inputs):
return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs)) return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs))
class tf_Conv2d(keras.layers.Layer): class TFConv2d(keras.layers.Layer):
# Substitution for PyTorch nn.Conv2D # Substitution for PyTorch nn.Conv2D
def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None): def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None):
super(tf_Conv2d, self).__init__() super(TFConv2d, self).__init__()
assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument" assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
self.conv = keras.layers.Conv2D( self.conv = keras.layers.Conv2D(
c2, k, s, 'VALID', use_bias=bias, c2, k, s, 'VALID', use_bias=bias,
@ -152,19 +131,19 @@ class tf_Conv2d(keras.layers.Layer):
return self.conv(inputs) return self.conv(inputs)
class tf_BottleneckCSP(keras.layers.Layer): class TFBottleneckCSP(keras.layers.Layer):
# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None): def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
# ch_in, ch_out, number, shortcut, groups, expansion # ch_in, ch_out, number, shortcut, groups, expansion
super(tf_BottleneckCSP, self).__init__() super(TFBottleneckCSP, self).__init__()
c_ = int(c2 * e) # hidden channels c_ = int(c2 * e) # hidden channels
self.cv1 = tf_Conv(c1, c_, 1, 1, w=w.cv1) self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
self.cv2 = tf_Conv2d(c1, c_, 1, 1, bias=False, w=w.cv2) self.cv2 = TFConv2d(c1, c_, 1, 1, bias=False, w=w.cv2)
self.cv3 = tf_Conv2d(c_, c_, 1, 1, bias=False, w=w.cv3) self.cv3 = TFConv2d(c_, c_, 1, 1, bias=False, w=w.cv3)
self.cv4 = tf_Conv(2 * c_, c2, 1, 1, w=w.cv4) self.cv4 = TFConv(2 * c_, c2, 1, 1, w=w.cv4)
self.bn = tf_BN(w.bn) self.bn = TFBN(w.bn)
self.act = lambda x: keras.activations.relu(x, alpha=0.1) self.act = lambda x: keras.activations.relu(x, alpha=0.1)
self.m = keras.Sequential([tf_Bottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)]) self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
def call(self, inputs): def call(self, inputs):
y1 = self.cv3(self.m(self.cv1(inputs))) y1 = self.cv3(self.m(self.cv1(inputs)))
@ -172,28 +151,28 @@ class tf_BottleneckCSP(keras.layers.Layer):
return self.cv4(self.act(self.bn(tf.concat((y1, y2), axis=3)))) return self.cv4(self.act(self.bn(tf.concat((y1, y2), axis=3))))
class tf_C3(keras.layers.Layer): class TFC3(keras.layers.Layer):
# CSP Bottleneck with 3 convolutions # CSP Bottleneck with 3 convolutions
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None): def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
# ch_in, ch_out, number, shortcut, groups, expansion # ch_in, ch_out, number, shortcut, groups, expansion
super(tf_C3, self).__init__() super(TFC3, self).__init__()
c_ = int(c2 * e) # hidden channels c_ = int(c2 * e) # hidden channels
self.cv1 = tf_Conv(c1, c_, 1, 1, w=w.cv1) self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
self.cv2 = tf_Conv(c1, c_, 1, 1, w=w.cv2) self.cv2 = TFConv(c1, c_, 1, 1, w=w.cv2)
self.cv3 = tf_Conv(2 * c_, c2, 1, 1, w=w.cv3) self.cv3 = TFConv(2 * c_, c2, 1, 1, w=w.cv3)
self.m = keras.Sequential([tf_Bottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)]) self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
def call(self, inputs): def call(self, inputs):
return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3)) return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))
class tf_SPP(keras.layers.Layer): class TFSPP(keras.layers.Layer):
# Spatial pyramid pooling layer used in YOLOv3-SPP # Spatial pyramid pooling layer used in YOLOv3-SPP
def __init__(self, c1, c2, k=(5, 9, 13), w=None): def __init__(self, c1, c2, k=(5, 9, 13), w=None):
super(tf_SPP, self).__init__() super(TFSPP, self).__init__()
c_ = c1 // 2 # hidden channels c_ = c1 // 2 # hidden channels
self.cv1 = tf_Conv(c1, c_, 1, 1, w=w.cv1) self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
self.cv2 = tf_Conv(c_ * (len(k) + 1), c2, 1, 1, w=w.cv2) self.cv2 = TFConv(c_ * (len(k) + 1), c2, 1, 1, w=w.cv2)
self.m = [keras.layers.MaxPool2D(pool_size=x, strides=1, padding='SAME') for x in k] self.m = [keras.layers.MaxPool2D(pool_size=x, strides=1, padding='SAME') for x in k]
def call(self, inputs): def call(self, inputs):
@ -201,9 +180,9 @@ class tf_SPP(keras.layers.Layer):
return self.cv2(tf.concat([x] + [m(x) for m in self.m], 3)) return self.cv2(tf.concat([x] + [m(x) for m in self.m], 3))
class tf_Detect(keras.layers.Layer): class TFDetect(keras.layers.Layer):
def __init__(self, nc=80, anchors=(), ch=(), w=None): # detection layer def __init__(self, nc=80, anchors=(), ch=(), imgsz=(640, 640), w=None): # detection layer
super(tf_Detect, self).__init__() super(TFDetect, self).__init__()
self.stride = tf.convert_to_tensor(w.stride.numpy(), dtype=tf.float32) self.stride = tf.convert_to_tensor(w.stride.numpy(), dtype=tf.float32)
self.nc = nc # number of classes self.nc = nc # number of classes
self.no = nc + 5 # number of outputs per anchor self.no = nc + 5 # number of outputs per anchor
@ -213,22 +192,20 @@ class tf_Detect(keras.layers.Layer):
self.anchors = tf.convert_to_tensor(w.anchors.numpy(), dtype=tf.float32) self.anchors = tf.convert_to_tensor(w.anchors.numpy(), dtype=tf.float32)
self.anchor_grid = tf.reshape(tf.convert_to_tensor(w.anchor_grid.numpy(), dtype=tf.float32), self.anchor_grid = tf.reshape(tf.convert_to_tensor(w.anchor_grid.numpy(), dtype=tf.float32),
[self.nl, 1, -1, 1, 2]) [self.nl, 1, -1, 1, 2])
self.m = [tf_Conv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)] self.m = [TFConv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)]
self.export = False # onnx export self.training = False # set to False after building model
self.training = True # set to False after building model self.imgsz = imgsz
for i in range(self.nl): for i in range(self.nl):
ny, nx = opt.img_size[0] // self.stride[i], opt.img_size[1] // self.stride[i] ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]
self.grid[i] = self._make_grid(nx, ny) self.grid[i] = self._make_grid(nx, ny)
def call(self, inputs): def call(self, inputs):
# x = x.copy() # for profiling
z = [] # inference output z = [] # inference output
self.training |= self.export
x = [] x = []
for i in range(self.nl): for i in range(self.nl):
x.append(self.m[i](inputs[i])) x.append(self.m[i](inputs[i]))
# x(bs,20,20,255) to x(bs,3,20,20,85) # x(bs,20,20,255) to x(bs,3,20,20,85)
ny, nx = opt.img_size[0] // self.stride[i], opt.img_size[1] // self.stride[i] ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]
x[i] = tf.transpose(tf.reshape(x[i], [-1, ny * nx, self.na, self.no]), [0, 2, 1, 3]) x[i] = tf.transpose(tf.reshape(x[i], [-1, ny * nx, self.na, self.no]), [0, 2, 1, 3])
if not self.training: # inference if not self.training: # inference
@ -236,8 +213,8 @@ class tf_Detect(keras.layers.Layer):
xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]
# Normalize xywh to 0-1 to reduce calibration error # Normalize xywh to 0-1 to reduce calibration error
xy /= tf.constant([[opt.img_size[1], opt.img_size[0]]], dtype=tf.float32) xy /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
wh /= tf.constant([[opt.img_size[1], opt.img_size[0]]], dtype=tf.float32) wh /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
y = tf.concat([xy, wh, y[..., 4:]], -1) y = tf.concat([xy, wh, y[..., 4:]], -1)
z.append(tf.reshape(y, [-1, 3 * ny * nx, self.no])) z.append(tf.reshape(y, [-1, 3 * ny * nx, self.no]))
@ -251,25 +228,23 @@ class tf_Detect(keras.layers.Layer):
return tf.cast(tf.reshape(tf.stack([xv, yv], 2), [1, 1, ny * nx, 2]), dtype=tf.float32) return tf.cast(tf.reshape(tf.stack([xv, yv], 2), [1, 1, ny * nx, 2]), dtype=tf.float32)
class tf_Upsample(keras.layers.Layer): class TFUpsample(keras.layers.Layer):
def __init__(self, size, scale_factor, mode, w=None): def __init__(self, size, scale_factor, mode, w=None): # warning: all arguments needed including 'w'
super(tf_Upsample, self).__init__() super(TFUpsample, self).__init__()
assert scale_factor == 2, "scale_factor must be 2" assert scale_factor == 2, "scale_factor must be 2"
self.upsample = lambda x: tf.image.resize(x, (x.shape[1] * 2, x.shape[2] * 2), method=mode)
# self.upsample = keras.layers.UpSampling2D(size=scale_factor, interpolation=mode) # self.upsample = keras.layers.UpSampling2D(size=scale_factor, interpolation=mode)
if opt.tf_raw_resize: # with default arguments: align_corners=False, half_pixel_centers=False
# with default arguments: align_corners=False, half_pixel_centers=False # self.upsample = lambda x: tf.raw_ops.ResizeNearestNeighbor(images=x,
self.upsample = lambda x: tf.raw_ops.ResizeNearestNeighbor(images=x, # size=(x.shape[1] * 2, x.shape[2] * 2))
size=(x.shape[1] * 2, x.shape[2] * 2))
else:
self.upsample = lambda x: tf.image.resize(x, (x.shape[1] * 2, x.shape[2] * 2), method=mode)
def call(self, inputs): def call(self, inputs):
return self.upsample(inputs) return self.upsample(inputs)
class tf_Concat(keras.layers.Layer): class TFConcat(keras.layers.Layer):
def __init__(self, dimension=1, w=None): def __init__(self, dimension=1, w=None):
super(tf_Concat, self).__init__() super(TFConcat, self).__init__()
assert dimension == 1, "convert only NCHW to NHWC concat" assert dimension == 1, "convert only NCHW to NHWC concat"
self.d = 3 self.d = 3
@ -277,8 +252,8 @@ class tf_Concat(keras.layers.Layer):
return tf.concat(inputs, self.d) return tf.concat(inputs, self.d)
def parse_model(d, ch, model): # model_dict, input_channels(3) def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)
logger.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments')) LOGGER.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'] anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
no = na * (nc + 5) # number of outputs = anchors * (classes + 5) no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
@ -310,10 +285,11 @@ def parse_model(d, ch, model): # model_dict, input_channels(3)
args.append([ch[x + 1] for x in f]) args.append([ch[x + 1] for x in f])
if isinstance(args[1], int): # number of anchors if isinstance(args[1], int): # number of anchors
args[1] = [list(range(args[1] * 2))] * len(f) args[1] = [list(range(args[1] * 2))] * len(f)
args.append(imgsz)
else: else:
c2 = ch[f] c2 = ch[f]
tf_m = eval('tf_' + m_str.replace('nn.', '')) tf_m = eval('TF' + m_str.replace('nn.', ''))
m_ = keras.Sequential([tf_m(*args, w=model.model[i][j]) for j in range(n)]) if n > 1 \ m_ = keras.Sequential([tf_m(*args, w=model.model[i][j]) for j in range(n)]) if n > 1 \
else tf_m(*args, w=model.model[i]) # module else tf_m(*args, w=model.model[i]) # module
@ -321,16 +297,16 @@ def parse_model(d, ch, model): # model_dict, input_channels(3)
t = str(m)[8:-2].replace('__main__.', '') # module type t = str(m)[8:-2].replace('__main__.', '') # module type
np = sum([x.numel() for x in torch_m_.parameters()]) # number params np = sum([x.numel() for x in torch_m_.parameters()]) # number params
m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
logger.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print LOGGER.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
layers.append(m_) layers.append(m_)
ch.append(c2) ch.append(c2)
return keras.Sequential(layers), sorted(save) return keras.Sequential(layers), sorted(save)
class tf_Model(): class TFModel:
def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, model=None): # model, input channels, number of classes def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, model=None, imgsz=(640, 640)): # model, channels, classes
super(tf_Model, self).__init__() super(TFModel, self).__init__()
if isinstance(cfg, dict): if isinstance(cfg, dict):
self.yaml = cfg # model dict self.yaml = cfg # model dict
else: # is *.yaml else: # is *.yaml
@ -343,9 +319,10 @@ class tf_Model():
if nc and nc != self.yaml['nc']: if nc and nc != self.yaml['nc']:
print('Overriding %s nc=%g with nc=%g' % (cfg, self.yaml['nc'], nc)) print('Overriding %s nc=%g with nc=%g' % (cfg, self.yaml['nc'], nc))
self.yaml['nc'] = nc # override yaml value self.yaml['nc'] = nc # override yaml value
self.model, self.savelist = parse_model(deepcopy(self.yaml), ch=[ch], model=model) # model, savelist, ch_out self.model, self.savelist = parse_model(deepcopy(self.yaml), ch=[ch], model=model, imgsz=imgsz)
def predict(self, inputs, profile=False): def predict(self, inputs, tf_nms=False, agnostic_nms=False, topk_per_class=100, topk_all=100, iou_thres=0.45,
conf_thres=0.25):
y = [] # outputs y = [] # outputs
x = inputs x = inputs
for i, m in enumerate(self.model.layers): for i, m in enumerate(self.model.layers):
@ -356,18 +333,18 @@ class tf_Model():
y.append(x if m.i in self.savelist else None) # save output y.append(x if m.i in self.savelist else None) # save output
# Add TensorFlow NMS # Add TensorFlow NMS
if opt.tf_nms: if tf_nms:
boxes = xywh2xyxy(x[0][..., :4]) boxes = self._xywh2xyxy(x[0][..., :4])
probs = x[0][:, :, 4:5] probs = x[0][:, :, 4:5]
classes = x[0][:, :, 5:] classes = x[0][:, :, 5:]
scores = probs * classes scores = probs * classes
if opt.agnostic_nms: if agnostic_nms:
nms = agnostic_nms_layer()((boxes, classes, scores)) nms = AgnosticNMS()((boxes, classes, scores), topk_all, iou_thres, conf_thres)
return nms, x[1] return nms, x[1]
else: else:
boxes = tf.expand_dims(boxes, 2) boxes = tf.expand_dims(boxes, 2)
nms = tf.image.combined_non_max_suppression( nms = tf.image.combined_non_max_suppression(
boxes, scores, opt.topk_per_class, opt.topk_all, opt.iou_thres, opt.score_thres, clip_boxes=False) boxes, scores, topk_per_class, topk_all, iou_thres, conf_thres, clip_boxes=False)
return nms, x[1] return nms, x[1]
return x[0] # output only first tensor [1,6300,85] = [xywh, conf, class0, class1, ...] return x[0] # output only first tensor [1,6300,85] = [xywh, conf, class0, class1, ...]
@ -377,182 +354,94 @@ class tf_Model():
# cls = tf.reshape(tf.cast(tf.argmax(x[..., 5:], axis=1), tf.float32), (-1, 1)) # x(6300,1) classes # cls = tf.reshape(tf.cast(tf.argmax(x[..., 5:], axis=1), tf.float32), (-1, 1)) # x(6300,1) classes
# return tf.concat([conf, cls, xywh], 1) # return tf.concat([conf, cls, xywh], 1)
@staticmethod
def _xywh2xyxy(xywh):
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
x, y, w, h = tf.split(xywh, num_or_size_splits=4, axis=-1)
return tf.concat([x - w / 2, y - h / 2, x + w / 2, y + h / 2], axis=-1)
class agnostic_nms_layer(keras.layers.Layer):
# wrap map_fn to avoid TypeSpec related error https://stackoverflow.com/a/65809989/3036450 class AgnosticNMS(keras.layers.Layer):
def call(self, input): # TF Agnostic NMS
return tf.map_fn(agnostic_nms, input, def call(self, input, topk_all, iou_thres, conf_thres):
# wrap map_fn to avoid TypeSpec related error https://stackoverflow.com/a/65809989/3036450
return tf.map_fn(self._nms, input,
fn_output_signature=(tf.float32, tf.float32, tf.float32, tf.int32), fn_output_signature=(tf.float32, tf.float32, tf.float32, tf.int32),
name='agnostic_nms') name='agnostic_nms')
@staticmethod
def agnostic_nms(x): def _nms(x, topk_all=100, iou_thres=0.45, conf_thres=0.25): # agnostic NMS
boxes, classes, scores = x boxes, classes, scores = x
class_inds = tf.cast(tf.argmax(classes, axis=-1), tf.float32) class_inds = tf.cast(tf.argmax(classes, axis=-1), tf.float32)
scores_inp = tf.reduce_max(scores, -1) scores_inp = tf.reduce_max(scores, -1)
selected_inds = tf.image.non_max_suppression( selected_inds = tf.image.non_max_suppression(
boxes, scores_inp, max_output_size=opt.topk_all, iou_threshold=opt.iou_thres, score_threshold=opt.score_thres) boxes, scores_inp, max_output_size=topk_all, iou_threshold=iou_thres, score_threshold=conf_thres)
selected_boxes = tf.gather(boxes, selected_inds) selected_boxes = tf.gather(boxes, selected_inds)
padded_boxes = tf.pad(selected_boxes, padded_boxes = tf.pad(selected_boxes,
paddings=[[0, opt.topk_all - tf.shape(selected_boxes)[0]], [0, 0]], paddings=[[0, topk_all - tf.shape(selected_boxes)[0]], [0, 0]],
mode="CONSTANT", constant_values=0.0) mode="CONSTANT", constant_values=0.0)
selected_scores = tf.gather(scores_inp, selected_inds) selected_scores = tf.gather(scores_inp, selected_inds)
padded_scores = tf.pad(selected_scores, padded_scores = tf.pad(selected_scores,
paddings=[[0, opt.topk_all - tf.shape(selected_boxes)[0]]], paddings=[[0, topk_all - tf.shape(selected_boxes)[0]]],
mode="CONSTANT", constant_values=-1.0) mode="CONSTANT", constant_values=-1.0)
selected_classes = tf.gather(class_inds, selected_inds) selected_classes = tf.gather(class_inds, selected_inds)
padded_classes = tf.pad(selected_classes, padded_classes = tf.pad(selected_classes,
paddings=[[0, opt.topk_all - tf.shape(selected_boxes)[0]]], paddings=[[0, topk_all - tf.shape(selected_boxes)[0]]],
mode="CONSTANT", constant_values=-1.0) mode="CONSTANT", constant_values=-1.0)
valid_detections = tf.shape(selected_inds)[0] valid_detections = tf.shape(selected_inds)[0]
return padded_boxes, padded_scores, padded_classes, valid_detections return padded_boxes, padded_scores, padded_classes, valid_detections
def xywh2xyxy(xywh): def representative_dataset_gen(dataset, ncalib=100):
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right # Representative dataset generator for use with converter.representative_dataset, returns a generator of np arrays
x, y, w, h = tf.split(xywh, num_or_size_splits=4, axis=-1) for n, (path, img, im0s, vid_cap) in enumerate(dataset):
return tf.concat([x - w / 2, y - h / 2, x + w / 2, y + h / 2], axis=-1)
def representative_dataset_gen():
# Representative dataset for use with converter.representative_dataset
n = 0
for path, img, im0s, vid_cap in dataset:
# Get sample input data as a numpy array in a method of your choosing.
n += 1
input = np.transpose(img, [1, 2, 0]) input = np.transpose(img, [1, 2, 0])
input = np.expand_dims(input, axis=0).astype(np.float32) input = np.expand_dims(input, axis=0).astype(np.float32)
input /= 255.0 input /= 255.0
yield [input] yield [input]
if n >= opt.ncalib: if n >= ncalib:
break break
if __name__ == "__main__": def run(weights=ROOT / 'yolov5s.pt', # weights path
imgsz=(640, 640), # inference size h,w
batch_size=1, # batch size
dynamic=False, # dynamic batch size
):
# PyTorch model
im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image
model = attempt_load(weights, map_location=torch.device('cpu'), inplace=True, fuse=False)
y = model(im) # inference
model.info()
# TensorFlow model
im = tf.zeros((batch_size, *imgsz, 3)) # BHWC image
tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
y = tf_model.predict(im) # inference
# Keras model
im = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
keras_model = keras.Model(inputs=im, outputs=tf_model.predict(im))
keras_model.summary()
def parse_opt():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='cfg path') parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s.pt', help='weights path')
parser.add_argument('--weights', type=str, default='yolov5s.pt', help='weights path') parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
parser.add_argument('--img-size', nargs='+', type=int, default=[320, 320], help='image size') # height, width
parser.add_argument('--batch-size', type=int, default=1, help='batch size') parser.add_argument('--batch-size', type=int, default=1, help='batch size')
parser.add_argument('--dynamic-batch-size', action='store_true', help='dynamic batch size') parser.add_argument('--dynamic', action='store_true', help='dynamic batch size')
parser.add_argument('--source', type=str, default='../data/coco128.yaml', help='dir of images or data.yaml file')
parser.add_argument('--ncalib', type=int, default=100, help='number of calibration images')
parser.add_argument('--tfl-int8', action='store_true', dest='tfl_int8', help='export TFLite int8 model')
parser.add_argument('--tf-nms', action='store_true', dest='tf_nms', help='TF NMS (without TFLite export)')
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
parser.add_argument('--tf-raw-resize', action='store_true', dest='tf_raw_resize',
help='use tf.raw_ops.ResizeNearestNeighbor for resize')
parser.add_argument('--topk-per-class', type=int, default=100, help='topk per class to keep in NMS')
parser.add_argument('--topk-all', type=int, default=100, help='topk for all classes to keep in NMS')
parser.add_argument('--iou-thres', type=float, default=0.5, help='IOU threshold for NMS')
parser.add_argument('--score-thres', type=float, default=0.4, help='score threshold for NMS')
opt = parser.parse_args() opt = parser.parse_args()
opt.cfg = check_yaml(opt.cfg) # check YAML opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand return opt
print(opt)
# Input
img = torch.zeros((opt.batch_size, 3, *opt.img_size)) # image size(1,3,320,192) iDetection
# Load PyTorch model def main(opt):
model = attempt_load(opt.weights, map_location=torch.device('cpu'), inplace=True, fuse=False) set_logging()
model.model[-1].export = False # set Detect() layer export=True print(colorstr('tf.py: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
y = model(img) # dry run run(**vars(opt))
nc = y[0].shape[-1] - 5
# TensorFlow saved_model export
try:
print('\nStarting TensorFlow saved_model export with TensorFlow %s...' % tf.__version__)
tf_model = tf_Model(opt.cfg, model=model, nc=nc)
img = tf.zeros((opt.batch_size, *opt.img_size, 3)) # NHWC Input for TensorFlow
m = tf_model.model.layers[-1] if __name__ == "__main__":
assert isinstance(m, tf_Detect), "the last layer must be Detect" opt = parse_opt()
m.training = False main(opt)
y = tf_model.predict(img)
inputs = keras.Input(shape=(*opt.img_size, 3), batch_size=None if opt.dynamic_batch_size else opt.batch_size)
keras_model = keras.Model(inputs=inputs, outputs=tf_model.predict(inputs))
keras_model.summary()
path = opt.weights.replace('.pt', '_saved_model') # filename
keras_model.save(path, save_format='tf')
print('TensorFlow saved_model export success, saved as %s' % path)
except Exception as e:
print('TensorFlow saved_model export failure: %s' % e)
traceback.print_exc(file=sys.stdout)
# TensorFlow GraphDef export
try:
print('\nStarting TensorFlow GraphDef export with TensorFlow %s...' % tf.__version__)
# https://github.com/leimao/Frozen_Graph_TensorFlow
full_model = tf.function(lambda x: keras_model(x))
full_model = full_model.get_concrete_function(
tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()
f = opt.weights.replace('.pt', '.pb') # filename
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
logdir=os.path.dirname(f),
name=os.path.basename(f),
as_text=False)
print('TensorFlow GraphDef export success, saved as %s' % f)
except Exception as e:
print('TensorFlow GraphDef export failure: %s' % e)
traceback.print_exc(file=sys.stdout)
# TFLite model export
if not opt.tf_nms:
try:
print('\nStarting TFLite export with TensorFlow %s...' % tf.__version__)
# fp32 TFLite model export ---------------------------------------------------------------------------------
# converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
# converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
# converter.allow_custom_ops = False
# converter.experimental_new_converter = True
# tflite_model = converter.convert()
# f = opt.weights.replace('.pt', '.tflite') # filename
# open(f, "wb").write(tflite_model)
# fp16 TFLite model export ---------------------------------------------------------------------------------
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# converter.representative_dataset = representative_dataset_gen
# converter.target_spec.supported_types = [tf.float16]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
converter.allow_custom_ops = False
converter.experimental_new_converter = True
tflite_model = converter.convert()
f = opt.weights.replace('.pt', '-fp16.tflite') # filename
open(f, "wb").write(tflite_model)
print('\nTFLite export success, saved as %s' % f)
# int8 TFLite model export ---------------------------------------------------------------------------------
if opt.tfl_int8:
# Representative Dataset
if opt.source.endswith('.yaml'):
with open(check_yaml(opt.source)) as f:
data = yaml.load(f, Loader=yaml.FullLoader) # data dict
check_dataset(data) # check
opt.source = data['train']
dataset = LoadImages(opt.source, img_size=opt.img_size, auto=False)
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8 # or tf.int8
converter.inference_output_type = tf.uint8 # or tf.int8
converter.allow_custom_ops = False
converter.experimental_new_converter = True
converter.experimental_new_quantizer = False
tflite_model = converter.convert()
f = opt.weights.replace('.pt', '-int8.tflite') # filename
open(f, "wb").write(tflite_model)
print('\nTFLite (int8) export success, saved as %s' % f)
except Exception as e:
print('\nTFLite export failure: %s' % e)
traceback.print_exc(file=sys.stdout)

View File

@ -1,6 +1,6 @@
# pip install -r requirements.txt # pip install -r requirements.txt
# base ---------------------------------------- # Base ----------------------------------------
matplotlib>=3.2.2 matplotlib>=3.2.2
numpy>=1.18.5 numpy>=1.18.5
opencv-python>=4.1.2 opencv-python>=4.1.2
@ -11,21 +11,23 @@ torch>=1.7.0
torchvision>=0.8.1 torchvision>=0.8.1
tqdm>=4.41.0 tqdm>=4.41.0
# logging ------------------------------------- # Logging -------------------------------------
tensorboard>=2.4.1 tensorboard>=2.4.1
# wandb # wandb
# plotting ------------------------------------ # Plotting ------------------------------------
seaborn>=0.11.0 seaborn>=0.11.0
pandas pandas
# export -------------------------------------- # Export --------------------------------------
# coremltools>=4.1 # coremltools>=4.1 # CoreML export
# onnx>=1.9.0 # onnx>=1.9.0 # ONNX export
# scikit-learn==0.19.2 # for coreml quantization # onnx-simplifier>=0.3.6 # ONNX simplifier
# tensorflow==2.4.1 # for TFLite export # scikit-learn==0.19.2 # CoreML quantization
# tensorflow>=2.4.1 # TFLite export
# tensorflowjs>=3.9.0 # TF.js export
# extras -------------------------------------- # Extras --------------------------------------
# Cython # for pycocotools https://github.com/cocodataset/cocoapi/issues/172 # Cython # for pycocotools https://github.com/cocodataset/cocoapi/issues/172
# pycocotools>=2.0 # COCO mAP # pycocotools>=2.0 # COCO mAP
# albumentations>=1.0.3 # albumentations>=1.0.3

View File

@ -161,9 +161,15 @@ def emojis(str=''):
return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
def file_size(file): def file_size(path):
# Return file size in MB # Return file/dir size (MB)
return Path(file).stat().st_size / 1e6 path = Path(path)
if path.is_file():
return path.stat().st_size / 1E6
elif path.is_dir():
return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / 1E6
else:
return 0.0
def check_online(): def check_online():