# flake8: noqa
import argparse
import re
import warnings
from pathlib import Path

from modelindex.load_model_index import load
from modelindex.models.ModelIndex import ModelIndex
from tabulate import tabulate

MMPT_ROOT = Path(__file__).absolute().parents[1]

prog_description = """\
Use metafile to generate a README.md.

Notice that the tool may fail in some corner cases, and you still need to check and fill some contents manually in the generated README.
"""

PREDICT_TEMPLATE = """\
**Predict image**

```python
from mmpretrain import inference_model

predict = inference_model('{model_name}', 'demo/bird.JPEG')
print(predict['pred_class'])
print(predict['pred_score'])
```
"""

RETRIEVE_TEMPLATE = """\
**Retrieve image**

```python
from mmpretrain import ImageRetrievalInferencer

inferencer = ImageRetrievalInferencer('{model_name}', prototype='demo/')
predict = inferencer('demo/dog.jpg', topk=2)[0]
print(predict[0])
print(predict[1])
```
"""

USAGE_TEMPLATE = """\
**Use the model**

```python
import torch
from mmpretrain import get_model

model = get_model('{model_name}', pretrained=True)
inputs = torch.rand(1, 3, 224, 224)
out = model(inputs)
print(type(out))
# To extract features.
feats = model.extract_feat(inputs)
print(type(feats))
```
"""

TRAIN_TEST_TEMPLATE = """\
**Train/Test Command**

Prepare your dataset according to the [docs](https://mmclassification.readthedocs.io/en/1.x/user_guides/dataset_prepare.html#prepare-dataset).

Train:

```shell
python tools/train.py {train_config}
```

Test:

```shell
python tools/test.py {test_config} {test_weights}
```
"""

TEST_ONLY_TEMPLATE = """\
**Test Command**

Prepare your dataset according to the [docs](https://mmclassification.readthedocs.io/en/1.x/user_guides/dataset_prepare.html#prepare-dataset).

Test:

```shell
python tools/test.py {test_config} {test_weights}
```
"""

METRIC_MAPPING = {
    'Top 1 Accuracy': 'Top-1 (%)',
    'Top 5 Accuracy': 'Top-5 (%)',
}

DATASET_PRIORITY = {
    'ImageNet-1k': 0,
    'CIFAR-10': 10,
    'CIFAR-100': 20,
}


def parse_args():
    parser = argparse.ArgumentParser(description=prog_description)
    parser.add_argument('metafile', type=Path, help='The path of metafile')
    parser.add_argument(
        '--table', action='store_true', help='Only generate summary tables')
    parser.add_argument(
        '--update', type=str, help='Update the specified readme file.')
    parser.add_argument('--out', type=str, help='Output to the file.')
    parser.add_argument(
        '--update-items',
        type=str,
        nargs='+',
        default=['models'],
        help='Update the specified readme file.')
    args = parser.parse_args()
    return args


def filter_models_by_task(models, task):
    model_list = []
    for model in models:
        if model.results is None and task is None:
            model_list.append(model)
        elif model.results is None:
            continue
        elif model.results[0].task == task or task == 'any':
            model_list.append(model)
    return model_list


def add_title(metafile: ModelIndex):
    paper = metafile.collections[0].paper
    title = paper['Title']
    url = paper['URL']
    abbr = metafile.collections[0].name
    papertype = metafile.collections[0].data.get('type', 'Algorithm')

    return f'# {abbr}\n> [{title}]({url})\n<!-- [{papertype.upper()}] -->\n'


def add_abstract(metafile: ModelIndex):
    paper = metafile.collections[0].paper
    url = paper['URL']
    if 'arxiv' in url:
        try:
            import arxiv
            search = arxiv.Search(id_list=[url.split('/')[-1]])
            info = next(search.results())
            abstract = info.summary.replace('\n', ' ')
        except ImportError:
            warnings.warn('Install arxiv parser by `pip install arxiv` '
                          'to automatically generate abstract.')
            abstract = None
    else:
        abstract = None

    content = '## Abstract\n'
    if abstract is not None:
        content += f'\n{abstract}\n'
    return content


