fast-reid/tools/deploy/trt_export.py

83 lines
1.9 KiB
Python
Raw Normal View History

# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import argparse
import os
import numpy as np
import sys
sys.path.append('../../')
sys.path.append("/export/home/lxy/runtimelib-tensorrt-tiny/build")
import pytrt
from fastreid.utils.logger import setup_logger
from fastreid.utils.file_io import PathManager
logger = setup_logger(name='trt_export')
def get_parser():
parser = argparse.ArgumentParser(description="Convert ONNX to TRT model")
parser.add_argument(
"--name",
default="baseline",
help="name for converted model"
)
parser.add_argument(
"--output",
default='outputs/trt_model',
help='path to save converted trt model'
)
parser.add_argument(
"--onnx-model",
default='outputs/onnx_model/baseline.onnx',
help='path to onnx model'
)
parser.add_argument(
"--height",
type=int,
default=256,
help="height of image"
)
parser.add_argument(
"--width",
type=int,
default=128,
help="width of image"
)
return parser
def export_trt_model(onnxModel, engineFile, input_numpy_array):
r"""
Export a model to trt format.
"""
trt = pytrt.Trt()
customOutput = []
maxBatchSize = 1
calibratorData = []
mode = 2
trt.CreateEngine(onnxModel, engineFile, customOutput, maxBatchSize, mode, calibratorData)
trt.DoInference(input_numpy_array) # slightly different from c++
return 0
if __name__ == '__main__':
args = get_parser().parse_args()
inputs = np.zeros(shape=(32, args.height, args.width, 3))
onnxModel = args.onnx_model
engineFile = os.path.join(args.output, args.name+'.engine')
PathManager.mkdirs(args.output)
export_trt_model(onnxModel, engineFile, inputs)
logger.info(f"Export trt model in {args.output} successfully!")