mmsegmentation/.dev/md2yml.py

318 lines
12 KiB
Python

#!/usr/bin/env python
# Copyright (c) OpenMMLab. All rights reserved.
# This tool is used to update model-index.yml which is required by MIM, and
# will be automatically called as a pre-commit hook. The updating will be
# triggered if any change of model information (.md files in configs/) has been
# detected before a commit.
import glob
import os
import os.path as osp
import re
import sys
from lxml import etree
from mmengine.fileio import dump
MMSEG_ROOT = osp.dirname(osp.dirname((osp.dirname(__file__))))
COLLECTIONS = [
'ANN', 'APCNet', 'BiSeNetV1', 'BiSeNetV2', 'CCNet', 'CGNet', 'DANet',
'DeepLabV3', 'DeepLabV3+', 'DMNet', 'DNLNet', 'DPT', 'EMANet', 'EncNet',
'ERFNet', 'FastFCN', 'FastSCNN', 'FCN', 'GCNet', 'ICNet', 'ISANet', 'KNet',
'NonLocalNet', 'OCRNet', 'PointRend', 'PSANet', 'PSPNet', 'Segformer',
'Segmenter', 'FPN', 'SETR', 'STDC', 'UNet', 'UPerNet'
]
COLLECTIONS_TEMP = []
def dump_yaml_and_check_difference(obj, filename, sort_keys=False):
"""Dump object to a yaml file, and check if the file content is different
from the original.
Args:
obj (any): The python object to be dumped.
filename (str): YAML filename to dump the object to.
sort_keys (str); Sort key by dictionary order.
Returns:
Bool: If the target YAML file is different from the original.
"""
str_dump = dump(obj, None, file_format='yaml', sort_keys=sort_keys)
if osp.isfile(filename):
file_exists = True
with open(filename, 'r', encoding='utf-8') as f:
str_orig = f.read()
else:
file_exists = False
str_orig = None
if file_exists and str_orig == str_dump:
is_different = False
else:
is_different = True
with open(filename, 'w', encoding='utf-8') as f:
f.write(str_dump)
return is_different
def parse_md(md_file):
"""Parse .md file and convert it to a .yml file which can be used for MIM.
Args:
md_file (str): Path to .md file.
Returns:
Bool: If the target YAML file is different from the original.
"""
collection_name = osp.split(osp.dirname(md_file))[1]
configs = os.listdir(osp.dirname(md_file))
collection = dict(
Name=collection_name,
Metadata={'Training Data': []},
Paper={
'URL': '',
'Title': ''
},
README=md_file,
Code={
'URL': '',
'Version': ''
})
collection.update({'Converted From': {'Weights': '', 'Code': ''}})
models = []
datasets = []
paper_url = None
paper_title = None
code_url = None
code_version = None
repo_url = None
# To avoid re-counting number of backbone model in OpenMMLab,
# if certain model in configs folder is backbone whose name is already
# recorded in MMClassification, then the `COLLECTION` dict of this model
# in MMSegmentation should be deleted, and `In Collection` in `Models`
# should be set with head or neck of this config file.
is_backbone = None
with open(md_file, 'r', encoding='UTF-8') as md:
lines = md.readlines()
i = 0
current_dataset = ''
while i < len(lines):
line = lines[i].strip()
# In latest README.md the title and url are in the third line.
if i == 2:
paper_url = lines[i].split('](')[1].split(')')[0]
paper_title = lines[i].split('](')[0].split('[')[1]
if len(line) == 0:
i += 1
continue
elif line[:3] == '<a ':
content = etree.HTML(line)
node = content.xpath('//a')[0]
if node.text == 'Code Snippet':
code_url = node.get('href', None)
assert code_url is not None, (
f'{collection_name} hasn\'t code snippet url.')
# version extraction
filter_str = r'blob/(.*)/mm'
pattern = re.compile(filter_str)
code_version = pattern.findall(code_url)
assert len(code_version) == 1, (
f'false regular expression ({filter_str}) use.')
code_version = code_version[0]
elif node.text == 'Official Repo':
repo_url = node.get('href', None)
assert repo_url is not None, (
f'{collection_name} hasn\'t official repo url.')
i += 1
elif line[:4] == '### ':
datasets.append(line[4:])
current_dataset = line[4:]
i += 2
elif line[:15] == '<!-- [BACKBONE]':
is_backbone = True
i += 1
elif (line[0] == '|' and (i + 1) < len(lines)
and lines[i + 1][:3] == '| -' and 'Method' in line
and 'Crop Size' in line and 'Mem (GB)' in line):
cols = [col.strip() for col in line.split('|')]
method_id = cols.index('Method')
backbone_id = cols.index('Backbone')
crop_size_id = cols.index('Crop Size')
lr_schd_id = cols.index('Lr schd')
mem_id = cols.index('Mem (GB)')
fps_id = cols.index('Inf time (fps)')
try:
ss_id = cols.index('mIoU')
except ValueError:
ss_id = cols.index('Dice')
try:
ms_id = cols.index('mIoU(ms+flip)')
except ValueError:
ms_id = False
config_id = cols.index('config')
download_id = cols.index('download')
j = i + 2
while j < len(lines) and lines[j][0] == '|':
els = [el.strip() for el in lines[j].split('|')]
config = ''
model_name = ''
weight = ''
for fn in configs:
if fn in els[config_id]:
left = els[download_id].index(
'https://download.openmmlab.com')
right = els[download_id].index('.pth') + 4
weight = els[download_id][left:right]
config = f'configs/{collection_name}/{fn}'
model_name = fn[:-3]
fps = els[fps_id] if els[fps_id] != '-' and els[
fps_id] != '' else -1
mem = els[mem_id].split(
'\\'
)[0] if els[mem_id] != '-' and els[mem_id] != '' else -1
crop_size = els[crop_size_id].split('x')
assert len(crop_size) == 2
method = els[method_id].split()[0].split('-')[-1]
model = {
'Name':
model_name,
'In Collection':
method,
'Metadata': {
'backbone': els[backbone_id],
'crop size': f'({crop_size[0]},{crop_size[1]})',
'lr schd': int(els[lr_schd_id]),
},
'Results': [
{
'Task': 'Semantic Segmentation',
'Dataset': current_dataset,
'Metrics': {
cols[ss_id]: float(els[ss_id]),
},
},
],
'Config':
config,
'Weights':
weight,
}
if fps != -1:
try:
fps = float(fps)
except Exception:
j += 1
continue
model['Metadata']['inference time (ms/im)'] = [{
'value':
round(1000 / float(fps), 2),
'hardware':
'V100',
'backend':
'PyTorch',
'batch size':
1,
'mode':
'FP32' if 'amp' not in config else 'AMP',
'resolution':
f'({crop_size[0]},{crop_size[1]})'
}]
if mem != -1:
model['Metadata']['Training Memory (GB)'] = float(mem)
# Only have semantic segmentation now
if ms_id and els[ms_id] != '-' and els[ms_id] != '':
model['Results'][0]['Metrics'][
'mIoU(ms+flip)'] = float(els[ms_id])
models.append(model)
j += 1
i = j
else:
i += 1
flag = (code_url is not None) and (paper_url is not None) and (repo_url
is not None)
assert flag, f'{collection_name} readme error'
collection['Name'] = method
collection['Metadata']['Training Data'] = datasets
collection['Code']['URL'] = code_url
collection['Code']['Version'] = code_version
collection['Paper']['URL'] = paper_url
collection['Paper']['Title'] = paper_title
collection['Converted From']['Code'] = repo_url
# ['Converted From']['Weights] miss
# remove empty attribute
check_key_list = ['Code', 'Paper', 'Converted From']
for check_key in check_key_list:
key_list = list(collection[check_key].keys())
for key in key_list:
if check_key not in collection:
break
if collection[check_key][key] == '':
if len(collection[check_key].keys()) == 1:
collection.pop(check_key)
else:
collection[check_key].pop(key)
yml_file = f'{md_file[:-9]}{collection_name}.yml'
if is_backbone:
if collection['Name'] not in COLLECTIONS:
result = {
'Collections': [collection],
'Models': models,
'Yml': yml_file
}
COLLECTIONS_TEMP.append(result)
return False
else:
result = {'Models': models}
else:
COLLECTIONS.append(collection['Name'])
result = {'Collections': [collection], 'Models': models}
return dump_yaml_and_check_difference(result, yml_file)
def update_model_index():
"""Update model-index.yml according to model .md files.
Returns:
Bool: If the updated model-index.yml is different from the original.
"""
configs_dir = osp.join(MMSEG_ROOT, 'configs')
yml_files = glob.glob(osp.join(configs_dir, '**', '*.yml'), recursive=True)
yml_files.sort()
# add .replace('\\', '/') to avoid Windows Style path
model_index = {
'Import': [
osp.relpath(yml_file, MMSEG_ROOT).replace('\\', '/')
for yml_file in yml_files
]
}
model_index_file = osp.join(MMSEG_ROOT, 'model-index.yml')
is_different = dump_yaml_and_check_difference(model_index,
model_index_file)
return is_different
if __name__ == '__main__':
file_list = [fn for fn in sys.argv[1:] if osp.basename(fn) == 'README.md']
if not file_list:
sys.exit(0)
file_modified = False
for fn in file_list:
file_modified |= parse_md(fn)
for result in COLLECTIONS_TEMP:
collection = result['Collections'][0]
yml_file = result.pop('Yml', None)
if collection['Name'] in COLLECTIONS:
result.pop('Collections')
file_modified |= dump_yaml_and_check_difference(result, yml_file)
file_modified |= update_model_index()
sys.exit(1 if file_modified else 0)