def add_usage(metafile):
    models = metafile.models
    if len(models) == 0:
        return

    content = []
    content.append('## How to use it?\n\n<!-- [TABS-BEGIN] -->\n')

    # Predict image
    cls_models = filter_models_by_task(models, 'Image Classification')
    if cls_models:
        model_name = cls_models[0].name
        content.append(PREDICT_TEMPLATE.format(model_name=model_name))

    # Retrieve image
    retrieval_models = filter_models_by_task(models, 'Image Retrieval')
    if retrieval_models:
        model_name = retrieval_models[0].name
        content.append(RETRIEVE_TEMPLATE.format(model_name=model_name))

    # Use the model
    model_name = models[0].name
    content.append(USAGE_TEMPLATE.format(model_name=model_name))

    # Train/Test Command
    inputs = {}
    train_model = [
        model for model in models
        if 'headless' not in model.name and '3rdparty' not in model.name
    ]
    if train_model:
        template = TRAIN_TEST_TEMPLATE
        inputs['train_config'] = train_model[0].config
    else:
        template = TEST_ONLY_TEMPLATE
    test_model = filter_models_by_task(models, task='any')[0]
    inputs['test_config'] = test_model.config
    inputs['test_weights'] = test_model.weights
    content.append(template.format(**inputs))

    content.append('\n<!-- [TABS-END] -->\n')
    return '\n'.join(content)


def format_pretrain(pretrain_field):
    pretrain_infos = pretrain_field.split('-')[:-1]
    infos = []
    for info in pretrain_infos:
        if re.match('^\d+e$', info):
            info = f'{info[:-1]}-Epochs'
        elif re.match('^in\d+k$', info):
            info = f'ImageNet-{info[2:-1]}k'
        else:
            info = info.upper()
        infos.append(info)
    return ' '.join(infos)


def generate_model_table(models,
                         folder,
                         with_pretrain=True,
                         with_metric=True,
                         pretrained_models=[]):
    header = ['Model']
    if with_pretrain:
        header.append('Pretrain')
    header.extend(['Params (M)', 'Flops (G)'])
    if with_metric:
        metrics = set()
        for model in models:
            metrics.update(model.results[0].metrics.keys())
        metrics = sorted(list(set(metrics)))
        for metric in metrics:
            header.append(METRIC_MAPPING.get(metric, metric))
    header.extend(['Config', 'Download'])

    rows = []
    for model in models:
        model_name = f'`{model.name}`'
        config = (MMPT_ROOT / model.config).relative_to(folder)
        if model.weights is not None:
            download = f'[model]({model.weights})'
        else:
            download = 'N/A'

        if 'Converted From' in model.data:
            model_name += '\*'
            converted_from = model.data['Converted From']
        elif model.weights is not None:
            log = re.sub(r'.pth$', '.json', model.weights)
            download += f' \| [log]({log})'

        row = [model_name]
        if with_pretrain:
            pretrain_field = [
                field for field in model.name.split('_')
                if field.endswith('-pre')
            ]
            if pretrain_field:
                pretrain = format_pretrain(pretrain_field[0])
                upstream = [
                    pretrain_model for pretrain_model in pretrained_models
                    if model.name in pretrain_model.data.get('Downstream', [])
                ]
                if upstream:
                    pretrain = f'[{pretrain}]({upstream[0].weights})'
            else:
                pretrain = 'From scratch'
            row.append(pretrain)

        if model.metadata.parameters is not None:
            row.append(f'{model.metadata.parameters / 1e6:.2f}')  # Params
        else:
            row.append('N/A')
        if model.metadata.flops is not None:
            row.append(f'{model.metadata.flops / 1e9:.2f}')  # Params
        else:
            row.append('N/A')

        if with_metric:
            for metric in metrics:
                row.append(model.results[0].metrics.get(metric, 'N/A'))
        row.append(f'[config]({config})')
        row.append(download)

        rows.append(row)

    table_cfg = dict(
        tablefmt='pipe',
        floatfmt='.2f',
        colalign=['left'] + ['center'] * (len(row) - 1))
    table_string = tabulate(rows, header, **table_cfg) + '\n'
    if any('Converted From' in model.data for model in models):
        table_string += (
            f"\n*Models with \* are converted from the [official repo]({converted_from['Code']}). "
            "The config files of these models are only for inference. We haven't reprodcue the training results.*\n"
        )

    return table_string


