import argparse
import fnmatch
import logging
import os
import os.path as osp
import pickle
from collections import OrderedDict, defaultdict
from datetime import datetime
from pathlib import Path

from modelindex.load_model_index import load
from rich.console import Console
from rich.syntax import Syntax
from rich.table import Table
from utils import METRICS_MAP, MMCLS_ROOT, substitute_weights

# Avoid to import MMPretrain to accelerate speed to show summary

console = Console()
logger = logging.getLogger('test')
logger.addHandler(logging.StreamHandler())
logger.addHandler(logging.FileHandler('benchmark_test.log', mode='w'))


def parse_args():
    parser = argparse.ArgumentParser(
        description="Test all models' accuracy in model-index.yml")
    parser.add_argument('checkpoint_root', help='Checkpoint file root path.')
    parser.add_argument(
        '--local', action='store_true', help='run at local instead of slurm.')
    parser.add_argument(
        '--models', nargs='+', type=str, help='Specify model names to run.')
    parser.add_argument(
        '--run', action='store_true', help='run script directly')
    parser.add_argument(
        '--summary',
        action='store_true',
        help='Summarize benchmark test results.')
    parser.add_argument('--save', action='store_true', help='Save the summary')
    parser.add_argument(
        '--gpus', type=int, default=1, help='How many GPUS to use.')
    parser.add_argument(
        '--no-skip',
        action='store_true',
        help='Whether to skip models without results record in the metafile.')
    parser.add_argument(
        '--work-dir',
        default='work_dirs/benchmark_test',
        help='the dir to save metric')
    parser.add_argument('--port', type=int, default=29666, help='dist port')
    parser.add_argument(
        '--partition',
        type=str,
        default='mm_model',
        help='(for slurm) Cluster partition to use.')
    parser.add_argument(
        '--job-name',
        type=str,
        default='cls-test-benchmark',
        help='(for slurm) Slurm job name prefix')
    parser.add_argument(
        '--quotatype',
        default=None,
        choices=['reserved', 'auto', 'spot'],
        help='(for slurm) Quota type, only available for phoenix-slurm>=0.2')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        type=str,
        default=[],
        help='Config options for all config files.')

    args = parser.parse_args()
    return args


def create_test_job_batch(commands, model_info, args, port, script_name):
    model_name = model_info.name
    config = Path(model_info.config)

    if model_info.weights is not None:
        checkpoint = substitute_weights(model_info.weights,
                                        args.checkpoint_root)
        if checkpoint is None:
            logger.warning(f'{model_name}: {checkpoint} not found.')
            return None
    else:
        return None

    job_name = f'{args.job_name}_{model_name}'
    work_dir = Path(args.work_dir) / model_name
    work_dir.mkdir(parents=True, exist_ok=True)
    result_file = work_dir / 'result.pkl'

    if args.quotatype is not None:
        quota_cfg = f'#SBATCH --quotatype {args.quotatype}'
    else:
        quota_cfg = ''

    if not args.local:
        launcher = 'srun python'
        runner = 'slurm'
    elif args.gpus > 1:
        launcher = 'pytorch'
        runner = ('torchrun --master_addr="127.0.0.1" '
                  f'--master_port={port} --nproc_per_node={args.gpus}')
    else:
        launcher = 'none'
        runner = 'python -u'

    job_script = (f'#!/bin/bash\n'
                  f'#SBATCH --output {work_dir}/job.%j.out\n'
                  f'#SBATCH --partition={args.partition}\n'
                  f'#SBATCH --job-name {job_name}\n'
                  f'#SBATCH --gres=gpu:{min(8, args.gpus)}\n'
                  f'{quota_cfg}\n'
                  f'#SBATCH --ntasks-per-node={min(8, args.gpus)}\n'
                  f'#SBATCH --ntasks={args.gpus}\n'
                  f'#SBATCH --cpus-per-task=5\n\n'
                  f'{runner} {script_name} {config} {checkpoint} '
                  f'--work-dir={work_dir} --cfg-option '
                  f'env_cfg.dist_cfg.port={port} '
                  f'{" ".join(args.cfg_options)} '
                  f'--out={result_file} --out-item="metrics" '
                  f'--launcher={launcher}\n')

    with open(work_dir / 'job.sh', 'w') as f:
        f.write(job_script)

    commands.append(f'echo "{config}"')
    if args.local:
        commands.append(f'bash {work_dir}/job.sh')
    else:
        commands.append(f'sbatch {work_dir}/job.sh')

    return work_dir / 'job.sh'


