# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import logging
import os.path as osp
from functools import partial

import mmcv
import torch.multiprocessing as mp
from torch.multiprocessing import Process, set_start_method

from mmdeploy.apis import (create_calib_table, extract_model,
                           get_predefined_partition_cfg, torch2onnx,
                           torch2torchscript, visualize_model)
from mmdeploy.utils import (IR, Backend, get_backend, get_calib_filename,
                            get_ir_config, get_model_inputs,
                            get_partition_config, get_root_logger, load_config,
                            target_wrapper)
from mmdeploy.utils.export_info import dump_info


def parse_args():
    parser = argparse.ArgumentParser(description='Export model to backends.')
    parser.add_argument('deploy_cfg', help='deploy config path')
    parser.add_argument('model_cfg', help='model config path')
    parser.add_argument('checkpoint', help='model checkpoint path')
    parser.add_argument('img', help='image used to convert model model')
    parser.add_argument(
        '--test-img', default=None, help='image used to test model')
    parser.add_argument('--work-dir', help='the dir to save logs and models')
    parser.add_argument(
        '--calib-dataset-cfg',
        help='dataset config path used to calibrate in int8 mode. If not \
            specified,it will use "val" dataset in model config instead.',
        default=None)
    parser.add_argument(
        '--device', help='device used for conversion', default='cpu')
    parser.add_argument(
        '--log-level',
        help='set log level',
        default='INFO',
        choices=list(logging._nameToLevel.keys()))
    parser.add_argument(
        '--show', action='store_true', help='Show detection outputs')
    parser.add_argument(
        '--dump-info', action='store_true', help='Output information for SDK')
    args = parser.parse_args()

    return args


def create_process(name, target, args, kwargs, ret_value=None):
    logger = get_root_logger()
    logger.info(f'{name} start.')
    log_level = logger.level

    wrap_func = partial(target_wrapper, target, log_level, ret_value)

    process = Process(target=wrap_func, args=args, kwargs=kwargs)
    process.start()
    process.join()

    if ret_value is not None:
        if ret_value.value != 0:
            logger.error(f'{name} failed.')
            exit(1)
        else:
            logger.info(f'{name} success.')


def torch2ir(ir_type: IR):
    """Return the conversion function from torch to the intermediate
    representation.

    Args:
        ir_type (IR): The type of the intermediate representation.
    """
    if ir_type == IR.ONNX:
        return torch2onnx
    elif ir_type == IR.TORCHSCRIPT:
        return torch2torchscript
    else:
        raise KeyError(f'Unexpected IR type {ir_type}')


