YOLOv5 Export Benchmarks for GPU (#6963)
* Add benchmarks.py GPU support * Updates * Updates * Updates * Updates * Add --half * Add TRT requirements * Cleanup * Add TF to warmup types * Update export.py * Update export.py * Update benchmarks.pypull/6988/head
parent
99de551f97
commit
932dc78496
24
export.py
24
export.py
|
@ -75,18 +75,18 @@ from utils.torch_utils import select_device
|
|||
|
||||
def export_formats():
|
||||
# YOLOv5 export formats
|
||||
x = [['PyTorch', '-', '.pt'],
|
||||
['TorchScript', 'torchscript', '.torchscript'],
|
||||
['ONNX', 'onnx', '.onnx'],
|
||||
['OpenVINO', 'openvino', '_openvino_model'],
|
||||
['TensorRT', 'engine', '.engine'],
|
||||
['CoreML', 'coreml', '.mlmodel'],
|
||||
['TensorFlow SavedModel', 'saved_model', '_saved_model'],
|
||||
['TensorFlow GraphDef', 'pb', '.pb'],
|
||||
['TensorFlow Lite', 'tflite', '.tflite'],
|
||||
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite'],
|
||||
['TensorFlow.js', 'tfjs', '_web_model']]
|
||||
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix'])
|
||||
x = [['PyTorch', '-', '.pt', True],
|
||||
['TorchScript', 'torchscript', '.torchscript', True],
|
||||
['ONNX', 'onnx', '.onnx', True],
|
||||
['OpenVINO', 'openvino', '_openvino_model', False],
|
||||
['TensorRT', 'engine', '.engine', True],
|
||||
['CoreML', 'coreml', '.mlmodel', False],
|
||||
['TensorFlow SavedModel', 'saved_model', '_saved_model', True],
|
||||
['TensorFlow GraphDef', 'pb', '.pb', True],
|
||||
['TensorFlow Lite', 'tflite', '.tflite', False],
|
||||
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False],
|
||||
['TensorFlow.js', 'tfjs', '_web_model', False]]
|
||||
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'GPU'])
|
||||
|
||||
|
||||
def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):
|
||||
|
|
|
@ -464,10 +464,11 @@ class DetectMultiBackend(nn.Module):
|
|||
|
||||
def warmup(self, imgsz=(1, 3, 640, 640)):
|
||||
# Warmup model by running inference once
|
||||
if self.pt or self.jit or self.onnx or self.engine: # warmup types
|
||||
if isinstance(self.device, torch.device) and self.device.type != 'cpu': # only warmup GPU models
|
||||
if any((self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb)): # warmup types
|
||||
if self.device.type != 'cpu': # only warmup GPU models
|
||||
im = torch.zeros(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
|
||||
self.forward(im) # warmup
|
||||
for _ in range(2 if self.jit else 1): #
|
||||
self.forward(im) # warmup
|
||||
|
||||
@staticmethod
|
||||
def model_type(p='path/to/model.pt'):
|
||||
|
|
|
@ -19,6 +19,7 @@ TensorFlow.js | `tfjs` | yolov5s_web_model/
|
|||
Requirements:
|
||||
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime openvino-dev tensorflow-cpu # CPU
|
||||
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime-gpu openvino-dev tensorflow # GPU
|
||||
$ pip install -U nvidia-tensorrt --index-url https://pypi.ngc.nvidia.com # TensorRT
|
||||
|
||||
Usage:
|
||||
$ python utils/benchmarks.py --weights yolov5s.pt --img 640
|
||||
|
@ -41,20 +42,29 @@ import export
|
|||
import val
|
||||
from utils import notebook_init
|
||||
from utils.general import LOGGER, print_args
|
||||
from utils.torch_utils import select_device
|
||||
|
||||
|
||||
def run(weights=ROOT / 'yolov5s.pt', # weights path
|
||||
imgsz=640, # inference size (pixels)
|
||||
batch_size=1, # batch size
|
||||
data=ROOT / 'data/coco128.yaml', # dataset.yaml path
|
||||
device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
|
||||
half=False, # use FP16 half-precision inference
|
||||
):
|
||||
y, t = [], time.time()
|
||||
formats = export.export_formats()
|
||||
for i, (name, f, suffix) in formats.iterrows(): # index, (name, file, suffix)
|
||||
device = select_device(device)
|
||||
for i, (name, f, suffix, gpu) in formats.iterrows(): # index, (name, file, suffix, gpu-capable)
|
||||
try:
|
||||
w = weights if f == '-' else export.run(weights=weights, imgsz=[imgsz], include=[f], device='cpu')[-1]
|
||||
if device.type != 'cpu':
|
||||
assert gpu, f'{name} inference not supported on GPU'
|
||||
if f == '-':
|
||||
w = weights # PyTorch format
|
||||
else:
|
||||
w = export.run(weights=weights, imgsz=[imgsz], include=[f], device=device, half=half)[-1] # all others
|
||||
assert suffix in str(w), 'export failed'
|
||||
result = val.run(data, w, batch_size, imgsz=imgsz, plots=False, device='cpu', task='benchmark')
|
||||
result = val.run(data, w, batch_size, imgsz, plots=False, device=device, task='benchmark', half=half)
|
||||
metrics = result[0] # metrics (mp, mr, map50, map, *losses(box, obj, cls))
|
||||
speeds = result[2] # times (preprocess, inference, postprocess)
|
||||
y.append([name, metrics[3], speeds[1]]) # mAP, t_inference
|
||||
|
@ -78,6 +88,8 @@ def parse_opt():
|
|||
parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)')
|
||||
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
|
||||
parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
|
||||
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
||||
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
|
||||
opt = parser.parse_args()
|
||||
print_args(FILE.stem, opt)
|
||||
return opt
|
||||
|
|
Loading…
Reference in New Issue