YOLOv5 Export Benchmarks (#6613)
* Add benchmarks.py * Update * Add requirements * Updates * Updates * Updates * Updates * Updates * Updates * dataset autodownload from root * Update * Redirect to /dev/null * sudo --help * Cleanup * Add exports pd df * Updates * Updates * Updates * Cleanup * dir handling fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Cleanup * Cleanup2 * Cleanup3 * Cleanup model_type Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/6629/head
parent
96d8f86085
commit
a45e472358
17
export.py
17
export.py
|
@ -52,6 +52,7 @@ import time
|
|||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.mobile_optimizer import optimize_for_mobile
|
||||
|
@ -72,6 +73,22 @@ from utils.general import (LOGGER, check_dataset, check_img_size, check_requirem
|
|||
from utils.torch_utils import select_device
|
||||
|
||||
|
||||
def export_formats():
|
||||
# YOLOv5 export formats
|
||||
x = [['PyTorch', '-', '.pt'],
|
||||
['TorchScript', 'torchscript', '.torchscript'],
|
||||
['ONNX', 'onnx', '.onnx'],
|
||||
['OpenVINO', 'openvino', '_openvino_model'],
|
||||
['TensorRT', 'engine', '.engine'],
|
||||
['CoreML', 'coreml', '.mlmodel'],
|
||||
['TensorFlow SavedModel', 'saved_model', '_saved_model'],
|
||||
['TensorFlow GraphDef', 'pb', '.pb'],
|
||||
['TensorFlow Lite', 'tflite', '.tflite'],
|
||||
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite'],
|
||||
['TensorFlow.js', 'tfjs', '_web_model']]
|
||||
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix'])
|
||||
|
||||
|
||||
def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):
|
||||
# YOLOv5 TorchScript model export
|
||||
try:
|
||||
|
|
|
@ -294,10 +294,7 @@ class DetectMultiBackend(nn.Module):
|
|||
|
||||
super().__init__()
|
||||
w = str(weights[0] if isinstance(weights, list) else weights)
|
||||
suffix = Path(w).suffix.lower()
|
||||
suffixes = ['.pt', '.torchscript', '.onnx', '.engine', '.tflite', '.pb', '', '.mlmodel', '.xml']
|
||||
check_suffix(w, suffixes) # check weights have acceptable suffix
|
||||
pt, jit, onnx, engine, tflite, pb, saved_model, coreml, xml = (suffix == x for x in suffixes) # backends
|
||||
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = self.model_type(w) # get backend
|
||||
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
|
||||
w = attempt_download(w) # download if not local
|
||||
if data: # data.yaml path (optional)
|
||||
|
@ -332,6 +329,8 @@ class DetectMultiBackend(nn.Module):
|
|||
check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/
|
||||
import openvino.inference_engine as ie
|
||||
core = ie.IECore()
|
||||
if not Path(w).is_file(): # if not *.xml
|
||||
w = next(Path(w).glob('*.xml')) # get *.xml file from *_openvino_model dir
|
||||
network = core.read_network(model=w, weights=Path(w).with_suffix('.bin')) # *.xml, *.bin paths
|
||||
executable_network = core.load_network(network, device_name='CPU', num_requests=1)
|
||||
elif engine: # TensorRT
|
||||
|
@ -459,6 +458,18 @@ class DetectMultiBackend(nn.Module):
|
|||
im = torch.zeros(*imgsz).to(self.device).type(torch.half if half else torch.float) # input image
|
||||
self.forward(im) # warmup
|
||||
|
||||
@staticmethod
|
||||
def model_type(p='path/to/model.pt'):
|
||||
# Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
|
||||
from export import export_formats
|
||||
suffixes = list(export_formats().Suffix) + ['.xml'] # export suffixes
|
||||
check_suffix(p, suffixes) # checks
|
||||
p = Path(p).name # eliminate trailing separators
|
||||
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, xml2 = (s in p for s in suffixes)
|
||||
xml |= xml2 # *_openvino_model or *.xml
|
||||
tflite &= not edgetpu # *.tflite
|
||||
return pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs
|
||||
|
||||
|
||||
class AutoShape(nn.Module):
|
||||
# YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
||||
"""
|
||||
Run YOLOv5 benchmarks on all supported export formats
|
||||
|
||||
Format | `export.py --include` | Model
|
||||
--- | --- | ---
|
||||
PyTorch | - | yolov5s.pt
|
||||
TorchScript | `torchscript` | yolov5s.torchscript
|
||||
ONNX | `onnx` | yolov5s.onnx
|
||||
OpenVINO | `openvino` | yolov5s_openvino_model/
|
||||
TensorRT | `engine` | yolov5s.engine
|
||||
CoreML | `coreml` | yolov5s.mlmodel
|
||||
TensorFlow SavedModel | `saved_model` | yolov5s_saved_model/
|
||||
TensorFlow GraphDef | `pb` | yolov5s.pb
|
||||
TensorFlow Lite | `tflite` | yolov5s.tflite
|
||||
TensorFlow Edge TPU | `edgetpu` | yolov5s_edgetpu.tflite
|
||||
TensorFlow.js | `tfjs` | yolov5s_web_model/
|
||||
|
||||
Requirements:
|
||||
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime openvino-dev tensorflow-cpu # CPU
|
||||
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime-gpu openvino-dev tensorflow # GPU
|
||||
|
||||
Usage:
|
||||
$ python utils/benchmarks.py --weights yolov5s.pt --img 640
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
|
||||
FILE = Path(__file__).resolve()
|
||||
ROOT = FILE.parents[1] # YOLOv5 root directory
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.append(str(ROOT)) # add ROOT to PATH
|
||||
# ROOT = ROOT.relative_to(Path.cwd()) # relative
|
||||
|
||||
import export
|
||||
import val
|
||||
from utils import notebook_init
|
||||
from utils.general import LOGGER, print_args
|
||||
|
||||
|
||||
def run(weights=ROOT / 'yolov5s.pt', # weights path
|
||||
imgsz=640, # inference size (pixels)
|
||||
batch_size=1, # batch size
|
||||
data=ROOT / 'data/coco128.yaml', # dataset.yaml path
|
||||
):
|
||||
y, t = [], time.time()
|
||||
formats = export.export_formats()
|
||||
for i, (name, f, suffix) in formats.iterrows(): # index, (name, file, suffix)
|
||||
try:
|
||||
w = weights if f == '-' else export.run(weights=weights, imgsz=[imgsz], include=[f], device='cpu')[-1]
|
||||
assert suffix in str(w), 'export failed'
|
||||
result = val.run(data, w, batch_size, imgsz=imgsz, plots=False, device='cpu', task='benchmark')
|
||||
metrics = result[0] # metrics (mp, mr, map50, map, *losses(box, obj, cls))
|
||||
speeds = result[2] # times (preprocess, inference, postprocess)
|
||||
y.append([name, metrics[3], speeds[1]]) # mAP, t_inference
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'WARNING: Benchmark failure for {name}: {e}')
|
||||
y.append([name, None, None]) # mAP, t_inference
|
||||
|
||||
# Print results
|
||||
LOGGER.info('\n')
|
||||
parse_opt()
|
||||
notebook_init() # print system info
|
||||
py = pd.DataFrame(y, columns=['Format', 'mAP@0.5:0.95', 'Inference time (ms)'])
|
||||
LOGGER.info(f'\nBenchmarks complete ({time.time() - t:.2f}s)')
|
||||
LOGGER.info(str(py))
|
||||
return py
|
||||
|
||||
|
||||
def parse_opt():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s.pt', help='weights path')
|
||||
parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)')
|
||||
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
|
||||
parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
|
||||
opt = parser.parse_args()
|
||||
print_args(FILE.stem, opt)
|
||||
return opt
|
||||
|
||||
|
||||
def main(opt):
|
||||
run(**vars(opt))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
opt = parse_opt()
|
||||
main(opt)
|
5
val.py
5
val.py
|
@ -163,9 +163,10 @@ def run(data,
|
|||
# Dataloader
|
||||
if not training:
|
||||
model.warmup(imgsz=(1 if pt else batch_size, 3, imgsz, imgsz), half=half) # warmup
|
||||
pad = 0.0 if task == 'speed' else 0.5
|
||||
pad = 0.0 if task in ('speed', 'benchmark') else 0.5
|
||||
rect = False if task == 'benchmark' else pt # square inference for benchmarks
|
||||
task = task if task in ('train', 'val', 'test') else 'val' # path to train/val/test images
|
||||
dataloader = create_dataloader(data[task], imgsz, batch_size, stride, single_cls, pad=pad, rect=pt,
|
||||
dataloader = create_dataloader(data[task], imgsz, batch_size, stride, single_cls, pad=pad, rect=rect,
|
||||
workers=workers, prefix=colorstr(f'{task}: '))[0]
|
||||
|
||||
seen = 0
|
||||
|
|
Loading…
Reference in New Issue