mirror of https://github.com/JDAI-CV/fast-reid.git
83 lines
1.9 KiB
Python
83 lines
1.9 KiB
Python
|
# encoding: utf-8
|
||
|
"""
|
||
|
@author: xingyu liao
|
||
|
@contact: sherlockliao01@gmail.com
|
||
|
"""
|
||
|
|
||
|
import argparse
|
||
|
import os
|
||
|
import numpy as np
|
||
|
import sys
|
||
|
|
||
|
sys.path.append('../../')
|
||
|
sys.path.append("/export/home/lxy/runtimelib-tensorrt-tiny/build")
|
||
|
|
||
|
import pytrt
|
||
|
from fastreid.utils.logger import setup_logger
|
||
|
from fastreid.utils.file_io import PathManager
|
||
|
|
||
|
|
||
|
logger = setup_logger(name='trt_export')
|
||
|
|
||
|
|
||
|
def get_parser():
|
||
|
parser = argparse.ArgumentParser(description="Convert ONNX to TRT model")
|
||
|
|
||
|
parser.add_argument(
|
||
|
"--name",
|
||
|
default="baseline",
|
||
|
help="name for converted model"
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--output",
|
||
|
default='outputs/trt_model',
|
||
|
help='path to save converted trt model'
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--onnx-model",
|
||
|
default='outputs/onnx_model/baseline.onnx',
|
||
|
help='path to onnx model'
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--height",
|
||
|
type=int,
|
||
|
default=256,
|
||
|
help="height of image"
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--width",
|
||
|
type=int,
|
||
|
default=128,
|
||
|
help="width of image"
|
||
|
)
|
||
|
return parser
|
||
|
|
||
|
|
||
|
def export_trt_model(onnxModel, engineFile, input_numpy_array):
|
||
|
r"""
|
||
|
Export a model to trt format.
|
||
|
"""
|
||
|
|
||
|
trt = pytrt.Trt()
|
||
|
|
||
|
customOutput = []
|
||
|
maxBatchSize = 1
|
||
|
calibratorData = []
|
||
|
mode = 2
|
||
|
trt.CreateEngine(onnxModel, engineFile, customOutput, maxBatchSize, mode, calibratorData)
|
||
|
trt.DoInference(input_numpy_array) # slightly different from c++
|
||
|
return 0
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
args = get_parser().parse_args()
|
||
|
|
||
|
inputs = np.zeros(shape=(32, args.height, args.width, 3))
|
||
|
onnxModel = args.onnx_model
|
||
|
engineFile = os.path.join(args.output, args.name+'.engine')
|
||
|
|
||
|
PathManager.mkdirs(args.output)
|
||
|
export_trt_model(onnxModel, engineFile, inputs)
|
||
|
|
||
|
logger.info(f"Export trt model in {args.output} successfully!")
|