# encoding: utf-8 """ @author: xingyu liao @contact: sherlockliao01@gmail.com """ import argparse import os import sys import tensorrt as trt from trt_calibrator import FeatEntropyCalibrator sys.path.append('.') from fastreid.utils.logger import setup_logger, PathManager logger = setup_logger(name="trt_export") def get_parser(): parser = argparse.ArgumentParser(description="Convert ONNX to TRT model") parser.add_argument( '--name', default='baseline', help="name for converted model" ) parser.add_argument( '--output', default='outputs/trt_model', help="path to save converted trt model" ) parser.add_argument( '--mode', default='fp32', help="which mode is used in tensorRT engine, mode can be ['fp32', 'fp16' 'int8']" ) parser.add_argument( '--batch-size', default=1, type=int, help="the maximum batch size of trt module" ) parser.add_argument( '--height', default=256, type=int, help="input image height" ) parser.add_argument( '--width', default=128, type=int, help="input image width" ) parser.add_argument( '--channel', default=3, type=int, help="input image channel" ) parser.add_argument( '--calib-data', default='Market1501', help="int8 calibrator dataset name" ) parser.add_argument( "--onnx-model", default='outputs/onnx_model/baseline.onnx', help='path to onnx model' ) return parser def onnx2trt( onnx_file_path, save_path, mode, log_level='ERROR', max_workspace_size=1, strict_type_constraints=False, int8_calibrator=None, ): """build TensorRT model from onnx model. Args: onnx_file_path (string or io object): onnx model name save_path (string): tensortRT serialization save path mode (string): Whether or not FP16 or Int8 kernels are permitted during engine build. log_level (string, default is ERROR): tensorrt logger level, now INTERNAL_ERROR, ERROR, WARNING, INFO, VERBOSE are support. max_workspace_size (int, default is 1): The maximum GPU temporary memory which the ICudaEngine can use at execution time. default is 1GB. strict_type_constraints (bool, default is False): When strict type constraints is set, TensorRT will choose the type constraints that conforms to type constraints. If the flag is not enabled higher precision implementation may be chosen if it results in higher performance. int8_calibrator (volksdep.calibrators.base.BaseCalibrator, default is None): calibrator for int8 mode, if None, default calibrator will be used as calibration data. """ mode = mode.lower() assert mode in ['fp32', 'fp16', 'int8'], "mode should be in ['fp32', 'fp16', 'int8'], " \ "but got {}".format(mode) trt_logger = trt.Logger(getattr(trt.Logger, log_level)) builder = trt.Builder(trt_logger) logger.info("Loading ONNX file from path {}...".format(onnx_file_path)) EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) network = builder.create_network(EXPLICIT_BATCH) parser = trt.OnnxParser(network, trt_logger) if isinstance(onnx_file_path, str): with open(onnx_file_path, 'rb') as f: logger.info("Beginning ONNX file parsing") flag = parser.parse(f.read()) else: flag = parser.parse(onnx_file_path.read()) if not flag: for error in range(parser.num_errors): logger.info(parser.get_error(error)) logger.info("Completed parsing of ONNX file.") # re-order output tensor output_tensors = [network.get_output(i) for i in range(network.num_outputs)] [network.unmark_output(tensor) for tensor in output_tensors] for tensor in output_tensors: identity_out_tensor = network.add_identity(tensor).get_output(0) identity_out_tensor.name = 'identity_{}'.format(tensor.name) network.mark_output(tensor=identity_out_tensor) config = builder.create_builder_config() config.max_workspace_size = max_workspace_size * (1 << 25) if mode == 'fp16': assert builder.platform_has_fast_fp16, "not support fp16" builder.fp16_mode = True if mode == 'int8': assert builder.platform_has_fast_int8, "not support int8" builder.int8_mode = True builder.int8_calibrator = int8_calibrator if strict_type_constraints: config.set_flag(trt.BuilderFlag.STRICT_TYPES) logger.info("Building an engine from file {}; this may take a while...".format(onnx_file_path)) engine = builder.build_cuda_engine(network) logger.info("Create engine successfully!") logger.info("Saving TRT engine file to path {}".format(save_path)) with open(save_path, 'wb') as f: f.write(engine.serialize()) logger.info("Engine file has already saved to {}!".format(save_path)) if __name__ == '__main__': args = get_parser().parse_args() onnx_file_path = args.onnx_model engineFile = os.path.join(args.output, args.name + '.engine') if args.mode.lower() == 'int8': int8_calib = FeatEntropyCalibrator(args) else: int8_calib = None PathManager.mkdirs(args.output) onnx2trt(onnx_file_path, engineFile, args.mode, int8_calibrator=int8_calib)