import copy
import os.path as osp
import platform
import sys
import tempfile
from importlib import import_module

from mmcv import Config, import_modules_from_strings

from easycv.framework.errors import IOError, KeyError, ValueError
from .user_config_params_utils import check_value_type

if platform.system() == 'Windows':
    import regex as re
else:
    import re


def traverse_replace(d, key, value):
    if isinstance(d, (dict, Config)):
        for k, v in d.items():
            if k == key:
                d[k] = value
            else:
                traverse_replace(v, key, value)
    elif isinstance(d, (list, tuple, set)):
        for v in d:
            traverse_replace(v, key, value)


BASE_KEY = '_base_'


# To find base cfg in 'easycv/configs/', base_cfg_name should be 'configs/xx/xx.py'
# TODO: reset the api, keep the same way as mmcv `Config.fromfile`
def check_base_cfg_path(base_cfg_name='configs/base.py', ori_filename=None):

    if base_cfg_name == '../../base.py':
        # To becompatible with previous config
        base_cfg_name = 'configs/base.py'

    base_cfg_dir_1 = osp.abspath(osp.dirname(
        osp.dirname(__file__)))  # easycv_package_root_path
    base_cfg_path_1 = osp.join(base_cfg_dir_1, base_cfg_name)
    print('Read base config from', base_cfg_path_1)
    if osp.exists(base_cfg_path_1):
        return base_cfg_path_1

    base_cfg_dir_2 = osp.dirname(base_cfg_dir_1)  # upper level dir
    base_cfg_path_2 = osp.join(base_cfg_dir_2, base_cfg_name)
    print('Read base config from', base_cfg_path_2)
    if osp.exists(base_cfg_path_2):
        return base_cfg_path_2

    # relative to ori_filename
    ori_cfg_dir = osp.dirname(ori_filename)
    base_cfg_path_3 = osp.join(ori_cfg_dir, base_cfg_name)
    base_cfg_path_3 = osp.abspath(osp.expanduser(base_cfg_path_3))
    if osp.exists(base_cfg_path_3):
        return base_cfg_path_3

    raise ValueError('%s not Found' % base_cfg_name)


# Read config without __base__
def mmcv_file2dict_raw(ori_filename):
    filename = osp.abspath(osp.expanduser(ori_filename))
    if not osp.isfile(filename):
        if ori_filename.startswith('configs/'):
            # read configs/config_templates/detection_oss.py
            filename = check_base_cfg_path(ori_filename)
        else:
            raise ValueError('%s and %s not Found' % (ori_filename, filename))

    fileExtname = osp.splitext(filename)[1]
    if fileExtname not in ['.py', '.json', '.yaml', '.yml']:
        raise IOError('Only py/yml/yaml/json type are supported now!')

    with tempfile.TemporaryDirectory() as temp_config_dir:
        temp_config_file = tempfile.NamedTemporaryFile(
            dir=temp_config_dir, suffix=fileExtname)
        if platform.system() == 'Windows':
            temp_config_file.close()
        temp_config_name = osp.basename(temp_config_file.name)
        Config._substitute_predefined_vars(filename, temp_config_file.name)
        if filename.endswith('.py'):
            temp_module_name = osp.splitext(temp_config_name)[0]
            sys.path.insert(0, temp_config_dir)
            Config._validate_py_syntax(filename)
            mod = import_module(temp_module_name)
            sys.path.pop(0)
            cfg_dict = {
                name: value
                for name, value in mod.__dict__.items()
                if not name.startswith('__')
            }
            # delete imported module
            del sys.modules[temp_module_name]
        elif filename.endswith(('.yml', '.yaml', '.json')):
            import mmcv
            cfg_dict = mmcv.load(temp_config_file.name)
        # close temp file
        temp_config_file.close()
    cfg_text = filename + '\n'
    with open(filename, 'r', encoding='utf-8') as f:
        # Setting encoding explicitly to resolve coding issue on windows
        cfg_text += f.read()

    return cfg_dict, cfg_text


# Reac config with __base__
def mmcv_file2dict_base(ori_filename):
    cfg_dict, cfg_text = mmcv_file2dict_raw(ori_filename)

    if BASE_KEY in cfg_dict:
        # cfg_dir = osp.dirname(filename)
        base_filename = cfg_dict.pop(BASE_KEY)
        base_filename = base_filename if isinstance(base_filename,
                                                    list) else [base_filename]

        cfg_dict_list = list()
        cfg_text_list = list()
        for f in base_filename:
            base_cfg_path = check_base_cfg_path(f, ori_filename)
            _cfg_dict, _cfg_text = mmcv_file2dict_base(base_cfg_path)
            cfg_dict_list.append(_cfg_dict)
            cfg_text_list.append(_cfg_text)

        base_cfg_dict = dict()
        for c in cfg_dict_list:
            if len(base_cfg_dict.keys() & c.keys()) > 0:
                raise KeyError('Duplicate key is not allowed among bases')
            base_cfg_dict.update(c)

        base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
        cfg_dict = base_cfg_dict

        # merge cfg_text
        cfg_text_list.append(cfg_text)
        cfg_text = '\n'.join(cfg_text_list)

    return cfg_dict, cfg_text


