# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import logging
import os.path as osp

import onnx
import onnx.helper

from mmdeploy.apis.onnx import extract_partition
from mmdeploy.utils import get_root_logger


def parse_args():
    parser = argparse.ArgumentParser(
        description='Extract model based on markers.')
    parser.add_argument('input_model', help='Input ONNX model')
    parser.add_argument('output_model', help='Output ONNX model')
    parser.add_argument(
        '--start',
        help='Start markers, format: func:type, e.g. backbone:input')
    parser.add_argument('--end', help='End markers')
    parser.add_argument(
        '--log-level',
        help='set log level',
        default='INFO',
        choices=list(logging._nameToLevel.keys()))
    args = parser.parse_args()

    args.start = args.start.split(',') if args.start else []
    args.end = args.end.split(',') if args.end else []

    return args


def collect_avaiable_marks(model):
    marks = []
    for node in model.graph.node:
        if node.op_type == 'Mark':
            for attr in node.attribute:
                if attr.name == 'func':
                    func = str(onnx.helper.get_attribute_value(attr), 'utf-8')
            if func not in marks:
                marks.append(func)
    return marks


def main():
    args = parse_args()

    logger = get_root_logger(log_level=args.log_level)

    model = onnx.load(args.input_model)
    marks = collect_avaiable_marks(model)
    logger.info('Available marks:\n    {}'.format('\n    '.join(marks)))

    extracted_model = extract_partition(model, args.start, args.end)

    if osp.splitext(args.output_model)[-1] != '.onnx':
        args.output_model += '.onnx'
    onnx.save(extracted_model, args.output_model)


if __name__ == '__main__':
    main()