293 lines
11 KiB
Python
293 lines
11 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright (c) OpenMMLab. All rights reserved.
|
|
# This tool is used to update model-index.yml which is required by MIM, and
|
|
# will be automatically called as a pre-commit hook. The updating will be
|
|
# triggered if any change of model information (.md files in configs/) has been
|
|
# detected before a commit.
|
|
|
|
import os
|
|
import os.path as osp
|
|
import re
|
|
import sys
|
|
from typing import List, Tuple
|
|
|
|
import yaml
|
|
|
|
MMSEG_ROOT = osp.abspath(osp.join(osp.dirname(__file__), '..'))
|
|
|
|
|
|
def get_collection_name_list(md_file_list: List[str]) -> List[str]:
|
|
"""Get the list of collection names."""
|
|
collection_name_list: List[str] = []
|
|
for md_file in md_file_list:
|
|
with open(md_file) as f:
|
|
lines = f.readlines()
|
|
collection_name = lines[0].split('#')[1].strip()
|
|
collection_name_list.append(collection_name)
|
|
return collection_name_list
|
|
|
|
|
|
def get_md_file_list() -> Tuple[List[str], List[str]]:
|
|
"""Get the list of md files."""
|
|
md_file_list: List[str] = []
|
|
md_dir_list: List[str] = []
|
|
for root, _, files in os.walk(osp.join(MMSEG_ROOT, 'configs')):
|
|
for file in files:
|
|
if file.endswith('.md'):
|
|
md_file_list.append(osp.join(root, file))
|
|
md_dir_list.append(root)
|
|
break
|
|
return md_file_list, md_dir_list
|
|
|
|
|
|
def get_model_info(md_file: str, config_dir: str,
|
|
collection_name_list: List[str]) -> Tuple[dict, str]:
|
|
"""Get model information from md file."""
|
|
datasets: List[str] = []
|
|
models: List[dict] = []
|
|
current_dataset: str = ''
|
|
paper_name: str = ''
|
|
paper_url: str = ''
|
|
code_url: str = ''
|
|
is_backbone: bool = False
|
|
is_dataset: bool = False
|
|
collection_name: str = ''
|
|
with open(md_file) as f:
|
|
lines: List[str] = f.readlines()
|
|
i: int = 0
|
|
|
|
while i < len(lines):
|
|
line: str = lines[i].strip()
|
|
if len(line) == 0:
|
|
i += 1
|
|
continue
|
|
# get paper name and url
|
|
if re.match(r'> \[.*\]+\([a-zA-Z]+://[^\s]*\)', line):
|
|
paper_info = line.split('](')
|
|
paper_name = paper_info[0][paper_info[0].index('[') + 1:]
|
|
paper_url = paper_info[1][:len(paper_info[1]) - 1]
|
|
|
|
# get code info
|
|
if 'Code Snippet' in line:
|
|
code_url = line.split('"')[1].split('"')[0]
|
|
|
|
if line.startswith('<!-- [BACKBONE]'):
|
|
is_backbone = True
|
|
|
|
if line.startswith('<!-- [DATASET]'):
|
|
is_dataset = True
|
|
|
|
# get dataset names
|
|
if line.startswith('###'):
|
|
current_dataset = line.split('###')[1].strip()
|
|
datasets.append(current_dataset)
|
|
|
|
# get model info key id
|
|
if (line[0] == '|' and (i + 1) < len(lines)
|
|
and lines[i + 1][:3] == '| -' and 'Method' in line
|
|
and 'Crop Size' in line and 'Mem (GB)' in line):
|
|
keys: List[str] = [key.strip() for key in line.split('|')]
|
|
crop_size_idx: int = keys.index('Crop Size')
|
|
mem_idx: int = keys.index('Mem (GB)')
|
|
assert 'Device' in keys, f'No Device in {md_file}'
|
|
device_idx: int = keys.index('Device')
|
|
|
|
if 'mIoU' in keys:
|
|
ss_idx = keys.index('mIoU')
|
|
elif 'mDice' in keys:
|
|
ss_idx = keys.index('mDice')
|
|
else:
|
|
raise ValueError(f'No mIoU or mDice in {md_file}')
|
|
if 'mIoU(ms+flip)' in keys:
|
|
ms_idx = keys.index('mIoU(ms+flip)')
|
|
elif 'Dice' in keys:
|
|
ms_idx = keys.index('Dice')
|
|
else:
|
|
ms_idx = -1
|
|
config_idx = keys.index('config')
|
|
download_idx = keys.index('download')
|
|
j: int = i + 2
|
|
while j < len(lines) and lines[j][0] == '|':
|
|
values = [value.strip() for value in lines[j].split('|')]
|
|
# get config name
|
|
try:
|
|
config_url = re.findall(r'[a-zA-Z]+://[^\s]*py',
|
|
values[config_idx])[0]
|
|
config_name = config_url.split('/')[-1]
|
|
model_name = config_name.replace('.py', '')
|
|
except IndexError:
|
|
raise ValueError(
|
|
f'config url is not found in {md_file}')
|
|
|
|
# get model name
|
|
try:
|
|
weight_url = re.findall(r'[a-zA-Z]+://[^\s]*pth',
|
|
values[download_idx])[0]
|
|
log_url = re.findall(r'[a-zA-Z]+://[^\s]*.json',
|
|
values[download_idx + 1])[0]
|
|
except IndexError:
|
|
raise ValueError(
|
|
f'url is not found in {values[download_idx]}')
|
|
|
|
# get batch size
|
|
bs = re.findall(r'[0-9]*xb[0-9]*',
|
|
config_name)[0].split('xb')
|
|
batch_size = int(bs[0]) * int(bs[1])
|
|
|
|
# get crop size
|
|
crop_size = values[crop_size_idx].split('x')
|
|
crop_size = [int(crop_size[0]), int(crop_size[1])]
|
|
|
|
mem = values[mem_idx].split('\\')[0] if values[
|
|
mem_idx] != '-' and values[mem_idx] != '' else -1
|
|
|
|
method = values[keys.index('Method')].strip()
|
|
# method = [method.strip()] if '+' not in method else [
|
|
# m.strip() for m in method.split('+')
|
|
# ]
|
|
# split method name:
|
|
if ' + ' in method:
|
|
method = [m.strip() for m in method.split(' + ')]
|
|
elif ' ' in method:
|
|
method = [m for m in method.split(' ')]
|
|
else:
|
|
method = [method]
|
|
backone: str = re.findall(
|
|
r'[^\s]*', values[keys.index('Backbone')].strip())[0]
|
|
archs = [backone] + method
|
|
collection_name = method[0]
|
|
config_path = osp.join('configs',
|
|
config_dir.split('/')[-1],
|
|
config_name)
|
|
model = {
|
|
'Name': model_name,
|
|
'In Collection': collection_name,
|
|
'Results': {
|
|
'Task': 'Semantic Segmentation',
|
|
'Dataset': current_dataset,
|
|
'Metrics': {
|
|
keys[ss_idx]: float(values[ss_idx])
|
|
}
|
|
},
|
|
'Config': config_path,
|
|
'Metadata': {
|
|
'Training Data':
|
|
current_dataset,
|
|
'Batch Size':
|
|
batch_size,
|
|
'Architecture':
|
|
archs,
|
|
'Training Resources':
|
|
f'{bs[0]}x {values[device_idx]} GPUS',
|
|
},
|
|
'Weights': weight_url,
|
|
'Training log': log_url,
|
|
'Paper': {
|
|
'Title': paper_name,
|
|
'URL': paper_url
|
|
},
|
|
'Code': code_url,
|
|
'Framework': 'PyTorch'
|
|
}
|
|
if ms_idx != -1 and values[ms_idx] != '-' and values[
|
|
ms_idx] != '':
|
|
model['Results']['Metrics'].update(
|
|
{keys[ms_idx]: float(values[ms_idx])})
|
|
if mem != -1:
|
|
model['Metadata']['Memory (GB)'] = float(mem)
|
|
models.append(model)
|
|
j += 1
|
|
i = j
|
|
i += 1
|
|
|
|
if not (is_dataset
|
|
or is_backbone) or collection_name not in collection_name_list:
|
|
collection = {
|
|
'Name': collection_name,
|
|
'License': 'Apache License 2.0',
|
|
'Metadata': {
|
|
'Training Data': datasets
|
|
},
|
|
'Paper': {
|
|
'Title': paper_name,
|
|
'URL': paper_url,
|
|
},
|
|
'README': osp.join('configs',
|
|
config_dir.split('/')[-1], 'README.md'),
|
|
'Frameworks': ['PyTorch'],
|
|
}
|
|
results = {
|
|
'Collections': [collection],
|
|
'Models': models
|
|
}, collection_name
|
|
else:
|
|
results = {'Models': models}, ''
|
|
|
|
return results
|
|
|
|
|
|
def dump_yaml_and_check_difference(model_info: dict, filename: str) -> bool:
|
|
"""dump yaml file and check difference with the original file.
|
|
|
|
Args:
|
|
model_info (dict): model info dict.
|
|
filename (str): filename to save.
|
|
"""
|
|
str_dump = yaml.dump(model_info, sort_keys=False)
|
|
if osp.isfile(filename):
|
|
file_exist = True
|
|
with open(filename, encoding='utf-8') as f:
|
|
str_orig = f.read()
|
|
else:
|
|
str_orig = None
|
|
file_exist = False
|
|
|
|
if file_exist and str_orig == str_dump:
|
|
is_different = False
|
|
else:
|
|
is_different = True
|
|
with open(filename, 'w', encoding='utf-8') as f:
|
|
f.write(str_dump)
|
|
|
|
return is_different
|
|
|
|
|
|
def update_model_index(config_dir_list: List[str]) -> bool:
|
|
"""update model index."""
|
|
yml_files = [
|
|
osp.join('configs',
|
|
dir_name.split('/')[-1], 'metafile.yaml')
|
|
for dir_name in config_dir_list
|
|
]
|
|
yml_files.sort()
|
|
|
|
model_index = {
|
|
'Import': [
|
|
osp.relpath(yml_file, MMSEG_ROOT).replace('\\', '/')
|
|
for yml_file in yml_files
|
|
]
|
|
}
|
|
model_index_file = osp.join(MMSEG_ROOT, 'model-index.yml')
|
|
return dump_yaml_and_check_difference(model_index, model_index_file)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# get md file list
|
|
md_file_list, config_dir_list = get_md_file_list()
|
|
file_modified = False
|
|
collection_name_list: List[str] = get_collection_name_list(md_file_list)
|
|
# hard code to add 'FPN'
|
|
collection_name_list.append('FPN')
|
|
# parse md file
|
|
for md_file, config_dir in zip(md_file_list, config_dir_list):
|
|
results, collection_name = get_model_info(md_file, config_dir,
|
|
collection_name_list)
|
|
filename = osp.join(config_dir, 'metafile.yaml')
|
|
file_modified |= dump_yaml_and_check_difference(results, filename)
|
|
if collection_name != '':
|
|
collection_name_list.append(collection_name)
|
|
|
|
file_modified |= update_model_index(config_dir_list)
|
|
sys.exit(1 if file_modified else 0)
|