add batch_size to rknn export

pull/13506/head
Edward Yang 2025-02-03 09:48:49 +11:00
parent 30bc65fe02
commit a9f7979c83
1 changed files with 8 additions and 4 deletions

View File

@ -251,7 +251,7 @@ def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose
return f, None
@try_export
def export_rknn(model, int8, data, prefix=colorstr('RKNN:')):
def export_rknn(model, batch_size, int8, data, prefix=colorstr('RKNN:')):
# YOLOv5 RKNN export
check_requirements('rknn-toolkit2')
from rknn.api import RKNN
@ -261,7 +261,7 @@ def export_rknn(model, int8, data, prefix=colorstr('RKNN:')):
[255, 255, 255]], target_platform=os.getenv("RKNN_PLATFORM", "rk3588").lower())
rknn.load_onnx(model=str(model.with_suffix(".onnx")))
rknn.build(do_quantization=int8, dataset=data)
rknn.build(do_quantization=int8, dataset=data, rknn_batch_size=batch_size)
f = model.with_suffix('.rknn')
rknn.export_rknn(str(f))
rknn.release()
@ -309,7 +309,11 @@ def run(
# Input
gs = int(max(model.stride)) # grid size (max stride)
imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiples
im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
if rknpu:
if batch_size != 1: LOGGER.info(f'Ignoring batch size in ONNX export for RKNN export')
im = torch.zeros(1, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
else:
im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
# Update model
model.eval()
@ -402,7 +406,7 @@ def run(
if xml: # OpenVINO
f[2], _ = export_openvino(file, metadata, half)
if rknpu:
f[3], _ = export_rknn(file, int8, data)
f[3], _ = export_rknn(file, batch_size, int8, data)
# Finish
f = [str(x) for x in f if x] # filter out '' and None