mirror of https://github.com/open-mmlab/mmocr.git
116 lines
3.9 KiB
Python
116 lines
3.9 KiB
Python
import os.path as osp
|
|
|
|
from mmengine.fileio import load
|
|
from tabulate import tabulate
|
|
|
|
|
|
class BaseWeightList:
|
|
"""Class for generating model list in markdown format.
|
|
|
|
Args:
|
|
dataset_list (list[str]): List of dataset names.
|
|
table_header (list[str]): List of table header.
|
|
msg (str): Message to be displayed.
|
|
task_abbr (str): Abbreviation of task name.
|
|
metric_name (str): Metric name.
|
|
"""
|
|
|
|
base_url: str = 'https://github.com/open-mmlab/mmocr/blob/1.x/'
|
|
table_cfg: dict = dict(
|
|
tablefmt='pipe', floatfmt='.2f', numalign='right', stralign='center')
|
|
dataset_list: list
|
|
table_header: list
|
|
msg: str
|
|
task_abbr: str
|
|
metric_name: str
|
|
|
|
def __init__(self):
|
|
data = (d + f' ({self.metric_name})' for d in self.dataset_list)
|
|
self.table_header = ['Model', 'README', *data]
|
|
|
|
def _get_model_info(self, task_name: str):
|
|
meta_indexes = load('../../model-index.yml')
|
|
for meta_path in meta_indexes['Import']:
|
|
meta_path = osp.join('../../', meta_path)
|
|
metainfo = load(meta_path)
|
|
collection2md = {}
|
|
for item in metainfo['Collections']:
|
|
url = self.base_url + item['README']
|
|
collection2md[item['Name']] = f'[link]({url})'
|
|
for item in metainfo['Models']:
|
|
if task_name not in item['Config']:
|
|
continue
|
|
name = f'`{item["Name"]}`'
|
|
if item.get('Alias', None):
|
|
if isinstance(item['Alias'], str):
|
|
item['Alias'] = [item['Alias']]
|
|
aliases = [f'`{alias}`' for alias in item['Alias']]
|
|
aliases.append(name)
|
|
name = ' / '.join(aliases)
|
|
readme = collection2md[item['In Collection']]
|
|
eval_res = self._get_eval_res(item)
|
|
yield (name, readme, *eval_res)
|
|
|
|
def _get_eval_res(self, item):
|
|
eval_res = {k: '-' for k in self.dataset_list}
|
|
for res in item['Results']:
|
|
if res['Dataset'] in self.dataset_list:
|
|
eval_res[res['Dataset']] = res['Metrics'][self.metric_name]
|
|
return (eval_res[k] for k in self.dataset_list)
|
|
|
|
def gen_model_list(self):
|
|
content = f'\n{self.msg}\n'
|
|
content += '```{table}\n:class: model-summary nowrap field-list '
|
|
content += 'table table-hover\n'
|
|
content += tabulate(
|
|
self._get_model_info(self.task_abbr), self.table_header,
|
|
**self.table_cfg)
|
|
content += '\n```\n'
|
|
return content
|
|
|
|
|
|
class TextDetWeightList(BaseWeightList):
|
|
|
|
dataset_list = ['ICDAR2015', 'CTW1500', 'Totaltext']
|
|
msg = '### Text Detection'
|
|
task_abbr = 'textdet'
|
|
metric_name = 'hmean-iou'
|
|
|
|
|
|
class TextRecWeightList(BaseWeightList):
|
|
|
|
dataset_list = [
|
|
'Avg', 'IIIT5K', 'SVT', 'ICDAR2013', 'ICDAR2015', 'SVTP', 'CT80'
|
|
]
|
|
msg = ('### Text Recognition\n'
|
|
'```{note}\n'
|
|
'Avg is the average on IIIT5K, SVT, ICDAR2013, ICDAR2015, SVTP,'
|
|
' CT80.\n```\n')
|
|
task_abbr = 'textrecog'
|
|
metric_name = 'word_acc'
|
|
|
|
def _get_eval_res(self, item):
|
|
eval_res = {k: '-' for k in self.dataset_list}
|
|
avg = []
|
|
for res in item['Results']:
|
|
if res['Dataset'] in self.dataset_list:
|
|
eval_res[res['Dataset']] = res['Metrics'][self.metric_name]
|
|
avg.append(res['Metrics'][self.metric_name])
|
|
eval_res['Avg'] = sum(avg) / len(avg)
|
|
return (eval_res[k] for k in self.dataset_list)
|
|
|
|
|
|
class KIEWeightList(BaseWeightList):
|
|
|
|
dataset_list = ['wildreceipt']
|
|
msg = '### Key Information Extraction'
|
|
task_abbr = 'kie'
|
|
metric_name = 'macro_f1'
|
|
|
|
|
|
def gen_weight_list():
|
|
content = TextDetWeightList().gen_model_list()
|
|
content += TextRecWeightList().gen_model_list()
|
|
content += KIEWeightList().gen_model_list()
|
|
return content
|