mmsegmentation/.dev/update_model_index.py

293 lines
11 KiB
Python

#!/usr/bin/env python
# Copyright (c) OpenMMLab. All rights reserved.
# This tool is used to update model-index.yml which is required by MIM, and
# will be automatically called as a pre-commit hook. The updating will be
# triggered if any change of model information (.md files in configs/) has been
# detected before a commit.
import os
import os.path as osp
import re
import sys
from typing import List, Tuple
import yaml
MMSEG_ROOT = osp.abspath(osp.join(osp.dirname(__file__), '..'))
def get_collection_name_list(md_file_list: List[str]) -> List[str]:
"""Get the list of collection names."""
collection_name_list: List[str] = []
for md_file in md_file_list:
with open(md_file) as f:
lines = f.readlines()
collection_name = lines[0].split('#')[1].strip()
collection_name_list.append(collection_name)
return collection_name_list
def get_md_file_list() -> Tuple[List[str], List[str]]:
"""Get the list of md files."""
md_file_list: List[str] = []
md_dir_list: List[str] = []
for root, _, files in os.walk(osp.join(MMSEG_ROOT, 'configs')):
for file in files:
if file.endswith('.md'):
md_file_list.append(osp.join(root, file))
md_dir_list.append(root)
break
return md_file_list, md_dir_list
def get_model_info(md_file: str, config_dir: str,
collection_name_list: List[str]) -> Tuple[dict, str]:
"""Get model information from md file."""
datasets: List[str] = []
models: List[dict] = []
current_dataset: str = ''
paper_name: str = ''
paper_url: str = ''
code_url: str = ''
is_backbone: bool = False
is_dataset: bool = False
collection_name: str = ''
with open(md_file) as f:
lines: List[str] = f.readlines()
i: int = 0
while i < len(lines):
line: str = lines[i].strip()
if len(line) == 0:
i += 1
continue
# get paper name and url
if re.match(r'> \[.*\]+\([a-zA-Z]+://[^\s]*\)', line):
paper_info = line.split('](')
paper_name = paper_info[0][paper_info[0].index('[') + 1:]
paper_url = paper_info[1][:len(paper_info[1]) - 1]
# get code info
if 'Code Snippet' in line:
code_url = line.split('"')[1].split('"')[0]
if line.startswith('<!-- [BACKBONE]'):
is_backbone = True
if line.startswith('<!-- [DATASET]'):
is_dataset = True
# get dataset names
if line.startswith('###'):
current_dataset = line.split('###')[1].strip()
datasets.append(current_dataset)
# get model info key id
if (line[0] == '|' and (i + 1) < len(lines)
and lines[i + 1][:3] == '| -' and 'Method' in line
and 'Crop Size' in line and 'Mem (GB)' in line):
keys: List[str] = [key.strip() for key in line.split('|')]
crop_size_idx: int = keys.index('Crop Size')
mem_idx: int = keys.index('Mem (GB)')
assert 'Device' in keys, f'No Device in {md_file}'
device_idx: int = keys.index('Device')
if 'mIoU' in keys:
ss_idx = keys.index('mIoU')
elif 'mDice' in keys:
ss_idx = keys.index('mDice')
else:
raise ValueError(f'No mIoU or mDice in {md_file}')
if 'mIoU(ms+flip)' in keys:
ms_idx = keys.index('mIoU(ms+flip)')
elif 'Dice' in keys:
ms_idx = keys.index('Dice')
else:
ms_idx = -1
config_idx = keys.index('config')
download_idx = keys.index('download')
j: int = i + 2
while j < len(lines) and lines[j][0] == '|':
values = [value.strip() for value in lines[j].split('|')]
# get config name
try:
config_url = re.findall(r'[a-zA-Z]+://[^\s]*py',
values[config_idx])[0]
config_name = config_url.split('/')[-1]
model_name = config_name.replace('.py', '')
except IndexError:
raise ValueError(
f'config url is not found in {md_file}')
# get model name
try:
weight_url = re.findall(r'[a-zA-Z]+://[^\s]*pth',
values[download_idx])[0]
log_url = re.findall(r'[a-zA-Z]+://[^\s]*.json',
values[download_idx + 1])[0]
except IndexError:
raise ValueError(
f'url is not found in {values[download_idx]}')
# get batch size
bs = re.findall(r'[0-9]*xb[0-9]*',
config_name)[0].split('xb')
batch_size = int(bs[0]) * int(bs[1])
# get crop size
crop_size = values[crop_size_idx].split('x')
crop_size = [int(crop_size[0]), int(crop_size[1])]
mem = values[mem_idx].split('\\')[0] if values[
mem_idx] != '-' and values[mem_idx] != '' else -1
method = values[keys.index('Method')].strip()
# method = [method.strip()] if '+' not in method else [
# m.strip() for m in method.split('+')
# ]
# split method name:
if ' + ' in method:
method = [m.strip() for m in method.split(' + ')]
elif ' ' in method:
method = [m for m in method.split(' ')]
else:
method = [method]
backone: str = re.findall(
r'[^\s]*', values[keys.index('Backbone')].strip())[0]
archs = [backone] + method
collection_name = method[0]
config_path = osp.join('configs',
config_dir.split('/')[-1],
config_name)
model = {
'Name': model_name,
'In Collection': collection_name,
'Results': {
'Task': 'Semantic Segmentation',
'Dataset': current_dataset,
'Metrics': {
keys[ss_idx]: float(values[ss_idx])
}
},
'Config': config_path,
'Metadata': {
'Training Data':
current_dataset,
'Batch Size':
batch_size,
'Architecture':
archs,
'Training Resources':
f'{bs[0]}x {values[device_idx]} GPUS',
},
'Weights': weight_url,
'Training log': log_url,
'Paper': {
'Title': paper_name,
'URL': paper_url
},
'Code': code_url,
'Framework': 'PyTorch'
}
if ms_idx != -1 and values[ms_idx] != '-' and values[
ms_idx] != '':
model['Results']['Metrics'].update(
{keys[ms_idx]: float(values[ms_idx])})
if mem != -1:
model['Metadata']['Memory (GB)'] = float(mem)
models.append(model)
j += 1
i = j
i += 1
if not (is_dataset
or is_backbone) or collection_name not in collection_name_list:
collection = {
'Name': collection_name,
'License': 'Apache License 2.0',
'Metadata': {
'Training Data': datasets
},
'Paper': {
'Title': paper_name,
'URL': paper_url,
},
'README': osp.join('configs',
config_dir.split('/')[-1], 'README.md'),
'Frameworks': ['PyTorch'],
}
results = {
'Collections': [collection],
'Models': models
}, collection_name
else:
results = {'Models': models}, ''
return results
def dump_yaml_and_check_difference(model_info: dict, filename: str) -> bool:
"""dump yaml file and check difference with the original file.
Args:
model_info (dict): model info dict.
filename (str): filename to save.
"""
str_dump = yaml.dump(model_info, sort_keys=False)
if osp.isfile(filename):
file_exist = True
with open(filename, encoding='utf-8') as f:
str_orig = f.read()
else:
str_orig = None
file_exist = False
if file_exist and str_orig == str_dump:
is_different = False
else:
is_different = True
with open(filename, 'w', encoding='utf-8') as f:
f.write(str_dump)
return is_different
def update_model_index(config_dir_list: List[str]) -> bool:
"""update model index."""
yml_files = [
osp.join('configs',
dir_name.split('/')[-1], 'metafile.yaml')
for dir_name in config_dir_list
]
yml_files.sort()
model_index = {
'Import': [
osp.relpath(yml_file, MMSEG_ROOT).replace('\\', '/')
for yml_file in yml_files
]
}
model_index_file = osp.join(MMSEG_ROOT, 'model-index.yml')
return dump_yaml_and_check_difference(model_index, model_index_file)
if __name__ == '__main__':
# get md file list
md_file_list, config_dir_list = get_md_file_list()
file_modified = False
collection_name_list: List[str] = get_collection_name_list(md_file_list)
# hard code to add 'FPN'
collection_name_list.append('FPN')
# parse md file
for md_file, config_dir in zip(md_file_list, config_dir_list):
results, collection_name = get_model_info(md_file, config_dir,
collection_name_list)
filename = osp.join(config_dir, 'metafile.yaml')
file_modified |= dump_yaml_and_check_difference(results, filename)
if collection_name != '':
collection_name_list.append(collection_name)
file_modified |= update_model_index(config_dir_list)
sys.exit(1 if file_modified else 0)