# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import yaml
import os
from argparse import ArgumentParser, RawDescriptionHelpFormatter


def override(dl, ks, v):
    """
    Recursively replace dict of list

    Args:
        dl(dict or list): dict or list to be replaced
        ks(list): list of keys
        v(str): value to be replaced
    """

    def str2num(v):
        try:
            return eval(v)
        except Exception:
            return v

    assert isinstance(dl, (list, dict)), ("{} should be a list or a dict")
    assert len(ks) > 0, ('lenght of keys should larger than 0')
    if isinstance(dl, list):
        k = str2num(ks[0])
        if len(ks) == 1:
            assert k < len(dl), ('index({}) out of range({})'.format(k, dl))
            dl[k] = str2num(v)
        else:
            override(dl[k], ks[1:], v)
    else:
        if len(ks) == 1:
            #assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
            if not ks[0] in dl:
                logger.warning('A new filed ({}) detected!'.format(ks[0], dl))
            dl[ks[0]] = str2num(v)
        else:
            assert ks[0] in dl, (
                '({}) doesn\'t exist in {}, a new dict field is invalid'.
                format(ks[0], dl))
            override(dl[ks[0]], ks[1:], v)


def override_config(config, options=None):
    """
    Recursively override the config

    Args:
        config(dict): dict to be replaced
        options(list): list of pairs(key0.key1.idx.key2=value)
            such as: [
                'topk=2',
                'VALID.transforms.1.ResizeImage.resize_short=300'
            ]

    Returns:
        config(dict): replaced config
    """
    if options is not None:
        for opt in options:
            assert isinstance(opt, str), (
                "option({}) should be a str".format(opt))
            assert "=" in opt, (
                "option({}) should contain a ="
                "to distinguish between key and value".format(opt))
            pair = opt.split('=')
            assert len(pair) == 2, ("there can be only a = in the option")
            key, value = pair
            keys = key.split('.')
            override(config, keys, value)

    return config


class ArgsParser(ArgumentParser):
    def __init__(self):
        super(ArgsParser, self).__init__(
            formatter_class=RawDescriptionHelpFormatter)
        self.add_argument("-c", "--config", help="configuration file to use")
        self.add_argument(
            "-t", "--tag", default="0", help="tag for marking worker")
        self.add_argument(
            '-o',
            '--override',
            action='append',
            default=[],
            help='config options to be overridden')
        self.add_argument(
            "--style_image", default="examples/style_images/1.jpg", help="tag for marking worker")
        self.add_argument(
            "--text_corpus", default="PaddleOCR", help="tag for marking worker")
        self.add_argument(
            "--language", default="en", help="tag for marking worker")

    def parse_args(self, argv=None):
        args = super(ArgsParser, self).parse_args(argv)
        assert args.config is not None, \
            "Please specify --config=configure_file_path."
        return args


def load_config(file_path):
    """
    Load config from yml/yaml file.
    Args:
        file_path (str): Path of the config file to be loaded.
    Returns: config
    """
    ext = os.path.splitext(file_path)[1]
    assert ext in ['.yml', '.yaml'], "only support yaml files for now"
    with open(file_path, 'rb') as f:
        config = yaml.load(f, Loader=yaml.Loader)

    return config


def gen_config():
    base_config = {
        "Global": {
            "algorithm": "SRNet",
            "use_gpu": True,
            "start_epoch": 1,
            "stage1_epoch_num": 100,
            "stage2_epoch_num": 100,
            "log_smooth_window": 20,
            "print_batch_step": 2,
            "save_model_dir": "./output/SRNet",
            "use_visualdl": False,
            "save_epoch_step": 10,
            "vgg_pretrain": "./pretrained/VGG19_pretrained",
            "vgg_load_static_pretrain": True
        },
        "Architecture": {
            "model_type": "data_aug",
            "algorithm": "SRNet",
            "net_g": {
                "name": "srnet_net_g",
                "encode_dim": 64,
                "norm": "batch",
                "use_dropout": False,
                "init_type": "xavier",
                "init_gain": 0.02,
                "use_dilation": 1
            },
            # input_nc, ndf, netD,
            # n_layers_D=3, norm='instance', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_id='cuda:0'
            "bg_discriminator": {
                "name": "srnet_bg_discriminator",
                "input_nc": 6,
                "ndf": 64,
                "netD": "basic",
                "norm": "none",
                "init_type": "xavier",
            },
            "fusion_discriminator": {
                "name": "srnet_fusion_discriminator",
                "input_nc": 6,
                "ndf": 64,
                "netD": "basic",
                "norm": "none",
                "init_type": "xavier",
            }
        },
        "Loss": {
            "lamb": 10,
            "perceptual_lamb": 1,
            "muvar_lamb": 50,
            "style_lamb": 500
        },
        "Optimizer": {
            "name": "Adam",
            "learning_rate": {
                "name": "lambda",
                "lr": 0.0002,
                "lr_decay_iters": 50
            },
            "beta1": 0.5,
            "beta2": 0.999,
        },
        "Train": {
            "batch_size_per_card": 8,
            "num_workers_per_card": 4,
            "dataset": {
                "delimiter": "\t",
                "data_dir": "/",
                "label_file": "tmp/label.txt",
                "transforms": [{
                    "DecodeImage": {
                        "to_rgb": True,
                        "to_np": False,
                        "channel_first": False
                    }
                }, {
                    "NormalizeImage": {
                        "scale": 1. / 255.,
                        "mean": [0.485, 0.456, 0.406],
                        "std": [0.229, 0.224, 0.225],
                        "order": None
                    }
                }, {
                    "ToCHWImage": None
                }]
            }
        }
    }
    with open("config.yml", "w") as f:
        yaml.dump(base_config, f)


if __name__ == '__main__':
    gen_config()