#!/usr/bin/env python

# Copyright (c) OpenMMLab. All rights reserved.
# This tool is used to update model-index.yml which is required by MIM, and
# will be automatically called as a pre-commit hook. The updating will be
# triggered if any change of model information (.md files in configs/) has been
# detected before a commit.

import os
import os.path as osp
import re
import sys
from typing import List, Tuple

import yaml

MMSEG_ROOT = osp.abspath(osp.join(osp.dirname(__file__), '..'))


def get_collection_name_list(md_file_list: List[str]) -> List[str]:
    """Get the list of collection names."""
    collection_name_list: List[str] = []
    for md_file in md_file_list:
        with open(md_file) as f:
            lines = f.readlines()
            collection_name = lines[0].split('#')[1].strip()
            collection_name_list.append(collection_name)
    return collection_name_list


def get_md_file_list() -> Tuple[List[str], List[str]]:
    """Get the list of md files."""
    md_file_list: List[str] = []
    md_dir_list: List[str] = []
    for root, _, files in os.walk(osp.join(MMSEG_ROOT, 'configs')):
        for file in files:
            if file.endswith('.md'):
                md_file_list.append(osp.join(root, file))
                md_dir_list.append(root)
                break
    return md_file_list, md_dir_list


def get_model_info(md_file: str, config_dir: str,
                   collection_name_list: List[str]) -> Tuple[dict, str]:
    """Get model information from md file."""
    datasets: List[str] = []
    models: List[dict] = []
    current_dataset: str = ''
    paper_name: str = ''
    paper_url: str = ''
    code_url: str = ''
    is_backbone: bool = False
    is_dataset: bool = False
    collection_name: str = ''
    with open(md_file) as f:
        lines: List[str] = f.readlines()
        i: int = 0

        while i < len(lines):
            line: str = lines[i].strip()
            if len(line) == 0:
                i += 1
                continue
            # get paper name and url
            if re.match(r'> \[.*\]+\([a-zA-Z]+://[^\s]*\)', line):
                paper_info = line.split('](')
                paper_name = paper_info[0][paper_info[0].index('[') + 1:]
                paper_url = paper_info[1][:len(paper_info[1]) - 1]

            # get code info
            if 'Code Snippet' in line:
                code_url = line.split('"')[1].split('"')[0]

            if line.startswith('<!-- [BACKBONE]'):
                is_backbone = True

            if line.startswith('<!-- [DATASET]'):
                is_dataset = True

            if '<!-- [SKIP DEV CHECK] -->' in line:
                return None, None

            # get dataset names
            if line.startswith('###'):
                current_dataset = line.split('###')[1].strip()
                datasets.append(current_dataset)

            # get model info key id
            if (line[0] == '|' and (i + 1) < len(lines)
                    and lines[i + 1][:3] == '| -' and 'Method' in line
                    and 'Crop Size' in line and 'Mem (GB)' in line):
                keys: List[str] = [key.strip() for key in line.split('|')]
                crop_size_idx: int = keys.index('Crop Size')
                mem_idx: int = keys.index('Mem (GB)')
                assert 'Device' in keys, f'No Device in {md_file}'
                device_idx: int = keys.index('Device')

                if 'mIoU' in keys:
                    ss_idx = keys.index('mIoU')
                elif 'mDice' in keys:
                    ss_idx = keys.index('mDice')
                else:
                    raise ValueError(f'No mIoU or mDice in {md_file}')
                if 'mIoU(ms+flip)' in keys:
                    ms_idx = keys.index('mIoU(ms+flip)')
                elif 'Dice' in keys:
                    ms_idx = keys.index('Dice')
                else:
                    ms_idx = -1
                config_idx = keys.index('config')
                download_idx = keys.index('download')
                j: int = i + 2
                while j < len(lines) and lines[j][0] == '|':
                    values = [value.strip() for value in lines[j].split('|')]
                    # get config name
                    try:
                        config_url = re.findall(r'[a-zA-Z]+://[^\s]*py',
                                                values[config_idx])[0]
                        config_name = config_url.split('/')[-1]
                        model_name = config_name.replace('.py', '')
                    except IndexError:
                        raise ValueError(
                            f'config url is not found in {md_file}')

                    # get model name
                    try:
                        weight_url = re.findall(r'[a-zA-Z]+://[^\s]*pth',
                                                values[download_idx])[0]
                        log_url = re.findall(r'[a-zA-Z]+://[^\s]*.json',
                                             values[download_idx + 1])[0]
                    except IndexError:
                        raise ValueError(
                            f'url is not found in {values[download_idx]}')

                    # get batch size
                    bs = re.findall(r'[0-9]*xb[0-9]*',
                                    config_name)[0].split('xb')
                    batch_size = int(bs[0]) * int(bs[1])

                    # get crop size
                    crop_size = values[crop_size_idx].split('x')
                    crop_size = [int(crop_size[0]), int(crop_size[1])]

                    mem = values[mem_idx].split('\\')[0] if values[
                        mem_idx] != '-' and values[mem_idx] != '' else -1

                    method = values[keys.index('Method')].strip()
                    # method = [method.strip()] if '+' not in method else [
                    #     m.strip() for m in method.split('+')
                    # ]
                    # split method name:
                    if ' + ' in method:
                        method = [m.strip() for m in method.split(' + ')]
                    elif ' ' in method:
                        method = [m for m in method.split(' ')]
                    else:
                        method = [method]
                    backone: str = re.findall(
                        r'[^\s]*', values[keys.index('Backbone')].strip())[0]
                    archs = [backone] + method
                    collection_name = method[0]
                    config_path = osp.join('configs',
                                           config_dir.split('/')[-1],
                                           config_name)
                    model = {
                        'Name': model_name,
                        'In Collection': collection_name,
                        'Results': {
                            'Task': 'Semantic Segmentation',
                            'Dataset': current_dataset,
                            'Metrics': {
                                keys[ss_idx]: float(values[ss_idx])
                            }
                        },
                        'Config': config_path,
                        'Metadata': {
                            'Training Data':
                            current_dataset,
                            'Batch Size':
                            batch_size,
                            'Architecture':
                            archs,
                            'Training Resources':
                            f'{bs[0]}x {values[device_idx]} GPUS',
                        },
                        'Weights': weight_url,
                        'Training log': log_url,
                        'Paper': {
                            'Title': paper_name,
                            'URL': paper_url
                        },
                        'Code': code_url,
                        'Framework': 'PyTorch'
                    }
                    if ms_idx != -1 and values[ms_idx] != '-' and values[
                            ms_idx] != '':
                        model['Results']['Metrics'].update(
                            {keys[ms_idx]: float(values[ms_idx])})
                    if mem != -1:
                        model['Metadata']['Memory (GB)'] = float(mem)
                    models.append(model)
                    j += 1
                i = j
            i += 1

    if not (is_dataset
            or is_backbone) or collection_name not in collection_name_list:
        collection = {
            'Name': collection_name,
            'License': 'Apache License 2.0',
            'Metadata': {
                'Training Data': datasets
            },
            'Paper': {
                'Title': paper_name,
                'URL': paper_url,
            },
            'README': osp.join('configs',
                               config_dir.split('/')[-1], 'README.md'),
            'Frameworks': ['PyTorch'],
        }
        results = {
            'Collections': [collection],
            'Models': models
        }, collection_name
    else:
        results = {'Models': models}, ''

    return results


