fast-reid/tools/deploy/trt_export.py

167 lines
5.5 KiB
Python

# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import argparse
import os
import sys
import tensorrt as trt
from trt_calibrator import FeatEntropyCalibrator
sys.path.append('.')
from fastreid.utils.logger import setup_logger, PathManager
logger = setup_logger(name="trt_export")
def get_parser():
parser = argparse.ArgumentParser(description="Convert ONNX to TRT model")
parser.add_argument(
'--name',
default='baseline',
help="name for converted model"
)
parser.add_argument(
'--output',
default='outputs/trt_model',
help="path to save converted trt model"
)
parser.add_argument(
'--mode',
default='fp32',
help="which mode is used in tensorRT engine, mode can be ['fp32', 'fp16' 'int8']"
)
parser.add_argument(
'--batch-size',
default=1,
type=int,
help="the maximum batch size of trt module"
)
parser.add_argument(
'--height',
default=256,
type=int,
help="input image height"
)
parser.add_argument(
'--width',
default=128,
type=int,
help="input image width"
)
parser.add_argument(
'--channel',
default=3,
type=int,
help="input image channel"
)
parser.add_argument(
'--calib-data',
default='Market1501',
help="int8 calibrator dataset name"
)
parser.add_argument(
"--onnx-model",
default='outputs/onnx_model/baseline.onnx',
help='path to onnx model'
)
return parser
def onnx2trt(
onnx_file_path,
save_path,
mode,
log_level='ERROR',
max_workspace_size=1,
strict_type_constraints=False,
int8_calibrator=None,
):
"""build TensorRT model from onnx model.
Args:
onnx_file_path (string or io object): onnx model name
save_path (string): tensortRT serialization save path
mode (string): Whether or not FP16 or Int8 kernels are permitted during engine build.
log_level (string, default is ERROR): tensorrt logger level, now
INTERNAL_ERROR, ERROR, WARNING, INFO, VERBOSE are support.
max_workspace_size (int, default is 1): The maximum GPU temporary memory which the ICudaEngine can use at
execution time. default is 1GB.
strict_type_constraints (bool, default is False): When strict type constraints is set, TensorRT will choose
the type constraints that conforms to type constraints. If the flag is not enabled higher precision
implementation may be chosen if it results in higher performance.
int8_calibrator (volksdep.calibrators.base.BaseCalibrator, default is None): calibrator for int8 mode,
if None, default calibrator will be used as calibration data.
"""
mode = mode.lower()
assert mode in ['fp32', 'fp16', 'int8'], "mode should be in ['fp32', 'fp16', 'int8'], " \
"but got {}".format(mode)
trt_logger = trt.Logger(getattr(trt.Logger, log_level))
builder = trt.Builder(trt_logger)
logger.info("Loading ONNX file from path {}...".format(onnx_file_path))
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network = builder.create_network(EXPLICIT_BATCH)
parser = trt.OnnxParser(network, trt_logger)
if isinstance(onnx_file_path, str):
with open(onnx_file_path, 'rb') as f:
logger.info("Beginning ONNX file parsing")
flag = parser.parse(f.read())
else:
flag = parser.parse(onnx_file_path.read())
if not flag:
for error in range(parser.num_errors):
logger.info(parser.get_error(error))
logger.info("Completed parsing of ONNX file.")
# re-order output tensor
output_tensors = [network.get_output(i) for i in range(network.num_outputs)]
[network.unmark_output(tensor) for tensor in output_tensors]
for tensor in output_tensors:
identity_out_tensor = network.add_identity(tensor).get_output(0)
identity_out_tensor.name = 'identity_{}'.format(tensor.name)
network.mark_output(tensor=identity_out_tensor)
config = builder.create_builder_config()
config.max_workspace_size = max_workspace_size * (1 << 25)
if mode == 'fp16':
assert builder.platform_has_fast_fp16, "not support fp16"
builder.fp16_mode = True
if mode == 'int8':
assert builder.platform_has_fast_int8, "not support int8"
builder.int8_mode = True
builder.int8_calibrator = int8_calibrator
if strict_type_constraints:
config.set_flag(trt.BuilderFlag.STRICT_TYPES)
logger.info("Building an engine from file {}; this may take a while...".format(onnx_file_path))
engine = builder.build_cuda_engine(network)
logger.info("Create engine successfully!")
logger.info("Saving TRT engine file to path {}".format(save_path))
with open(save_path, 'wb') as f:
f.write(engine.serialize())
logger.info("Engine file has already saved to {}!".format(save_path))
if __name__ == '__main__':
args = get_parser().parse_args()
onnx_file_path = args.onnx_model
engineFile = os.path.join(args.output, args.name + '.engine')
if args.mode.lower() == 'int8':
int8_calib = FeatEntropyCalibrator(args)
else:
int8_calib = None
PathManager.mkdirs(args.output)
onnx2trt(onnx_file_path, engineFile, args.mode, int8_calibrator=int8_calib)