# flake8: noqa import argparse import re import warnings from pathlib import Path from modelindex.load_model_index import load from modelindex.models.ModelIndex import ModelIndex from tabulate import tabulate MMPT_ROOT = Path(__file__).absolute().parents[1] prog_description = """\ Use metafile to generate a README.md. Notice that the tool may fail in some corner cases, and you still need to check and fill some contents manually in the generated README. """ PREDICT_TEMPLATE = """\ **Predict image** ```python from mmpretrain import inference_model predict = inference_model('{model_name}', 'demo/bird.JPEG') print(predict['pred_class']) print(predict['pred_score']) ``` """ RETRIEVE_TEMPLATE = """\ **Retrieve image** ```python from mmpretrain import ImageRetrievalInferencer inferencer = ImageRetrievalInferencer('{model_name}', prototype='demo/') predict = inferencer('demo/dog.jpg', topk=2)[0] print(predict[0]) print(predict[1]) ``` """ USAGE_TEMPLATE = """\ **Use the model** ```python import torch from mmpretrain import get_model model = get_model('{model_name}', pretrained=True) inputs = torch.rand(1, 3, 224, 224) out = model(inputs) print(type(out)) # To extract features. feats = model.extract_feat(inputs) print(type(feats)) ``` """ TRAIN_TEST_TEMPLATE = """\ **Train/Test Command** Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset). Train: ```shell python tools/train.py {train_config} ``` Test: ```shell python tools/test.py {test_config} {test_weights} ``` """ TEST_ONLY_TEMPLATE = """\ **Test Command** Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset). Test: ```shell python tools/test.py {test_config} {test_weights} ``` """ METRIC_MAPPING = { 'Top 1 Accuracy': 'Top-1 (%)', 'Top 5 Accuracy': 'Top-5 (%)', } DATASET_PRIORITY = { 'ImageNet-1k': 0, 'CIFAR-10': 10, 'CIFAR-100': 20, } def parse_args(): parser = argparse.ArgumentParser(description=prog_description) parser.add_argument('metafile', type=Path, help='The path of metafile') parser.add_argument( '--table', action='store_true', help='Only generate summary tables') parser.add_argument( '--update', type=str, help='Update the specified readme file.') parser.add_argument('--out', type=str, help='Output to the file.') parser.add_argument( '--update-items', type=str, nargs='+', default=['models'], help='Update the specified readme file.') args = parser.parse_args() return args def filter_models_by_task(models, task): model_list = [] for model in models: if model.results is None and task is None: model_list.append(model) elif model.results is None: continue elif model.results[0].task == task or task == 'any': model_list.append(model) return model_list def add_title(metafile: ModelIndex): paper = metafile.collections[0].paper title = paper['Title'] url = paper['URL'] abbr = metafile.collections[0].name papertype = metafile.collections[0].data.get('type', 'Algorithm') return f'# {abbr}\n> [{title}]({url})\n\n' def add_abstract(metafile: ModelIndex): paper = metafile.collections[0].paper url = paper['URL'] if 'arxiv' in url: try: import arxiv search = arxiv.Search(id_list=[url.split('/')[-1]]) info = next(search.results()) abstract = info.summary.replace('\n', ' ') except ImportError: warnings.warn('Install arxiv parser by `pip install arxiv` ' 'to automatically generate abstract.') abstract = None else: abstract = None content = '## Abstract\n' if abstract is not None: content += f'\n{abstract}\n' return content def add_usage(metafile): models = metafile.models if len(models) == 0: return content = [] content.append('## How to use it?\n\n\n') # Predict image cls_models = filter_models_by_task(models, 'Image Classification') if cls_models: model_name = cls_models[0].name content.append(PREDICT_TEMPLATE.format(model_name=model_name)) # Retrieve image retrieval_models = filter_models_by_task(models, 'Image Retrieval') if retrieval_models: model_name = retrieval_models[0].name content.append(RETRIEVE_TEMPLATE.format(model_name=model_name)) # Use the model model_name = models[0].name content.append(USAGE_TEMPLATE.format(model_name=model_name)) # Train/Test Command inputs = {} train_model = [ model for model in models if 'headless' not in model.name and '3rdparty' not in model.name ] if train_model: template = TRAIN_TEST_TEMPLATE inputs['train_config'] = train_model[0].config else: template = TEST_ONLY_TEMPLATE test_model = filter_models_by_task(models, task='any')[0] inputs['test_config'] = test_model.config inputs['test_weights'] = test_model.weights content.append(template.format(**inputs)) content.append('\n\n') return '\n'.join(content) def format_pretrain(pretrain_field): pretrain_infos = pretrain_field.split('-')[:-1] infos = [] for info in pretrain_infos: if re.match('^\d+e$', info): info = f'{info[:-1]}-Epochs' elif re.match('^in\d+k$', info): info = f'ImageNet-{info[2:-1]}k' else: info = info.upper() infos.append(info) return ' '.join(infos) def generate_model_table(models, folder, with_pretrain=True, with_metric=True, pretrained_models=[]): header = ['Model'] if with_pretrain: header.append('Pretrain') header.extend(['Params (M)', 'Flops (G)']) if with_metric: metrics = set() for model in models: metrics.update(model.results[0].metrics.keys()) metrics = sorted(list(set(metrics))) for metric in metrics: header.append(METRIC_MAPPING.get(metric, metric)) header.extend(['Config', 'Download']) rows = [] for model in models: model_name = f'`{model.name}`' config = (MMPT_ROOT / model.config).relative_to(folder) if model.weights is not None: download = f'[model]({model.weights})' else: download = 'N/A' if 'Converted From' in model.data: model_name += '\*' converted_from = model.data['Converted From'] elif model.weights is not None: log = re.sub(r'.pth$', '.json', model.weights) download += f' \| [log]({log})' row = [model_name] if with_pretrain: pretrain_field = [ field for field in model.name.split('_') if field.endswith('-pre') ] if pretrain_field: pretrain = format_pretrain(pretrain_field[0]) upstream = [ pretrain_model for pretrain_model in pretrained_models if model.name in pretrain_model.data.get('Downstream', []) ] if upstream: pretrain = f'[{pretrain}]({upstream[0].weights})' else: pretrain = 'From scratch' row.append(pretrain) if model.metadata.parameters is not None: row.append(f'{model.metadata.parameters / 1e6:.2f}') # Params else: row.append('N/A') if model.metadata.flops is not None: row.append(f'{model.metadata.flops / 1e9:.2f}') # Params else: row.append('N/A') if with_metric: for metric in metrics: row.append(model.results[0].metrics.get(metric, 'N/A')) row.append(f'[config]({config})') row.append(download) rows.append(row) table_cfg = dict( tablefmt='pipe', floatfmt='.2f', colalign=['left'] + ['center'] * (len(row) - 1)) table_string = tabulate(rows, header, **table_cfg) + '\n' if any('Converted From' in model.data for model in models): table_string += ( f"\n*Models with \* are converted from the [official repo]({converted_from['Code']}). " "The config files of these models are only for inference. We haven't reprodcue the training results.*\n" ) return table_string def add_models(metafile): models = metafile.models if len(models) == 0: return '' content = ['## Models and results\n'] algo_folder = Path(metafile.filepath).parent.absolute().resolve() # Pretrained models pretrain_models = filter_models_by_task(models, task=None) if pretrain_models: content.append('### Pretrained models\n') content.append( generate_model_table( pretrain_models, algo_folder, with_pretrain=False, with_metric=False)) # Classification models tasks = [ 'Image Classification', 'Image Retrieval', 'Multi-Label Classification', ] for task in tasks: task_models = filter_models_by_task(models, task=task) if task_models: datasets = {model.results[0].dataset for model in task_models} datasets = sorted( list(datasets), key=lambda x: DATASET_PRIORITY.get(x, 50)) for dataset in datasets: content.append(f'### {task} on {dataset}\n') dataset_models = [ model for model in task_models if model.results[0].dataset == dataset ] content.append( generate_model_table( dataset_models, algo_folder, pretrained_models=pretrain_models)) return '\n'.join(content) def parse_readme(readme): with open(readme, 'r') as f: file = f.read() content = {} for img_match in re.finditer( '^\n', file, flags=re.MULTILINE): content['image'] = img_match.group() start, end = img_match.span() file = file[:start] + file[end:] break sections = re.split('^## ', file, flags=re.MULTILINE) for section in sections: if section.startswith('# '): content['title'] = section.strip() + '\n' elif section.startswith('Introduction'): content['intro'] = '## ' + section.strip() + '\n' elif section.startswith('Abstract'): content['abs'] = '## ' + section.strip() + '\n' elif section.startswith('How to use it'): content['usage'] = '## ' + section.strip() + '\n' elif section.startswith('Models and results'): content['models'] = '## ' + section.strip() + '\n' elif section.startswith('Citation'): content['citation'] = '## ' + section.strip() + '\n' else: section_title = section.split('\n', maxsplit=1)[0] content[section_title] = '## ' + section.strip() + '\n' return content def combine_readme(content: dict): content = content.copy() readme = content.pop('title') if 'intro' in content: readme += f"\n{content.pop('intro')}" readme += f"\n{content.pop('image')}" readme += f"\n{content.pop('abs')}" else: readme += f"\n{content.pop('abs')}" readme += f"\n{content.pop('image')}" readme += f"\n{content.pop('usage')}" readme += f"\n{content.pop('models')}" citation = content.pop('citation') if content: # Custom sections for v in content.values(): readme += f'\n{v}' readme += f'\n{citation}' return readme def main(): args = parse_args() metafile = load(str(args.metafile)) if args.table: print(add_models(metafile)) return if args.update is not None: content = parse_readme(args.update) else: content = {} if 'title' not in content or 'title' in args.update_items: content['title'] = add_title(metafile) if 'abs' not in content or 'abs' in args.update_items: content['abs'] = add_abstract(metafile) if 'image' not in content or 'image' in args.update_items: img = '
\n\n
\n' content['image'] = img if 'usage' not in content or 'usage' in args.update_items: content['usage'] = add_usage(metafile) if 'models' not in content or 'models' in args.update_items: content['models'] = add_models(metafile) if 'citation' not in content: content['citation'] = '## Citation\n```bibtex\n```\n' content = combine_readme(content) if args.out is not None: with open(args.out, 'w') as f: f.write(content) else: print(content) if __name__ == '__main__': main()