def dump_yaml_and_check_difference(model_info: dict, filename: str) -> bool:
    """dump yaml file and check difference with the original file.

    Args:
        model_info (dict): model info dict.
        filename (str): filename to save.
    """
    str_dump = yaml.dump(model_info, sort_keys=False)
    if osp.isfile(filename):
        file_exist = True
        with open(filename, encoding='utf-8') as f:
            str_orig = f.read()
    else:
        str_orig = None
        file_exist = False

    if file_exist and str_orig == str_dump:
        is_different = False
    else:
        is_different = True
        with open(filename, 'w', encoding='utf-8') as f:
            f.write(str_dump)

    return is_different


def update_model_index(config_dir_list: List[str]) -> bool:
    """update model index."""
    yml_files = [
        osp.join('configs',
                 dir_name.split('/')[-1], 'metafile.yaml')
        for dir_name in config_dir_list
    ]
    yml_files.sort()

    model_index = {
        'Import': [
            osp.relpath(yml_file, MMSEG_ROOT).replace('\\', '/')
            for yml_file in yml_files
        ]
    }
    model_index_file = osp.join(MMSEG_ROOT, 'model-index.yml')
    return dump_yaml_and_check_difference(model_index, model_index_file)


if __name__ == '__main__':
    # get md file list
    md_file_list, config_dir_list = get_md_file_list()
    file_modified = False
    collection_name_list: List[str] = get_collection_name_list(md_file_list)
    # hard code to add 'FPN'
    collection_name_list.append('FPN')
    remove_config_dir_list = []
    # parse md file
    for md_file, config_dir in zip(md_file_list, config_dir_list):
        results, collection_name = get_model_info(md_file, config_dir,
                                                  collection_name_list)
        if results is None:
            remove_config_dir_list.append(config_dir)
            continue
        filename = osp.join(config_dir, 'metafile.yaml')
        file_modified |= dump_yaml_and_check_difference(results, filename)
        if collection_name != '':
            collection_name_list.append(collection_name)
    # remove config dir
    for config_dir in remove_config_dir_list:
        config_dir_list.remove(config_dir)
    file_modified |= update_model_index(config_dir_list)
    sys.exit(1 if file_modified else 0)