# Copyright (c) Alibaba, Inc. and its affiliates.
import os

import jsonplus

from easycv.file import io
from easycv.utils.config_tools import Config

MODELSCOPE_PREFIX = 'modelscope'
EASYCV_ARCH = '__easycv_arch__'


def to_ms_config(cfg, task, ms_model_name, save_path=None, dump=True):
    """Convert EasyCV config to ModelScope style.

    Args:
        cfg (str | Config): Easycv config file or Config object.
        task (str): Task name in modelscope, refer to: modelscope.utils.constant.Tasks.
        ms_model_name (str): Model name registered in modelscope, model type will be replaced with `ms_model_name`, used in modelscope.
        save_path (str): Save path for saving the generated modelscope configuration file. Only valid when dump is True.
        dump (bool): Whether dump the converted config to `save_path`.
    """
    # TODO: support multi eval_pipelines
    # TODO: support for adding customized required keys to the configuration file

    if isinstance(cfg, str):
        easycv_cfg = Config.fromfile(cfg)
        if dump and save_path is None:
            save_dir = os.path.dirname(cfg)
            save_name = MODELSCOPE_PREFIX + '_' + os.path.splitext(
                os.path.basename(cfg))[0] + '.json'
            save_path = os.path.join(save_dir, save_name)
    else:
        easycv_cfg = cfg
        if dump and save_path is None:
            raise ValueError('Please provide `save_path`!')

    assert save_path.endswith('json'), 'Only support json file!'
    optimizer_options = easycv_cfg.optimizer_config
    optimizer_options.update({'loss_keys': 'total_loss'})

    val_dataset_cfg = easycv_cfg.data.val
    val_imgs_per_gpu = val_dataset_cfg.pop('imgs_per_gpu',
                                           easycv_cfg.data.imgs_per_gpu)
    val_workers_per_gpu = val_dataset_cfg.pop('workers_per_gpu',
                                              easycv_cfg.data.workers_per_gpu)

    log_config = easycv_cfg.log_config
    predict_config = easycv_cfg.get('predict', None)

    hooks = [{
        'type': 'CheckpointHook',
        'interval': easycv_cfg.checkpoint_config.interval
    }, {
        'type': 'EvaluationHook',
        'interval': easycv_cfg.eval_config.interval
    }, {
        'type': 'AddLrLogHook'
    }, {
        'type': 'IterTimerHook'
    }]

    custom_hooks = easycv_cfg.get('custom_hooks', [])
    hooks.extend(custom_hooks)

    for log_hook_i in log_config.hooks:
        if log_hook_i['type'] == 'TensorboardLoggerHook':
            # replace with modelscope api
            hooks.append({
                'type': 'TensorboardHook',
                'interval': log_config.interval
            })
        elif log_hook_i['type'] == 'TextLoggerHook':
            # use modelscope api
            hooks.append({
                'type': 'TextLoggerHook',
                'interval': log_config.interval
            })
        else:
            log_hook_i.update({'interval': log_config.interval})
            hooks.append(log_hook_i)

    ori_model_type = easycv_cfg.model.pop('type')

    ms_cfg = Config(
        dict(
            task=task,
            framework='pytorch',
            model={
                'type': ms_model_name,
                **easycv_cfg.model, EASYCV_ARCH: {
                    'type': ori_model_type
                }
            },
            dataset=dict(train=easycv_cfg.data.train, val=val_dataset_cfg),
            train=dict(
                work_dir=easycv_cfg.get('work_dir', None),
                max_epochs=easycv_cfg.total_epochs,
                dataloader=dict(
                    batch_size_per_gpu=easycv_cfg.data.imgs_per_gpu,
                    workers_per_gpu=easycv_cfg.data.workers_per_gpu,
                ),
                optimizer=dict(
                    **easycv_cfg.optimizer, options=optimizer_options),
                lr_scheduler=easycv_cfg.lr_config,
                hooks=hooks),
            evaluation=dict(
                dataloader=dict(
                    batch_size_per_gpu=val_imgs_per_gpu,
                    workers_per_gpu=val_workers_per_gpu,
                ),
                metrics={
                    'type': 'EasyCVMetric',
                    'evaluators': easycv_cfg.eval_pipelines[0].evaluators
                }),
            pipeline=dict(predictor_config=predict_config),
        ))

    if dump:
        with io.open(save_path, 'w') as f:
            res = jsonplus.dumps(
                ms_cfg._cfg_dict.to_dict(), indent=4, sort_keys=False)
            f.write(res)

    return ms_cfg