# gen mmcv.Config
def mmcv_config_fromfile(ori_filename):

    cfg_dict, cfg_text = mmcv_file2dict_base(ori_filename)

    if cfg_dict.get('custom_imports', None):
        import_modules_from_strings(**cfg_dict['custom_imports'])

    return Config(cfg_dict, cfg_text=cfg_text, filename=ori_filename)


# get the true value for ori_key in cfg_dict
def get_config_class_value(cfg_dict, ori_key, dict_mem_helper):
    if ori_key in dict_mem_helper:
        return dict_mem_helper[ori_key]

    k_list = ori_key.split('.')
    t = cfg_dict
    for at_k in k_list:
        if isinstance(t, (list, tuple)):
            t = t[int(at_k)]
        else:
            t = t[at_k]
    dict_mem_helper[ori_key] = t
    return t


def config_dict_edit(ori_cfg_dict, cfg_dict, reg, dict_mem_helper):
    """
    edit ${configs.variables} in config dict to solve dependicies in config

    ori_cfg_dict: to find the true value of ${configs.variables}
    cfg_dict: for find leafs of dict by recursive
    reg: Regular expression pattern for find all ${configs.variables} in leafs of dict
    dict_mem_helper: to store the true value of ${configs.variables} which have been found
    """
    if isinstance(cfg_dict, dict):
        for key, value in cfg_dict.items():
            if isinstance(value, str):
                var_set = set(reg.findall(value))
                if len(var_set) > 0:
                    n_value = value
                    is_arithmetic_exp = True
                    is_exp = True
                    for var in var_set:
                        # get the true value for ori_key in cfg_dict
                        var_value = get_config_class_value(
                            ori_cfg_dict, var[2:-1], dict_mem_helper)

                        # while ${variables}, replace by true value directly
                        if var == value:
                            is_exp = False
                            cfg_dict[key] = var_value
                            break
                        else:
                            if isinstance(var_value, str):
                                # str expression, like '${data_root_path}/images/'
                                is_arithmetic_exp = False
                            # arithmetic expression, like '-${img_size}//2'
                            n_value = n_value.replace(var, str(var_value))
                    if is_exp:
                        cfg_dict[key] = eval(
                            n_value) if is_arithmetic_exp else n_value

            # recursively find the leafs of dict
            config_dict_edit(ori_cfg_dict, value, reg, dict_mem_helper)

    elif isinstance(cfg_dict, list):
        for i, value in enumerate(cfg_dict):
            if isinstance(value, dict):
                # recursively find the leafs of dict
                config_dict_edit(ori_cfg_dict, value, reg, dict_mem_helper)
            elif isinstance(value, list):
                # recursively find the leafs of list
                config_dict_edit(ori_cfg_dict, value, reg, dict_mem_helper)
            elif isinstance(value, str):
                var_set = set(reg.findall(value))
                if len(var_set) > 0:
                    n_value = value
                    is_arithmetic_exp = True
                    is_exp = True
                    for var in var_set:
                        # get the true value for ori_key in cfg_dict
                        var_value = get_config_class_value(
                            ori_cfg_dict, var[2:-1], dict_mem_helper)
                        # while '${variables}', replace by true value directly
                        if var == value:
                            # is not an expression, directly replace
                            is_exp = False
                            cfg_dict[i] = var_value
                            break
                        else:  # arithmetic expression
                            if isinstance(var_value, str):
                                # str expression, like '${data_root_path}/images/'
                                is_arithmetic_exp = False
                            # arithmetic expression, like '-${img_size}//2'
                            n_value = n_value.replace(var, str(var_value))
                    if is_exp:
                        cfg_dict[i] = eval(
                            n_value) if is_arithmetic_exp else n_value


