# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import logging

from mmdeploy.backend.tensorrt import from_onnx
from mmdeploy.backend.tensorrt.utils import get_trt_log_level
from mmdeploy.utils import (get_common_config, get_model_inputs,
                            get_root_logger, load_config)


def parse_args():
    parser = argparse.ArgumentParser(description='Convert ONNX to TensorRT.')
    parser.add_argument('deploy_cfg', help='deploy config path')
    parser.add_argument('onnx_path', help='ONNX model path')
    parser.add_argument('output_prefix', help='output TensorRT engine prefix')
    parser.add_argument('--device-id', help='`the CUDA device id', default=0)
    parser.add_argument(
        '--calib-file',
        help='`the calibration data used to calibrate engine to int8',
        default=None)
    parser.add_argument(
        '--log-level',
        help='set log level',
        default='INFO',
        choices=list(logging._nameToLevel.keys()))
    args = parser.parse_args()

    return args


def main():
    args = parse_args()
    logger = get_root_logger(log_level=args.log_level)

    deploy_cfg_path = args.deploy_cfg
    deploy_cfg = load_config(deploy_cfg_path)[0]
    onnx_path = args.onnx_path
    output_prefix = args.output_prefix
    device_id = args.device_id
    calib_file = args.calib_file

    model_id = 0
    common_params = get_common_config(deploy_cfg)
    model_params = get_model_inputs(deploy_cfg)[model_id]

    final_params = common_params
    final_params.update(model_params)

    int8_param = final_params.get('int8_param', dict())

    if calib_file is not None:
        int8_param['calib_file'] = calib_file
        # do not support partition model calibration for now
        int8_param['model_type'] = 'end2end'

    logger.info(f'onnx2tensorrt: \n\tonnx_path: {onnx_path} '
                f'\n\tdeploy_cfg: {deploy_cfg_path}')
    try:
        from_onnx(
            onnx_path,
            output_prefix,
            input_shapes=final_params['input_shapes'],
            log_level=get_trt_log_level(),
            fp16_mode=final_params.get('fp16_mode', False),
            int8_mode=final_params.get('int8_mode', False),
            int8_param=int8_param,
            max_workspace_size=final_params.get('max_workspace_size', 0),
            device_id=device_id)

        logger.info('onnx2tensorrt success.')
    except Exception as e:
        logger.error(e)
        logger.error('onnx2tensorrt failed.')


if __name__ == '__main__':
    main()