mirror of https://github.com/JDAI-CV/fast-reid.git
167 lines
5.5 KiB
Python
167 lines
5.5 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: xingyu liao
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
|
|
import tensorrt as trt
|
|
|
|
from trt_calibrator import FeatEntropyCalibrator
|
|
|
|
sys.path.append('.')
|
|
|
|
from fastreid.utils.logger import setup_logger, PathManager
|
|
|
|
logger = setup_logger(name="trt_export")
|
|
|
|
|
|
def get_parser():
|
|
parser = argparse.ArgumentParser(description="Convert ONNX to TRT model")
|
|
|
|
parser.add_argument(
|
|
'--name',
|
|
default='baseline',
|
|
help="name for converted model"
|
|
)
|
|
parser.add_argument(
|
|
'--output',
|
|
default='outputs/trt_model',
|
|
help="path to save converted trt model"
|
|
)
|
|
parser.add_argument(
|
|
'--mode',
|
|
default='fp32',
|
|
help="which mode is used in tensorRT engine, mode can be ['fp32', 'fp16' 'int8']"
|
|
)
|
|
parser.add_argument(
|
|
'--batch-size',
|
|
default=1,
|
|
type=int,
|
|
help="the maximum batch size of trt module"
|
|
)
|
|
parser.add_argument(
|
|
'--height',
|
|
default=256,
|
|
type=int,
|
|
help="input image height"
|
|
)
|
|
parser.add_argument(
|
|
'--width',
|
|
default=128,
|
|
type=int,
|
|
help="input image width"
|
|
)
|
|
parser.add_argument(
|
|
'--channel',
|
|
default=3,
|
|
type=int,
|
|
help="input image channel"
|
|
)
|
|
parser.add_argument(
|
|
'--calib-data',
|
|
default='Market1501',
|
|
help="int8 calibrator dataset name"
|
|
)
|
|
parser.add_argument(
|
|
"--onnx-model",
|
|
default='outputs/onnx_model/baseline.onnx',
|
|
help='path to onnx model'
|
|
)
|
|
return parser
|
|
|
|
|
|
def onnx2trt(
|
|
onnx_file_path,
|
|
save_path,
|
|
mode,
|
|
log_level='ERROR',
|
|
max_workspace_size=1,
|
|
strict_type_constraints=False,
|
|
int8_calibrator=None,
|
|
):
|
|
"""build TensorRT model from onnx model.
|
|
Args:
|
|
onnx_file_path (string or io object): onnx model name
|
|
save_path (string): tensortRT serialization save path
|
|
mode (string): Whether or not FP16 or Int8 kernels are permitted during engine build.
|
|
log_level (string, default is ERROR): tensorrt logger level, now
|
|
INTERNAL_ERROR, ERROR, WARNING, INFO, VERBOSE are support.
|
|
max_workspace_size (int, default is 1): The maximum GPU temporary memory which the ICudaEngine can use at
|
|
execution time. default is 1GB.
|
|
strict_type_constraints (bool, default is False): When strict type constraints is set, TensorRT will choose
|
|
the type constraints that conforms to type constraints. If the flag is not enabled higher precision
|
|
implementation may be chosen if it results in higher performance.
|
|
int8_calibrator (volksdep.calibrators.base.BaseCalibrator, default is None): calibrator for int8 mode,
|
|
if None, default calibrator will be used as calibration data.
|
|
"""
|
|
mode = mode.lower()
|
|
assert mode in ['fp32', 'fp16', 'int8'], "mode should be in ['fp32', 'fp16', 'int8'], " \
|
|
"but got {}".format(mode)
|
|
|
|
trt_logger = trt.Logger(getattr(trt.Logger, log_level))
|
|
builder = trt.Builder(trt_logger)
|
|
|
|
logger.info("Loading ONNX file from path {}...".format(onnx_file_path))
|
|
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
|
network = builder.create_network(EXPLICIT_BATCH)
|
|
parser = trt.OnnxParser(network, trt_logger)
|
|
if isinstance(onnx_file_path, str):
|
|
with open(onnx_file_path, 'rb') as f:
|
|
logger.info("Beginning ONNX file parsing")
|
|
flag = parser.parse(f.read())
|
|
else:
|
|
flag = parser.parse(onnx_file_path.read())
|
|
if not flag:
|
|
for error in range(parser.num_errors):
|
|
logger.info(parser.get_error(error))
|
|
|
|
logger.info("Completed parsing of ONNX file.")
|
|
# re-order output tensor
|
|
output_tensors = [network.get_output(i) for i in range(network.num_outputs)]
|
|
[network.unmark_output(tensor) for tensor in output_tensors]
|
|
for tensor in output_tensors:
|
|
identity_out_tensor = network.add_identity(tensor).get_output(0)
|
|
identity_out_tensor.name = 'identity_{}'.format(tensor.name)
|
|
network.mark_output(tensor=identity_out_tensor)
|
|
|
|
config = builder.create_builder_config()
|
|
config.max_workspace_size = max_workspace_size * (1 << 25)
|
|
if mode == 'fp16':
|
|
assert builder.platform_has_fast_fp16, "not support fp16"
|
|
builder.fp16_mode = True
|
|
if mode == 'int8':
|
|
assert builder.platform_has_fast_int8, "not support int8"
|
|
builder.int8_mode = True
|
|
builder.int8_calibrator = int8_calibrator
|
|
|
|
if strict_type_constraints:
|
|
config.set_flag(trt.BuilderFlag.STRICT_TYPES)
|
|
|
|
logger.info("Building an engine from file {}; this may take a while...".format(onnx_file_path))
|
|
engine = builder.build_cuda_engine(network)
|
|
logger.info("Create engine successfully!")
|
|
|
|
logger.info("Saving TRT engine file to path {}".format(save_path))
|
|
with open(save_path, 'wb') as f:
|
|
f.write(engine.serialize())
|
|
logger.info("Engine file has already saved to {}!".format(save_path))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = get_parser().parse_args()
|
|
|
|
onnx_file_path = args.onnx_model
|
|
engineFile = os.path.join(args.output, args.name + '.engine')
|
|
|
|
if args.mode.lower() == 'int8':
|
|
int8_calib = FeatEntropyCalibrator(args)
|
|
else:
|
|
int8_calib = None
|
|
|
|
PathManager.mkdirs(args.output)
|
|
onnx2trt(onnx_file_path, engineFile, args.mode, int8_calibrator=int8_calib)
|