def rebuild_config(cfg, user_config_params):
    """
    # rebuild config by user config params,
    modify config by user config params & replace ${configs.variables} by true value

    return: Config
    """
    print(user_config_params)
    assert len(user_config_params
               ) % 2 == 0, 'user_config_params must be setted as --key value'

    new_text = cfg.text + '\n\n#user config params\n'

    for ori_args_k, args_v in zip(user_config_params[0::2],
                                  user_config_params[1::2]):
        assert ori_args_k.startswith('--')
        args_k = ori_args_k[2:]
        t = cfg
        attr_list = args_k.split('.')
        # find the location of key
        for at_k in attr_list[:-1]:
            if isinstance(t, (list, tuple)):
                try:
                    t = t[int(at_k)]
                except:
                    raise IndexError('%s is out of index' % at_k)
            else:
                try:
                    t = getattr(t, at_k)
                except:
                    raise KeyError('%s is not in dict' % at_k)
        # replace the value of key bey user config params
        if isinstance(t, (list, tuple)):  # set value of list[index]
            # convert str_type to value_type
            formatted_v = check_value_type(args_v, t[int(attr_list[-1])])
            try:
                t[int(attr_list[-1])] = formatted_v
            except:
                raise IndexError('%s is out of index' % attr_list[-1])
        else:  # set value of dict[map]
            # convert str_type to value_type
            formatted_v = check_value_type(args_v, getattr(t, attr_list[-1]))
            try:
                setattr(t, attr_list[-1], formatted_v)
            except:
                raise KeyError('set %s is error' % attr_list[-1])

        new_text += '%s %s\n' % (ori_args_k, args_v)
    # reg example: '-${a.b.ca}+${a.b.ca}+${d.c}'
    reg = re.compile('\$\{[a-zA-Z_][a-zA-Z0-9_\.]*\}')
    dict_mem_helper = {}

    # edit ${configs.variables} in config dict to solve dependicies in config
    config_dict_edit(cfg._cfg_dict, cfg._cfg_dict, reg, dict_mem_helper)

    return Config(
        cfg_dict=cfg._cfg_dict, cfg_text=new_text, filename=cfg._filename)


# validate config for export, which may be from training config, disable pretrained
def validate_export_config(cfg):
    cfg_copy = copy.deepcopy(cfg)
    pretrained = getattr(cfg_copy.model, 'pretrained', None)
    if pretrained is not None:
        setattr(cfg_copy.model, 'pretrained', None)
    backbone = getattr(cfg_copy.model, 'backbone', None)
    if backbone is not None:
        pretrained = getattr(backbone, 'pretrained', None)
        if pretrained is not None:
            setattr(backbone, 'pretrained', False)
    return cfg_copy


CONFIG_TEMPLATE_ZOO = {

    # detection
    'YOLOX': 'configs/config_templates/yolox.py',
    'YOLOX_ITAG': 'configs/config_templates/yolox_itag.py',

    # cls
    'CLASSIFICATION': 'configs/config_templates/classification.py',
    'CLASSIFICATION_OSS': 'configs/config_templates/classification_oss.py',
    'CLASSIFICATION_TFRECORD_OSS':
    'configs/config_templates/classification_tfrecord_oss.py',

    # metric learning
    'METRICLEARNING_TFRECORD_OSS':
    'configs/config_templates/metric_learning/softmaxbased_tfrecord_oss.py',
    'MODELPARALLEL_METRICLEARNING':
    'configs/config_templates/metric_learning/modelparallel_softmaxbased_tfrecord_oss.py',

    # ssl
    'MOCO_R50_TFRECORD': 'configs/config_templates/moco_r50_tfrecord.py',
    'MOCO_R50_TFRECORD_OSS':
    'configs/config_templates/moco_r50_tfrecord_oss.py',
    'MOCO_TIMM_TFRECORD': 'configs/config_templates/moco_timm_tfrecord.py',
    'MOCO_TIMM_TFRECORD_OSS':
    'configs/config_templates/moco_timm_tfrecord_oss.py',
    'SWAV_R50_TFRECORD': 'configs/config_templates/swav_r50_tfrecord.py',
    'SWAV_R50_TFRECORD_OSS':
    'configs/config_templates/swav_r50_tfrecord_oss.py',
    'MOBY_TIMM_TFRECORD_OSS':
    'configs/config_templates/moby_timm_tfrecord_oss.py',
    'DINO_TIMM': 'configs/config_templates/dino_timm.py',
    'DINO_TIMM_TFRECORD_OSS':
    'configs/config_templates/dino_timm_tfrecord_oss.py',
    'DINO_R50_TFRECORD_OSS':
    'configs/config_templates/dino_rn50_tfrecord_oss.py',
    'MAE': 'configs/config_templates/mae_vit_base_patch16.py',

    # edge models
    'YOLOX_EDGE': 'configs/config_templates/yolox_edge.py',
    'YOLOX_EDGE_ITAG': 'configs/config_templates/yolox_edge_itag.py',

    # pose
    'TOPDOWN_HRNET': 'configs/config_templates/topdown_hrnet_w48_udp.py',
    'TOPDOWN_LITEHRNET': 'configs/config_templates/topdown_litehrnet_30.py',
}