def main():
    args = parse_args()
    set_start_method('spawn')
    logger = get_root_logger()
    logger.setLevel(args.log_level)

    deploy_cfg_path = args.deploy_cfg
    model_cfg_path = args.model_cfg
    checkpoint_path = args.checkpoint

    # load deploy_cfg
    deploy_cfg, model_cfg = load_config(deploy_cfg_path, model_cfg_path)

    # create work_dir if not
    mmcv.mkdir_or_exist(osp.abspath(args.work_dir))

    if args.dump_info:
        dump_info(deploy_cfg, model_cfg, args.work_dir, pth=checkpoint_path)

    ret_value = mp.Value('d', 0, lock=False)

    # convert to IR
    ir_config = get_ir_config(deploy_cfg)
    ir_save_file = ir_config['save_file']
    ir_type = IR.get(ir_config['type'])
    create_process(
        f'torch2{ir_type.value}',
        target=torch2ir(ir_type),
        args=(args.img, args.work_dir, ir_save_file, deploy_cfg_path,
              model_cfg_path, checkpoint_path),
        kwargs=dict(device=args.device),
        ret_value=ret_value)

    # convert backend
    ir_files = [osp.join(args.work_dir, ir_save_file)]

    # partition model
    partition_cfgs = get_partition_config(deploy_cfg)

    if partition_cfgs is not None:

        if 'partition_cfg' in partition_cfgs:
            partition_cfgs = partition_cfgs.get('partition_cfg', None)
        else:
            assert 'type' in partition_cfgs
            partition_cfgs = get_predefined_partition_cfg(
                deploy_cfg, partition_cfgs['type'])

        origin_ir_file = ir_files[0]
        ir_files = []
        for partition_cfg in partition_cfgs:
            save_file = partition_cfg['save_file']
            save_path = osp.join(args.work_dir, save_file)
            start = partition_cfg['start']
            end = partition_cfg['end']
            dynamic_axes = partition_cfg.get('dynamic_axes', None)

            create_process(
                f'partition model {save_file} with start: {start}, end: {end}',
                extract_model,
                args=(origin_ir_file, start, end),
                kwargs=dict(dynamic_axes=dynamic_axes, save_file=save_path),
                ret_value=ret_value)

            ir_files.append(save_path)

    # calib data
    calib_filename = get_calib_filename(deploy_cfg)
    if calib_filename is not None:
        calib_path = osp.join(args.work_dir, calib_filename)

        create_process(
            'calibration',
            create_calib_table,
            args=(calib_path, deploy_cfg_path, model_cfg_path,
                  checkpoint_path),
            kwargs=dict(
                dataset_cfg=args.calib_dataset_cfg,
                dataset_type='val',
                device=args.device),
            ret_value=ret_value)

    backend_files = ir_files
    # convert backend
    backend = get_backend(deploy_cfg)
    if backend == Backend.TENSORRT:
        model_params = get_model_inputs(deploy_cfg)
        assert len(model_params) == len(ir_files)

        from mmdeploy.apis.tensorrt import is_available as trt_is_available
        from mmdeploy.apis.tensorrt import onnx2tensorrt
        assert trt_is_available(
        ), 'TensorRT is not available,' \
            + ' please install TensorRT and build TensorRT custom ops first.'
        backend_files = []
        for model_id, model_param, onnx_path in zip(
                range(len(ir_files)), model_params, ir_files):
            onnx_name = osp.splitext(osp.split(onnx_path)[1])[0]
            save_file = model_param.get('save_file', onnx_name + '.engine')

            partition_type = 'end2end' if partition_cfgs is None \
                else onnx_name
            create_process(
                f'onnx2tensorrt of {onnx_path}',
                target=onnx2tensorrt,
                args=(args.work_dir, save_file, model_id, deploy_cfg_path,
                      onnx_path),
                kwargs=dict(device=args.device, partition_type=partition_type),
                ret_value=ret_value)

            backend_files.append(osp.join(args.work_dir, save_file))

    elif backend == Backend.NCNN:
        from mmdeploy.apis.ncnn import is_available as is_available_ncnn

        if not is_available_ncnn():
            logger.error('ncnn support is not available.')
            exit(1)

        from mmdeploy.apis.ncnn import get_output_model_file, onnx2ncnn

        backend_files = []
        for onnx_path in ir_files:
            model_param_path, model_bin_path = get_output_model_file(
                onnx_path, args.work_dir)
            create_process(
                f'onnx2ncnn with {onnx_path}',
                target=onnx2ncnn,
                args=(onnx_path, model_param_path, model_bin_path),
                kwargs=dict(),
                ret_value=ret_value)
            backend_files += [model_param_path, model_bin_path]

    elif backend == Backend.OPENVINO:
        from mmdeploy.apis.openvino import \
            is_available as is_available_openvino
        assert is_available_openvino(), \
            'OpenVINO is not available, please install OpenVINO first.'

        from mmdeploy.apis.openvino import (get_input_info_from_cfg,
                                            get_mo_options_from_cfg,
                                            get_output_model_file,
                                            onnx2openvino)
        openvino_files = []
        for onnx_path in ir_files:
            model_xml_path = get_output_model_file(onnx_path, args.work_dir)
            input_info = get_input_info_from_cfg(deploy_cfg)
            output_names = get_ir_config(deploy_cfg).output_names
            mo_options = get_mo_options_from_cfg(deploy_cfg)
            create_process(
                f'onnx2openvino with {onnx_path}',
                target=onnx2openvino,
                args=(input_info, output_names, onnx_path, args.work_dir,
                      mo_options),
                kwargs=dict(),
                ret_value=ret_value)
            openvino_files.append(model_xml_path)
        backend_files = openvino_files

    elif backend == Backend.PPLNN:
        from mmdeploy.apis.pplnn import is_available as is_available_pplnn
        assert is_available_pplnn(), \
            'PPLNN is not available, please install PPLNN first.'

        from mmdeploy.apis.pplnn import onnx2pplnn
        pplnn_files = []
        for onnx_path in ir_files:
            algo_file = onnx_path.replace('.onnx', '.json')
            model_inputs = get_model_inputs(deploy_cfg)
            assert 'opt_shape' in model_inputs, 'Expect opt_shape ' \
                'in deploy config for PPLNN'
            # PPLNN accepts only 1 input shape for optimization,
            # may get changed in the future
            input_shapes = [model_inputs.opt_shape]
            create_process(
                f'onnx2pplnn with {onnx_path}',
                target=onnx2pplnn,
                args=(algo_file, onnx_path),
                kwargs=dict(device=args.device, input_shapes=input_shapes),
                ret_value=ret_value)
            pplnn_files += [onnx_path, algo_file]
        backend_files = pplnn_files

    if args.test_img is None:
        args.test_img = args.img
    import os
    is_display = os.getenv('DISPLAY')
    # for headless installation.
    if is_display is not None:
        # visualize model of the backend
        create_process(
            f'visualize {backend.value} model',
            target=visualize_model,
            args=(model_cfg_path, deploy_cfg_path, backend_files,
                  args.test_img, args.device),
            kwargs=dict(
                backend=backend,
                output_file=osp.join(args.work_dir,
                                     f'output_{backend.value}.jpg'),
                show_result=args.show),
            ret_value=ret_value)

        # visualize pytorch model
        create_process(
            'visualize pytorch model',
            target=visualize_model,
            args=(model_cfg_path, deploy_cfg_path, [checkpoint_path],
                  args.test_img, args.device),
            kwargs=dict(
                backend=Backend.PYTORCH,
                output_file=osp.join(args.work_dir, 'output_pytorch.jpg'),
                show_result=args.show),
            ret_value=ret_value)
    else:
        logger.warning(
            '\"visualize_model\" has been skipped may be because it\'s \
            running on a headless device.')
    logger.info('All process success.')


if __name__ == '__main__':
    main()