2022-12-19 13:53:13 +08:00
|
|
|
# flake8: noqa
|
|
|
|
import argparse
|
2023-03-02 13:29:07 +08:00
|
|
|
import re
|
2022-12-19 13:53:13 +08:00
|
|
|
import warnings
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
from modelindex.load_model_index import load
|
|
|
|
from modelindex.models.ModelIndex import ModelIndex
|
2023-03-02 13:29:07 +08:00
|
|
|
from tabulate import tabulate
|
|
|
|
|
|
|
|
MMPT_ROOT = Path(__file__).absolute().parents[1]
|
2022-12-19 13:53:13 +08:00
|
|
|
|
|
|
|
prog_description = """\
|
|
|
|
Use metafile to generate a README.md.
|
|
|
|
|
|
|
|
Notice that the tool may fail in some corner cases, and you still need to check and fill some contents manually in the generated README.
|
|
|
|
"""
|
|
|
|
|
2023-03-02 13:29:07 +08:00
|
|
|
PREDICT_TEMPLATE = """\
|
|
|
|
**Predict image**
|
|
|
|
|
|
|
|
```python
|
|
|
|
from mmpretrain import inference_model
|
|
|
|
|
|
|
|
predict = inference_model('{model_name}', 'demo/bird.JPEG')
|
|
|
|
print(predict['pred_class'])
|
|
|
|
print(predict['pred_score'])
|
|
|
|
```
|
|
|
|
"""
|
|
|
|
|
|
|
|
RETRIEVE_TEMPLATE = """\
|
|
|
|
**Retrieve image**
|
|
|
|
|
|
|
|
```python
|
|
|
|
from mmpretrain import ImageRetrievalInferencer
|
|
|
|
|
|
|
|
inferencer = ImageRetrievalInferencer('{model_name}', prototype='demo/')
|
|
|
|
predict = inferencer('demo/dog.jpg', topk=2)[0]
|
|
|
|
print(predict[0])
|
|
|
|
print(predict[1])
|
|
|
|
```
|
|
|
|
"""
|
|
|
|
|
|
|
|
USAGE_TEMPLATE = """\
|
|
|
|
**Use the model**
|
|
|
|
|
|
|
|
```python
|
|
|
|
import torch
|
|
|
|
from mmpretrain import get_model
|
|
|
|
|
|
|
|
model = get_model('{model_name}', pretrained=True)
|
|
|
|
inputs = torch.rand(1, 3, 224, 224)
|
|
|
|
out = model(inputs)
|
|
|
|
print(type(out))
|
|
|
|
# To extract features.
|
|
|
|
feats = model.extract_feat(inputs)
|
|
|
|
print(type(feats))
|
|
|
|
```
|
|
|
|
"""
|
|
|
|
|
|
|
|
TRAIN_TEST_TEMPLATE = """\
|
|
|
|
**Train/Test Command**
|
|
|
|
|
2023-04-06 20:58:52 +08:00
|
|
|
Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset).
|
2023-03-02 13:29:07 +08:00
|
|
|
|
|
|
|
Train:
|
|
|
|
|
|
|
|
```shell
|
|
|
|
python tools/train.py {train_config}
|
|
|
|
```
|
|
|
|
|
|
|
|
Test:
|
|
|
|
|
|
|
|
```shell
|
|
|
|
python tools/test.py {test_config} {test_weights}
|
|
|
|
```
|
|
|
|
"""
|
|
|
|
|
|
|
|
TEST_ONLY_TEMPLATE = """\
|
|
|
|
**Test Command**
|
|
|
|
|
2023-04-06 20:58:52 +08:00
|
|
|
Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset).
|
2023-03-02 13:29:07 +08:00
|
|
|
|
|
|
|
Test:
|
|
|
|
|
|
|
|
```shell
|
|
|
|
python tools/test.py {test_config} {test_weights}
|
|
|
|
```
|
|
|
|
"""
|
|
|
|
|
|
|
|
METRIC_MAPPING = {
|
|
|
|
'Top 1 Accuracy': 'Top-1 (%)',
|
|
|
|
'Top 5 Accuracy': 'Top-5 (%)',
|
|
|
|
}
|
|
|
|
|
|
|
|
DATASET_PRIORITY = {
|
|
|
|
'ImageNet-1k': 0,
|
|
|
|
'CIFAR-10': 10,
|
|
|
|
'CIFAR-100': 20,
|
|
|
|
}
|
|
|
|
|
2022-12-19 13:53:13 +08:00
|
|
|
|
|
|
|
def parse_args():
|
|
|
|
parser = argparse.ArgumentParser(description=prog_description)
|
|
|
|
parser.add_argument('metafile', type=Path, help='The path of metafile')
|
|
|
|
parser.add_argument(
|
|
|
|
'--table', action='store_true', help='Only generate summary tables')
|
2023-03-02 13:29:07 +08:00
|
|
|
parser.add_argument(
|
|
|
|
'--update', type=str, help='Update the specified readme file.')
|
2023-03-20 14:30:57 +08:00
|
|
|
parser.add_argument('--out', type=str, help='Output to the file.')
|
2023-03-02 13:29:07 +08:00
|
|
|
parser.add_argument(
|
|
|
|
'--update-items',
|
|
|
|
type=str,
|
|
|
|
nargs='+',
|
|
|
|
default=['models'],
|
|
|
|
help='Update the specified readme file.')
|
2022-12-19 13:53:13 +08:00
|
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
2023-03-02 13:29:07 +08:00
|
|
|
def filter_models_by_task(models, task):
|
|
|
|
model_list = []
|
|
|
|
for model in models:
|
|
|
|
if model.results is None and task is None:
|
|
|
|
model_list.append(model)
|
|
|
|
elif model.results is None:
|
|
|
|
continue
|
|
|
|
elif model.results[0].task == task or task == 'any':
|
|
|
|
model_list.append(model)
|
|
|
|
return model_list
|
|
|
|
|
|
|
|
|
|
|
|
def add_title(metafile: ModelIndex):
|
2022-12-19 13:53:13 +08:00
|
|
|
paper = metafile.collections[0].paper
|
|
|
|
title = paper['Title']
|
|
|
|
url = paper['URL']
|
|
|
|
abbr = metafile.collections[0].name
|
|
|
|
papertype = metafile.collections[0].data.get('type', 'Algorithm')
|
|
|
|
|
2023-03-02 13:29:07 +08:00
|
|
|
return f'# {abbr}\n> [{title}]({url})\n<!-- [{papertype.upper()}] -->\n'
|
2022-12-19 13:53:13 +08:00
|
|
|
|
|
|
|
|
2023-03-02 13:29:07 +08:00
|
|
|
def add_abstract(metafile: ModelIndex):
|
2022-12-19 13:53:13 +08:00
|
|
|
paper = metafile.collections[0].paper
|
|
|
|
url = paper['URL']
|
|
|
|
if 'arxiv' in url:
|
|
|
|
try:
|
|
|
|
import arxiv
|
|
|
|
search = arxiv.Search(id_list=[url.split('/')[-1]])
|
|
|
|
info = next(search.results())
|
2023-03-02 13:29:07 +08:00
|
|
|
abstract = info.summary.replace('\n', ' ')
|
2022-12-19 13:53:13 +08:00
|
|
|
except ImportError:
|
|
|
|
warnings.warn('Install arxiv parser by `pip install arxiv` '
|
|
|
|
'to automatically generate abstract.')
|
|
|
|
abstract = None
|
|
|
|
else:
|
|
|
|
abstract = None
|
|
|
|
|
2023-03-02 13:29:07 +08:00
|
|
|
content = '## Abstract\n'
|
2022-12-19 13:53:13 +08:00
|
|
|
if abstract is not None:
|
2023-03-02 13:29:07 +08:00
|
|
|
content += f'\n{abstract}\n'
|
|
|
|
return content
|
2022-12-19 13:53:13 +08:00
|
|
|
|
|
|
|
|
2023-03-02 13:29:07 +08:00
|
|
|
def add_usage(metafile):
|
2022-12-19 13:53:13 +08:00
|
|
|
models = metafile.models
|
|
|
|
if len(models) == 0:
|
|
|
|
return
|
|
|
|
|
2023-03-02 13:29:07 +08:00
|
|
|
content = []
|
|
|
|
content.append('## How to use it?\n\n<!-- [TABS-BEGIN] -->\n')
|
|
|
|
|
|
|
|
# Predict image
|
|
|
|
cls_models = filter_models_by_task(models, 'Image Classification')
|
|
|
|
if cls_models:
|
|
|
|
model_name = cls_models[0].name
|
|
|
|
content.append(PREDICT_TEMPLATE.format(model_name=model_name))
|
|
|
|
|
|
|
|
# Retrieve image
|
|
|
|
retrieval_models = filter_models_by_task(models, 'Image Retrieval')
|
|
|
|
if retrieval_models:
|
|
|
|
model_name = retrieval_models[0].name
|
|
|
|
content.append(RETRIEVE_TEMPLATE.format(model_name=model_name))
|
|
|
|
|
|
|
|
# Use the model
|
|
|
|
model_name = models[0].name
|
|
|
|
content.append(USAGE_TEMPLATE.format(model_name=model_name))
|
|
|
|
|
|
|
|
# Train/Test Command
|
|
|
|
inputs = {}
|
|
|
|
train_model = [
|
|
|
|
model for model in models
|
|
|
|
if 'headless' not in model.name and '3rdparty' not in model.name
|
|
|
|
]
|
|
|
|
if train_model:
|
|
|
|
template = TRAIN_TEST_TEMPLATE
|
|
|
|
inputs['train_config'] = train_model[0].config
|
|
|
|
else:
|
|
|
|
template = TEST_ONLY_TEMPLATE
|
|
|
|
test_model = filter_models_by_task(models, task='any')[0]
|
|
|
|
inputs['test_config'] = test_model.config
|
|
|
|
inputs['test_weights'] = test_model.weights
|
|
|
|
content.append(template.format(**inputs))
|
|
|
|
|
|
|
|
content.append('\n<!-- [TABS-END] -->\n')
|
|
|
|
return '\n'.join(content)
|
|
|
|
|
|
|
|
|
|
|
|
def format_pretrain(pretrain_field):
|
|
|
|
pretrain_infos = pretrain_field.split('-')[:-1]
|
|
|
|
infos = []
|
|
|
|
for info in pretrain_infos:
|
|
|
|
if re.match('^\d+e$', info):
|
|
|
|
info = f'{info[:-1]}-Epochs'
|
|
|
|
elif re.match('^in\d+k$', info):
|
|
|
|
info = f'ImageNet-{info[2:-1]}k'
|
|
|
|
else:
|
|
|
|
info = info.upper()
|
|
|
|
infos.append(info)
|
|
|
|
return ' '.join(infos)
|
|
|
|
|
2022-12-19 13:53:13 +08:00
|
|
|
|
2023-03-02 13:29:07 +08:00
|
|
|
def generate_model_table(models,
|
|
|
|
folder,
|
|
|
|
with_pretrain=True,
|
|
|
|
with_metric=True,
|
|
|
|
pretrained_models=[]):
|
|
|
|
header = ['Model']
|
|
|
|
if with_pretrain:
|
|
|
|
header.append('Pretrain')
|
|
|
|
header.extend(['Params (M)', 'Flops (G)'])
|
|
|
|
if with_metric:
|
|
|
|
metrics = set()
|
|
|
|
for model in models:
|
|
|
|
metrics.update(model.results[0].metrics.keys())
|
|
|
|
metrics = sorted(list(set(metrics)))
|
|
|
|
for metric in metrics:
|
|
|
|
header.append(METRIC_MAPPING.get(metric, metric))
|
|
|
|
header.extend(['Config', 'Download'])
|
|
|
|
|
|
|
|
rows = []
|
2022-12-19 13:53:13 +08:00
|
|
|
for model in models:
|
2023-03-02 13:29:07 +08:00
|
|
|
model_name = f'`{model.name}`'
|
|
|
|
config = (MMPT_ROOT / model.config).relative_to(folder)
|
|
|
|
if model.weights is not None:
|
|
|
|
download = f'[model]({model.weights})'
|
|
|
|
else:
|
|
|
|
download = 'N/A'
|
|
|
|
|
|
|
|
if 'Converted From' in model.data:
|
|
|
|
model_name += '\*'
|
|
|
|
converted_from = model.data['Converted From']
|
|
|
|
elif model.weights is not None:
|
|
|
|
log = re.sub(r'.pth$', '.json', model.weights)
|
|
|
|
download += f' \| [log]({log})'
|
|
|
|
|
|
|
|
row = [model_name]
|
|
|
|
if with_pretrain:
|
|
|
|
pretrain_field = [
|
|
|
|
field for field in model.name.split('_')
|
|
|
|
if field.endswith('-pre')
|
|
|
|
]
|
|
|
|
if pretrain_field:
|
|
|
|
pretrain = format_pretrain(pretrain_field[0])
|
|
|
|
upstream = [
|
|
|
|
pretrain_model for pretrain_model in pretrained_models
|
|
|
|
if model.name in pretrain_model.data.get('Downstream', [])
|
|
|
|
]
|
|
|
|
if upstream:
|
|
|
|
pretrain = f'[{pretrain}]({upstream[0].weights})'
|
|
|
|
else:
|
|
|
|
pretrain = 'From scratch'
|
|
|
|
row.append(pretrain)
|
|
|
|
|
|
|
|
if model.metadata.parameters is not None:
|
|
|
|
row.append(f'{model.metadata.parameters / 1e6:.2f}') # Params
|
2022-12-19 13:53:13 +08:00
|
|
|
else:
|
2023-03-02 13:29:07 +08:00
|
|
|
row.append('N/A')
|
|
|
|
if model.metadata.flops is not None:
|
|
|
|
row.append(f'{model.metadata.flops / 1e9:.2f}') # Params
|
|
|
|
else:
|
|
|
|
row.append('N/A')
|
|
|
|
|
|
|
|
if with_metric:
|
|
|
|
for metric in metrics:
|
|
|
|
row.append(model.results[0].metrics.get(metric, 'N/A'))
|
|
|
|
row.append(f'[config]({config})')
|
|
|
|
row.append(download)
|
|
|
|
|
|
|
|
rows.append(row)
|
|
|
|
|
|
|
|
table_cfg = dict(
|
|
|
|
tablefmt='pipe',
|
|
|
|
floatfmt='.2f',
|
|
|
|
colalign=['left'] + ['center'] * (len(row) - 1))
|
|
|
|
table_string = tabulate(rows, header, **table_cfg) + '\n'
|
|
|
|
if any('Converted From' in model.data for model in models):
|
|
|
|
table_string += (
|
|
|
|
f"\n*Models with \* are converted from the [official repo]({converted_from['Code']}). "
|
|
|
|
"The config files of these models are only for inference. We haven't reprodcue the training results.*\n"
|
|
|
|
)
|
|
|
|
|
|
|
|
return table_string
|
|
|
|
|
|
|
|
|
|
|
|
def add_models(metafile):
|
|
|
|
models = metafile.models
|
|
|
|
if len(models) == 0:
|
|
|
|
return ''
|
|
|
|
|
|
|
|
content = ['## Models and results\n']
|
|
|
|
algo_folder = Path(metafile.filepath).parent.absolute().resolve()
|
|
|
|
|
|
|
|
# Pretrained models
|
|
|
|
pretrain_models = filter_models_by_task(models, task=None)
|
|
|
|
if pretrain_models:
|
|
|
|
content.append('### Pretrained models\n')
|
|
|
|
content.append(
|
|
|
|
generate_model_table(
|
|
|
|
pretrain_models,
|
|
|
|
algo_folder,
|
|
|
|
with_pretrain=False,
|
|
|
|
with_metric=False))
|
|
|
|
|
|
|
|
# Classification models
|
|
|
|
tasks = [
|
|
|
|
'Image Classification',
|
|
|
|
'Image Retrieval',
|
|
|
|
'Multi-Label Classification',
|
|
|
|
]
|
|
|
|
|
|
|
|
for task in tasks:
|
|
|
|
task_models = filter_models_by_task(models, task=task)
|
|
|
|
if task_models:
|
|
|
|
datasets = {model.results[0].dataset for model in task_models}
|
|
|
|
datasets = sorted(
|
|
|
|
list(datasets), key=lambda x: DATASET_PRIORITY.get(x, 50))
|
|
|
|
for dataset in datasets:
|
|
|
|
content.append(f'### {task} on {dataset}\n')
|
|
|
|
dataset_models = [
|
|
|
|
model for model in task_models
|
|
|
|
if model.results[0].dataset == dataset
|
|
|
|
]
|
|
|
|
content.append(
|
|
|
|
generate_model_table(
|
|
|
|
dataset_models,
|
|
|
|
algo_folder,
|
|
|
|
pretrained_models=pretrain_models))
|
|
|
|
return '\n'.join(content)
|
|
|
|
|
|
|
|
|
|
|
|
def parse_readme(readme):
|
|
|
|
with open(readme, 'r') as f:
|
|
|
|
file = f.read()
|
|
|
|
|
|
|
|
content = {}
|
|
|
|
|
|
|
|
for img_match in re.finditer(
|
|
|
|
'^<div.*\n.*\n</div>\n', file, flags=re.MULTILINE):
|
|
|
|
content['image'] = img_match.group()
|
|
|
|
start, end = img_match.span()
|
|
|
|
file = file[:start] + file[end:]
|
|
|
|
break
|
|
|
|
|
|
|
|
sections = re.split('^## ', file, flags=re.MULTILINE)
|
|
|
|
for section in sections:
|
|
|
|
if section.startswith('# '):
|
|
|
|
content['title'] = section.strip() + '\n'
|
|
|
|
elif section.startswith('Introduction'):
|
|
|
|
content['intro'] = '## ' + section.strip() + '\n'
|
|
|
|
elif section.startswith('Abstract'):
|
|
|
|
content['abs'] = '## ' + section.strip() + '\n'
|
|
|
|
elif section.startswith('How to use it'):
|
|
|
|
content['usage'] = '## ' + section.strip() + '\n'
|
|
|
|
elif section.startswith('Models and results'):
|
|
|
|
content['models'] = '## ' + section.strip() + '\n'
|
|
|
|
elif section.startswith('Citation'):
|
|
|
|
content['citation'] = '## ' + section.strip() + '\n'
|
2022-12-19 13:53:13 +08:00
|
|
|
else:
|
2023-03-02 13:29:07 +08:00
|
|
|
section_title = section.split('\n', maxsplit=1)[0]
|
|
|
|
content[section_title] = '## ' + section.strip() + '\n'
|
|
|
|
return content
|
|
|
|
|
|
|
|
|
|
|
|
def combine_readme(content: dict):
|
|
|
|
content = content.copy()
|
|
|
|
readme = content.pop('title')
|
|
|
|
if 'intro' in content:
|
|
|
|
readme += f"\n{content.pop('intro')}"
|
|
|
|
readme += f"\n{content.pop('image')}"
|
|
|
|
readme += f"\n{content.pop('abs')}"
|
|
|
|
else:
|
|
|
|
readme += f"\n{content.pop('abs')}"
|
|
|
|
readme += f"\n{content.pop('image')}"
|
|
|
|
|
|
|
|
readme += f"\n{content.pop('usage')}"
|
|
|
|
readme += f"\n{content.pop('models')}"
|
|
|
|
|
|
|
|
citation = content.pop('citation')
|
|
|
|
if content:
|
|
|
|
# Custom sections
|
|
|
|
for v in content.values():
|
|
|
|
readme += f'\n{v}'
|
|
|
|
readme += f'\n{citation}'
|
|
|
|
return readme
|
2022-12-19 13:53:13 +08:00
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
args = parse_args()
|
|
|
|
metafile = load(str(args.metafile))
|
2023-03-02 13:29:07 +08:00
|
|
|
if args.table:
|
|
|
|
print(add_models(metafile))
|
|
|
|
return
|
|
|
|
|
|
|
|
if args.update is not None:
|
|
|
|
content = parse_readme(args.update)
|
|
|
|
else:
|
|
|
|
content = {}
|
|
|
|
|
|
|
|
if 'title' not in content or 'title' in args.update_items:
|
|
|
|
content['title'] = add_title(metafile)
|
|
|
|
if 'abs' not in content or 'abs' in args.update_items:
|
|
|
|
content['abs'] = add_abstract(metafile)
|
|
|
|
if 'image' not in content or 'image' in args.update_items:
|
|
|
|
img = '<div align=center>\n<img src="" width="50%"/>\n</div>\n'
|
|
|
|
content['image'] = img
|
|
|
|
if 'usage' not in content or 'usage' in args.update_items:
|
|
|
|
content['usage'] = add_usage(metafile)
|
|
|
|
if 'models' not in content or 'models' in args.update_items:
|
|
|
|
content['models'] = add_models(metafile)
|
|
|
|
if 'citation' not in content:
|
|
|
|
content['citation'] = '## Citation\n```bibtex\n```\n'
|
|
|
|
|
2023-03-20 14:30:57 +08:00
|
|
|
content = combine_readme(content)
|
|
|
|
if args.out is not None:
|
|
|
|
with open(args.out, 'w') as f:
|
|
|
|
f.write(content)
|
|
|
|
else:
|
|
|
|
print(content)
|
2022-12-19 13:53:13 +08:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
main()
|