Add OpenVINO metadata to export (#7947)
* Write .yaml file when exporting model to openvino Write .yaml file automatically when exporting model to openvino to be used during inference * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update export.py * Update export.py * Load metadata on inference * Update common.py Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>pull/7952/head
parent
541a5b72bb
commit
a3a652c933
|
@ -54,6 +54,7 @@ from pathlib import Path
|
|||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import yaml
|
||||
from torch.utils.mobile_optimizer import optimize_for_mobile
|
||||
|
||||
FILE = Path(__file__).resolve()
|
||||
|
@ -168,7 +169,7 @@ def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorst
|
|||
LOGGER.info(f'{prefix} export failure: {e}')
|
||||
|
||||
|
||||
def export_openvino(file, half, prefix=colorstr('OpenVINO:')):
|
||||
def export_openvino(model, file, half, prefix=colorstr('OpenVINO:')):
|
||||
# YOLOv5 OpenVINO export
|
||||
try:
|
||||
check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/
|
||||
|
@ -178,7 +179,9 @@ def export_openvino(file, half, prefix=colorstr('OpenVINO:')):
|
|||
f = str(file).replace('.pt', f'_openvino_model{os.sep}')
|
||||
|
||||
cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} --data_type {'FP16' if half else 'FP32'}"
|
||||
subprocess.check_output(cmd.split())
|
||||
subprocess.check_output(cmd.split()) # export
|
||||
with open(Path(f) / 'meta.yaml', 'w') as g:
|
||||
yaml.dump({'stride': int(max(model.stride)), 'names': model.names}, g) # add metadata.yaml
|
||||
|
||||
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
|
||||
return f
|
||||
|
@ -520,7 +523,7 @@ def run(
|
|||
if onnx or xml: # OpenVINO requires ONNX
|
||||
f[2] = export_onnx(model, im, file, opset, train, dynamic, simplify)
|
||||
if xml: # OpenVINO
|
||||
f[3] = export_openvino(file, half)
|
||||
f[3] = export_openvino(model, file, half)
|
||||
if coreml:
|
||||
_, f[4] = export_coreml(model, im, file, int8, half)
|
||||
|
||||
|
|
|
@ -326,9 +326,6 @@ class DetectMultiBackend(nn.Module):
|
|||
stride, names = 32, [f'class{i}' for i in range(1000)] # assign defaults
|
||||
w = attempt_download(w) # download if not local
|
||||
fp16 &= (pt or jit or onnx or engine) and device.type != 'cpu' # FP16
|
||||
if data: # data.yaml path (optional)
|
||||
with open(data, errors='ignore') as f:
|
||||
names = yaml.safe_load(f)['names'] # class names
|
||||
|
||||
if pt: # PyTorch
|
||||
model = attempt_load(weights if isinstance(weights, list) else w, map_location=device)
|
||||
|
@ -367,7 +364,8 @@ class DetectMultiBackend(nn.Module):
|
|||
w = next(Path(w).glob('*.xml')) # get *.xml file from *_openvino_model dir
|
||||
network = ie.read_model(model=w, weights=Path(w).with_suffix('.bin'))
|
||||
executable_network = ie.compile_model(model=network, device_name="CPU")
|
||||
self.output_layer = next(iter(executable_network.outputs))
|
||||
output_layer = next(iter(executable_network.outputs))
|
||||
self._load_metadata(w.parent / 'meta.yaml') # load metadata
|
||||
elif engine: # TensorRT
|
||||
LOGGER.info(f'Loading {w} for TensorRT inference...')
|
||||
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
|
||||
|
@ -433,7 +431,11 @@ class DetectMultiBackend(nn.Module):
|
|||
output_details = interpreter.get_output_details() # outputs
|
||||
elif tfjs:
|
||||
raise Exception('ERROR: YOLOv5 TF.js inference is not supported')
|
||||
|
||||
self.__dict__.update(locals()) # assign all variables to self
|
||||
if not hasattr(self, 'names') and data: # assign class names (optional)
|
||||
with open(data, errors='ignore') as f:
|
||||
names = yaml.safe_load(f)['names']
|
||||
|
||||
def forward(self, im, augment=False, visualize=False, val=False):
|
||||
# YOLOv5 MultiBackend inference
|
||||
|
@ -493,13 +495,20 @@ class DetectMultiBackend(nn.Module):
|
|||
y = torch.tensor(y, device=self.device)
|
||||
return (y, []) if val else y
|
||||
|
||||
def _load_metadata(self, f='path/to/meta.yaml'):
|
||||
# Load metadata from meta.yaml if it exists
|
||||
if Path(f).is_file():
|
||||
with open(f, errors='ignore') as f:
|
||||
for k, v in yaml.safe_load(f).items():
|
||||
setattr(self, k, v) # assign stride, names
|
||||
|
||||
def warmup(self, imgsz=(1, 3, 640, 640)):
|
||||
# Warmup model by running inference once
|
||||
if any((self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb)): # warmup types
|
||||
if self.device.type != 'cpu': # only warmup GPU models
|
||||
im = torch.zeros(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
|
||||
for _ in range(2 if self.jit else 1): #
|
||||
self.forward(im) # warmup
|
||||
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb
|
||||
if any(warmup_types) and self.device.type != 'cpu':
|
||||
im = torch.zeros(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
|
||||
for _ in range(2 if self.jit else 1): #
|
||||
self.forward(im) # warmup
|
||||
|
||||
@staticmethod
|
||||
def model_type(p='path/to/model.pt'):
|
||||
|
|
Loading…
Reference in New Issue