import os
import tempfile

import mmcv
import pytest

from mmdeploy.utils import get_onnx_config, get_task_type, load_config
from mmdeploy.utils.constants import Task
from mmdeploy.utils.export_info import dump_info

correct_model_path = 'tests/data/srgan.py'
correct_model_cfg = mmcv.Config.fromfile(correct_model_path)
correct_deploy_path = 'tests/data/super-resolution.py'
correct_deploy_cfg = mmcv.Config.fromfile(correct_deploy_path)
empty_file_path = tempfile.NamedTemporaryFile(suffix='.py').name
empty_path = './a.py'


@pytest.fixture(autouse=True, scope='module')
def create_empty_file():
    os.mknod(empty_file_path)


def test_load_config_none():
    with pytest.raises(AssertionError):
        load_config()


def test_load_config_type_error():
    with pytest.raises(TypeError):
        load_config(1)


def test_load_config_file_error():
    with pytest.raises(FileNotFoundError):
        load_config(empty_path)


@pytest.mark.parametrize('args', [
    [empty_file_path],
    [correct_model_path],
    [correct_model_cfg],
    (correct_model_path, correct_deploy_path),
    (correct_model_path, correct_deploy_cfg),
    (correct_model_cfg, correct_deploy_cfg),
])
def test_load_config(args):
    configs = load_config(*args)
    for v in zip(configs, args):
        if isinstance(v[1], str):
            cfg = mmcv.Config.fromfile(v[1])
        else:
            cfg = v[1]
        assert v[0]._cfg_dict == cfg._cfg_dict


@pytest.mark.parametrize('deploy_cfg, default',
                         [(empty_file_path, None),
                          (empty_file_path, Task.SUPER_RESOLUTION)])
def test_get_task_type_default(deploy_cfg, default):
    if default is None:
        res = get_task_type(deploy_cfg)
    else:
        res = get_task_type(deploy_cfg, default)
    assert res == default


@pytest.mark.parametrize('deploy_cfg, default',
                         [(correct_deploy_path, None),
                          (correct_deploy_path, Task.TEXT_DETECTION),
                          (correct_deploy_cfg, None)])
def test_get_task_type(deploy_cfg, default):
    if default is None:
        res = get_task_type(deploy_cfg)
    else:
        res = get_task_type(deploy_cfg, default)
    assert res == Task.SUPER_RESOLUTION


def test_get_onnx_config_error():
    with pytest.raises(Exception):
        get_onnx_config(empty_file_path)


@pytest.mark.parametrize('deploy_cfg',
                         [correct_deploy_path, correct_deploy_cfg])
def test_get_onnx_config(deploy_cfg):
    onnx_config = dict(
        dynamic_axes={
            'input': {
                0: 'batch',
                2: 'height',
                3: 'width'
            },
            'output': {
                0: 'batch',
                2: 'height',
                3: 'width'
            }
        },
        type='onnx',
        export_params=True,
        keep_initializers_as_inputs=False,
        opset_version=11,
        save_file='end2end.onnx',
        input_names=['input'],
        output_names=['output'],
        input_shape=None)
    res = get_onnx_config(deploy_cfg)
    assert res == onnx_config


def test_AdvancedEnum():
    keys = [
        Task.TEXT_DETECTION, Task.TEXT_RECOGNITION, Task.SEGMENTATION,
        Task.SUPER_RESOLUTION, Task.CLASSIFICATION, Task.OBJECT_DETECTION
    ]
    vals = [
        'TextDetection', 'TextRecognition', 'Segmentation', 'SuperResolution',
        'Classification', 'ObjectDetection'
    ]
    for k, v in zip(keys, vals):
        assert Task.get(v, None) == k
        assert k.value == v
    assert Task.get('a', Task.TEXT_DETECTION) == Task.TEXT_DETECTION


def test_export_info():
    with tempfile.TemporaryDirectory() as dir:
        dump_info(correct_deploy_cfg, correct_model_cfg, dir)
        preprocess_json = os.path.join(dir, 'preprocess.json')
        deploy_json = os.path.join(dir, 'deploy_cfg.json')
        assert os.path.exists(preprocess_json)
        assert os.path.exists(deploy_json)