add batch_size to rknn export
parent
30bc65fe02
commit
a9f7979c83
12
export.py
12
export.py
|
@ -251,7 +251,7 @@ def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose
|
|||
return f, None
|
||||
|
||||
@try_export
|
||||
def export_rknn(model, int8, data, prefix=colorstr('RKNN:')):
|
||||
def export_rknn(model, batch_size, int8, data, prefix=colorstr('RKNN:')):
|
||||
# YOLOv5 RKNN export
|
||||
check_requirements('rknn-toolkit2')
|
||||
from rknn.api import RKNN
|
||||
|
@ -261,7 +261,7 @@ def export_rknn(model, int8, data, prefix=colorstr('RKNN:')):
|
|||
[255, 255, 255]], target_platform=os.getenv("RKNN_PLATFORM", "rk3588").lower())
|
||||
|
||||
rknn.load_onnx(model=str(model.with_suffix(".onnx")))
|
||||
rknn.build(do_quantization=int8, dataset=data)
|
||||
rknn.build(do_quantization=int8, dataset=data, rknn_batch_size=batch_size)
|
||||
f = model.with_suffix('.rknn')
|
||||
rknn.export_rknn(str(f))
|
||||
rknn.release()
|
||||
|
@ -309,7 +309,11 @@ def run(
|
|||
# Input
|
||||
gs = int(max(model.stride)) # grid size (max stride)
|
||||
imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiples
|
||||
im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
|
||||
if rknpu:
|
||||
if batch_size != 1: LOGGER.info(f'Ignoring batch size in ONNX export for RKNN export')
|
||||
im = torch.zeros(1, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
|
||||
else:
|
||||
im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
|
||||
|
||||
# Update model
|
||||
model.eval()
|
||||
|
@ -402,7 +406,7 @@ def run(
|
|||
if xml: # OpenVINO
|
||||
f[2], _ = export_openvino(file, metadata, half)
|
||||
if rknpu:
|
||||
f[3], _ = export_rknn(file, int8, data)
|
||||
f[3], _ = export_rknn(file, batch_size, int8, data)
|
||||
|
||||
# Finish
|
||||
f = [str(x) for x in f if x] # filter out '' and None
|
||||
|
|
Loading…
Reference in New Issue