def add_models(metafile):
    models = metafile.models
    if len(models) == 0:
        return ''

    content = ['## Models and results\n']
    algo_folder = Path(metafile.filepath).parent.absolute().resolve()

    # Pretrained models
    pretrain_models = filter_models_by_task(models, task=None)
    if pretrain_models:
        content.append('### Pretrained models\n')
        content.append(
            generate_model_table(
                pretrain_models,
                algo_folder,
                with_pretrain=False,
                with_metric=False))

    # Classification models
    tasks = [
        'Image Classification',
        'Image Retrieval',
        'Multi-Label Classification',
    ]

    for task in tasks:
        task_models = filter_models_by_task(models, task=task)
        if task_models:
            datasets = {model.results[0].dataset for model in task_models}
            datasets = sorted(
                list(datasets), key=lambda x: DATASET_PRIORITY.get(x, 50))
            for dataset in datasets:
                content.append(f'### {task} on {dataset}\n')
                dataset_models = [
                    model for model in task_models
                    if model.results[0].dataset == dataset
                ]
                content.append(
                    generate_model_table(
                        dataset_models,
                        algo_folder,
                        pretrained_models=pretrain_models))
    return '\n'.join(content)


def parse_readme(readme):
    with open(readme, 'r') as f:
        file = f.read()

    content = {}

    for img_match in re.finditer(
            '^<div.*\n.*\n</div>\n', file, flags=re.MULTILINE):
        content['image'] = img_match.group()
        start, end = img_match.span()
        file = file[:start] + file[end:]
        break

    sections = re.split('^## ', file, flags=re.MULTILINE)
    for section in sections:
        if section.startswith('# '):
            content['title'] = section.strip() + '\n'
        elif section.startswith('Introduction'):
            content['intro'] = '## ' + section.strip() + '\n'
        elif section.startswith('Abstract'):
            content['abs'] = '## ' + section.strip() + '\n'
        elif section.startswith('How to use it'):
            content['usage'] = '## ' + section.strip() + '\n'
        elif section.startswith('Models and results'):
            content['models'] = '## ' + section.strip() + '\n'
        elif section.startswith('Citation'):
            content['citation'] = '## ' + section.strip() + '\n'
        else:
            section_title = section.split('\n', maxsplit=1)[0]
            content[section_title] = '## ' + section.strip() + '\n'
    return content


def combine_readme(content: dict):
    content = content.copy()
    readme = content.pop('title')
    if 'intro' in content:
        readme += f"\n{content.pop('intro')}"
        readme += f"\n{content.pop('image')}"
        readme += f"\n{content.pop('abs')}"
    else:
        readme += f"\n{content.pop('abs')}"
        readme += f"\n{content.pop('image')}"

    readme += f"\n{content.pop('usage')}"
    readme += f"\n{content.pop('models')}"

    citation = content.pop('citation')
    if content:
        # Custom sections
        for v in content.values():
            readme += f'\n{v}'
    readme += f'\n{citation}'
    return readme


def main():
    args = parse_args()
    metafile = load(str(args.metafile))
    if args.table:
        print(add_models(metafile))
        return

    if args.update is not None:
        content = parse_readme(args.update)
    else:
        content = {}

    if 'title' not in content or 'title' in args.update_items:
        content['title'] = add_title(metafile)
    if 'abs' not in content or 'abs' in args.update_items:
        content['abs'] = add_abstract(metafile)
    if 'image' not in content or 'image' in args.update_items:
        img = '<div align=center>\n<img src="" width="50%"/>\n</div>\n'
        content['image'] = img
    if 'usage' not in content or 'usage' in args.update_items:
        content['usage'] = add_usage(metafile)
    if 'models' not in content or 'models' in args.update_items:
        content['models'] = add_models(metafile)
    if 'citation' not in content:
        content['citation'] = '## Citation\n```bibtex\n```\n'

    content = combine_readme(content)
    if args.out is not None:
        with open(args.out, 'w') as f:
            f.write(content)
    else:
        print(content)


if __name__ == '__main__':
    main()