def test(models, args):
    script_name = osp.join('tools', 'test.py')
    port = args.port

    commands = []

    preview_script = ''
    for model_info in models.values():

        if model_info.results is None:
            # Skip pre-train model
            continue

        script_path = create_test_job_batch(commands, model_info, args, port,
                                            script_name)
        preview_script = script_path or preview_script
        port += 1

    command_str = '\n'.join(commands)

    preview = Table()
    preview.add_column(str(preview_script))
    preview.add_column('Shell command preview')
    preview.add_row(
        Syntax.from_path(
            preview_script,
            background_color='default',
            line_numbers=True,
            word_wrap=True),
        Syntax(
            command_str,
            'bash',
            background_color='default',
            line_numbers=True,
            word_wrap=True))
    console.print(preview)

    if args.run:
        os.system(command_str)
    else:
        console.print('Please set "--run" to start the job')


def save_summary(summary_data, work_dir):
    summary_path = work_dir / 'test_benchmark_summary.csv'
    file = open(summary_path, 'w')
    columns = defaultdict(list)
    for model_name, summary in summary_data.items():
        if len(summary) == 0:
            # Skip models without results
            continue
        columns['Name'].append(model_name)

        for metric_key in METRICS_MAP:
            if metric_key in summary:
                metric = summary[metric_key]
                expect = round(metric['expect'], 2)
                result = round(metric['result'], 2)
                columns[f'{metric_key} (expect)'].append(str(expect))
                columns[f'{metric_key}'].append(str(result))
            else:
                columns[f'{metric_key} (expect)'].append('')
                columns[f'{metric_key}'].append('')

    columns = {
        field: column
        for field, column in columns.items() if ''.join(column)
    }
    file.write(','.join(columns.keys()) + '\n')
    for row in zip(*columns.values()):
        file.write(','.join(row) + '\n')
    file.close()
    logger.info('Summary file saved at ' + str(summary_path))


def show_summary(summary_data):
    table = Table(title='Test Benchmark Regression Summary')
    table.add_column('Name')
    for metric in METRICS_MAP:
        table.add_column(f'{metric} (expect)')
        table.add_column(f'{metric}')
    table.add_column('Date')

    def set_color(value, expect):
        if value > expect + 0.01:
            return 'green'
        elif value >= expect - 0.01:
            return 'white'
        else:
            return 'red'

    for model_name, summary in summary_data.items():
        row = [model_name]
        for metric_key in METRICS_MAP:
            if metric_key in summary:
                metric = summary[metric_key]
                expect = round(metric['expect'], 2)
                result = round(metric['result'], 2)
                color = set_color(result, expect)
                row.append(f'{expect:.2f}')
                row.append(f'[{color}]{result:.2f}[/{color}]')
            else:
                row.extend([''] * 2)
        if 'date' in summary:
            row.append(summary['date'])
        else:
            row.append('')
        table.add_row(*row)

    # Remove empty columns
    table.columns = [
        column for column in table.columns if ''.join(column._cells)
    ]
    console.print(table)


def summary(models, args):
    work_dir = Path(args.work_dir)

    summary_data = {}
    for model_name, model_info in models.items():

        if model_info.results is None and not args.no_skip:
            continue

        # Skip if not found result file.
        result_file = work_dir / model_name / 'result.pkl'
        if not result_file.exists():
            summary_data[model_name] = {}
            continue

        with open(result_file, 'rb') as file:
            results = pickle.load(file)
        date = datetime.fromtimestamp(result_file.lstat().st_mtime)

        expect_metrics = model_info.results[0].metrics

        # extract metrics
        summary = {'date': date.strftime('%Y-%m-%d')}
        for key_yml, key_res in METRICS_MAP.items():
            if key_yml in expect_metrics and key_res in results:
                expect_result = float(expect_metrics[key_yml])
                result = float(results[key_res])
                summary[key_yml] = dict(expect=expect_result, result=result)

        summary_data[model_name] = summary

    show_summary(summary_data)
    if args.save:
        save_summary(summary_data, work_dir)


def main():
    args = parse_args()

    # parse model-index.yml
    model_index_file = MMCLS_ROOT / 'model-index.yml'
    model_index = load(str(model_index_file))
    model_index.build_models_with_collections()
    models = OrderedDict({model.name: model for model in model_index.models})

    if args.models:
        filter_models = {}
        for pattern in args.models:
            filter_models.update({
                name: models[name]
                for name in fnmatch.filter(models, pattern + '*')
            })
        if len(filter_models) == 0:
            logger.error('No model found, please specify models in:\n' +
                         '\n'.join(models.keys()))
            return
        models = filter_models

    if args.summary:
        summary(models, args)
    else:
        test(models, args)


if __name__ == '__main__':
    main()