mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
add batch_size to rknn export
This commit is contained in:
parent
30bc65fe02
commit
a9f7979c83
12
export.py
12
export.py
@ -251,7 +251,7 @@ def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose
|
|||||||
return f, None
|
return f, None
|
||||||
|
|
||||||
@try_export
|
@try_export
|
||||||
def export_rknn(model, int8, data, prefix=colorstr('RKNN:')):
|
def export_rknn(model, batch_size, int8, data, prefix=colorstr('RKNN:')):
|
||||||
# YOLOv5 RKNN export
|
# YOLOv5 RKNN export
|
||||||
check_requirements('rknn-toolkit2')
|
check_requirements('rknn-toolkit2')
|
||||||
from rknn.api import RKNN
|
from rknn.api import RKNN
|
||||||
@ -261,7 +261,7 @@ def export_rknn(model, int8, data, prefix=colorstr('RKNN:')):
|
|||||||
[255, 255, 255]], target_platform=os.getenv("RKNN_PLATFORM", "rk3588").lower())
|
[255, 255, 255]], target_platform=os.getenv("RKNN_PLATFORM", "rk3588").lower())
|
||||||
|
|
||||||
rknn.load_onnx(model=str(model.with_suffix(".onnx")))
|
rknn.load_onnx(model=str(model.with_suffix(".onnx")))
|
||||||
rknn.build(do_quantization=int8, dataset=data)
|
rknn.build(do_quantization=int8, dataset=data, rknn_batch_size=batch_size)
|
||||||
f = model.with_suffix('.rknn')
|
f = model.with_suffix('.rknn')
|
||||||
rknn.export_rknn(str(f))
|
rknn.export_rknn(str(f))
|
||||||
rknn.release()
|
rknn.release()
|
||||||
@ -309,7 +309,11 @@ def run(
|
|||||||
# Input
|
# Input
|
||||||
gs = int(max(model.stride)) # grid size (max stride)
|
gs = int(max(model.stride)) # grid size (max stride)
|
||||||
imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiples
|
imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiples
|
||||||
im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
|
if rknpu:
|
||||||
|
if batch_size != 1: LOGGER.info(f'Ignoring batch size in ONNX export for RKNN export')
|
||||||
|
im = torch.zeros(1, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
|
||||||
|
else:
|
||||||
|
im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
|
||||||
|
|
||||||
# Update model
|
# Update model
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -402,7 +406,7 @@ def run(
|
|||||||
if xml: # OpenVINO
|
if xml: # OpenVINO
|
||||||
f[2], _ = export_openvino(file, metadata, half)
|
f[2], _ = export_openvino(file, metadata, half)
|
||||||
if rknpu:
|
if rknpu:
|
||||||
f[3], _ = export_rknn(file, int8, data)
|
f[3], _ = export_rknn(file, batch_size, int8, data)
|
||||||
|
|
||||||
# Finish
|
# Finish
|
||||||
f = [str(x) for x in f if x] # filter out '' and None
|
f = [str(x) for x in f if x] # filter out '' and None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user