470 lines
16 KiB
Python
470 lines
16 KiB
Python
import argparse
|
|
import copy
|
|
import re
|
|
from functools import partial
|
|
from pathlib import Path
|
|
|
|
import yaml
|
|
from prompt_toolkit import ANSI
|
|
from prompt_toolkit import prompt as _prompt
|
|
from prompt_toolkit.completion import (FuzzyCompleter, FuzzyWordCompleter,
|
|
PathCompleter)
|
|
from rich.console import Console
|
|
from rich.panel import Panel
|
|
from rich.prompt import Confirm, Prompt
|
|
from rich.syntax import Syntax
|
|
|
|
prog_description = """\
|
|
To display metafile or fill missing fields of the metafile.
|
|
"""
|
|
|
|
MMCLS_ROOT = Path(__file__).absolute().parents[1].resolve().absolute()
|
|
console = Console()
|
|
dataset_completer = FuzzyWordCompleter(
|
|
['ImageNet-1k', 'ImageNet-21k', 'CIFAR-10', 'CIFAR-100'])
|
|
|
|
|
|
def prompt(message,
|
|
allow_empty=True,
|
|
default=None,
|
|
multiple=False,
|
|
completer=None):
|
|
with console.capture() as capture:
|
|
console.print(message, end='')
|
|
|
|
message = ANSI(capture.get())
|
|
ask = partial(
|
|
_prompt, message=message, default=default or '', completer=completer)
|
|
|
|
out = ask()
|
|
|
|
if multiple:
|
|
outs = []
|
|
while out != '':
|
|
outs.append(out)
|
|
out = ask()
|
|
return outs
|
|
|
|
if not allow_empty and out == '':
|
|
while out == '':
|
|
out = ask()
|
|
|
|
if default is None and out == '':
|
|
return None
|
|
else:
|
|
return out.strip()
|
|
|
|
|
|
class MyDumper(yaml.Dumper):
|
|
|
|
def increase_indent(self, flow=False, indentless=False):
|
|
return super(MyDumper, self).increase_indent(flow, False)
|
|
|
|
|
|
yaml_dump = partial(
|
|
yaml.dump, Dumper=MyDumper, default_flow_style=False, sort_keys=False)
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description=prog_description)
|
|
parser.add_argument('--src', type=Path, help='The path of the matafile.')
|
|
parser.add_argument('--out', '-o', type=Path, help='The output path.')
|
|
parser.add_argument(
|
|
'--view', action='store_true', help='Only pretty print the metafile.')
|
|
parser.add_argument('--csv', type=str, help='Use a csv to update models.')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def get_flops(config_path):
|
|
import numpy as np
|
|
import torch
|
|
from fvcore.nn import FlopCountAnalysis, parameter_count
|
|
from mmengine.config import Config
|
|
from mmengine.dataset import Compose
|
|
from mmengine.registry import DefaultScope
|
|
|
|
import mmpretrain.datasets # noqa: F401
|
|
from mmpretrain.apis import init_model
|
|
|
|
cfg = Config.fromfile(config_path)
|
|
|
|
if 'test_dataloader' in cfg:
|
|
# build the data pipeline
|
|
test_dataset = cfg.test_dataloader.dataset
|
|
if test_dataset.pipeline[0]['type'] == 'LoadImageFromFile':
|
|
test_dataset.pipeline.pop(0)
|
|
if test_dataset.type in ['CIFAR10', 'CIFAR100']:
|
|
# The image shape of CIFAR is (32, 32, 3)
|
|
test_dataset.pipeline.insert(1, dict(type='Resize', scale=32))
|
|
|
|
with DefaultScope.overwrite_default_scope('mmpretrain'):
|
|
data = Compose(test_dataset.pipeline)({
|
|
'img':
|
|
np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8)
|
|
})
|
|
resolution = tuple(data['inputs'].shape[-2:])
|
|
else:
|
|
# For configs only for get model.
|
|
resolution = (224, 224)
|
|
|
|
model = init_model(cfg, device='cpu')
|
|
|
|
with torch.no_grad():
|
|
model.forward = model.extract_feat
|
|
model.to('cpu')
|
|
inputs = (torch.randn((1, 3, *resolution)), )
|
|
analyzer = FlopCountAnalysis(model, inputs)
|
|
analyzer.unsupported_ops_warnings(False)
|
|
analyzer.uncalled_modules_warnings(False)
|
|
flops = analyzer.total()
|
|
params = parameter_count(model)['']
|
|
return int(flops), int(params)
|
|
|
|
|
|
def fill_collection(collection: dict):
|
|
if collection.get('Name') is None:
|
|
name = prompt(
|
|
'Please input the collection [red]name[/]: ', allow_empty=False)
|
|
collection['Name'] = name
|
|
|
|
if collection.get('Metadata', {}).get('Architecture') is None:
|
|
architecture = prompt(
|
|
'Please input the model [red]architecture[/] '
|
|
'(input empty to finish): ',
|
|
multiple=True)
|
|
if len(architecture) > 0:
|
|
collection.setdefault('Metadata', {})
|
|
collection['Metadata']['Architecture'] = architecture
|
|
|
|
if collection.get('Paper', {}).get('Title') is None:
|
|
title = prompt('Please input the [red]paper title[/]: ')
|
|
else:
|
|
title = collection['Paper']['Title']
|
|
if collection.get('Paper', {}).get('URL') is None:
|
|
url = prompt('Please input the [red]paper url[/]: ')
|
|
else:
|
|
url = collection['Paper']['URL']
|
|
paper = dict(Title=title, URL=url)
|
|
collection['Paper'] = paper
|
|
|
|
if collection.get('README') is None:
|
|
readme = prompt(
|
|
'Please input the [red]README[/] file path: ',
|
|
completer=PathCompleter(file_filter=lambda name: Path(name).is_dir(
|
|
) or 'README.md' in name))
|
|
if readme is not None:
|
|
collection['README'] = str(
|
|
Path(readme).absolute().relative_to(MMCLS_ROOT))
|
|
else:
|
|
collection['README'] = None
|
|
|
|
order = ['Name', 'Metadata', 'Paper', 'README', 'Code']
|
|
collection = {
|
|
k: collection[k]
|
|
for k in sorted(collection.keys(), key=order.index)
|
|
}
|
|
return collection
|
|
|
|
|
|
def fill_model_by_prompt(model: dict, defaults: dict):
|
|
# Name
|
|
if model.get('Name') is None:
|
|
name = prompt(
|
|
'Please input the model [red]name[/]: ', allow_empty=False)
|
|
model['Name'] = name
|
|
|
|
# In Collection
|
|
model['In Collection'] = defaults.get('In Collection')
|
|
|
|
# Config
|
|
config = model.get('Config')
|
|
if config is None:
|
|
config = prompt(
|
|
'Please input the [red]config[/] file path: ',
|
|
completer=FuzzyCompleter(PathCompleter()))
|
|
if config is not None:
|
|
config = str(Path(config).absolute().relative_to(MMCLS_ROOT))
|
|
model['Config'] = config
|
|
|
|
# Metadata.Flops, Metadata.Parameters
|
|
flops = model.get('Metadata', {}).get('FLOPs')
|
|
params = model.get('Metadata', {}).get('Parameters')
|
|
if model.get('Config') is not None and (
|
|
MMCLS_ROOT / model['Config']).exists() and (flops is None
|
|
or params is None):
|
|
try:
|
|
print('Automatically compute FLOPs and Parameters from config.')
|
|
flops, params = get_flops(str(MMCLS_ROOT / model['Config']))
|
|
except Exception:
|
|
print('Failed to compute FLOPs and Parameters.')
|
|
|
|
if flops is None:
|
|
flops = prompt('Please specify the [red]FLOPs[/]: ')
|
|
if flops is not None:
|
|
flops = int(flops)
|
|
if params is None:
|
|
params = prompt('Please specify the [red]number of parameters[/]: ')
|
|
if params is not None:
|
|
params = int(params)
|
|
|
|
model.setdefault('Metadata', {})
|
|
model['Metadata'].setdefault('FLOPs', flops)
|
|
model['Metadata'].setdefault('Parameters', params)
|
|
|
|
if model.get('Metadata', {}).get('Training Data') is None:
|
|
training_data = prompt(
|
|
'Please input all [red]training dataset[/], '
|
|
'include pre-training (input empty to finish): ',
|
|
completer=dataset_completer,
|
|
multiple=True)
|
|
if len(training_data) > 1:
|
|
model['Metadata']['Training Data'] = training_data
|
|
elif len(training_data) == 1:
|
|
model['Metadata']['Training Data'] = training_data[0]
|
|
|
|
results = model.get('Results')
|
|
if results is None:
|
|
test_dataset = prompt(
|
|
'Please input the [red]test dataset[/]: ',
|
|
completer=dataset_completer)
|
|
if test_dataset is not None:
|
|
task = Prompt.ask(
|
|
'Please input the [red]test task[/]',
|
|
default='Image Classification')
|
|
if task == 'Image Classification':
|
|
metrics = {}
|
|
top1 = prompt('Please input the [red]top-1 accuracy[/]: ')
|
|
top5 = prompt('Please input the [red]top-5 accuracy[/]: ')
|
|
if top1 is not None:
|
|
metrics['Top 1 Accuracy'] = round(float(top1), 2)
|
|
if top5 is not None:
|
|
metrics['Top 5 Accuracy'] = round(float(top5), 2)
|
|
else:
|
|
metrics_list = prompt(
|
|
'Please input the [red]metrics[/] like "mAP=94.98" '
|
|
'(input empty to finish): ',
|
|
multiple=True)
|
|
metrics = {}
|
|
for metric in metrics_list:
|
|
k, v = metric.split('=')[:2]
|
|
metrics[k] = round(float(v), 2)
|
|
if len(metrics) > 0:
|
|
results = [{
|
|
'Dataset': test_dataset,
|
|
'Metrics': metrics,
|
|
'Task': task
|
|
}]
|
|
model['Results'] = results
|
|
|
|
weights = model.get('Weights')
|
|
if weights is None:
|
|
weights = prompt('Please input the [red]checkpoint download link[/]: ')
|
|
model['Weights'] = weights
|
|
|
|
if model.get('Converted From') is None and model.get(
|
|
'Weights') is not None:
|
|
if Confirm.ask(
|
|
'Is the checkpoint is converted from [red]other repository[/]?'
|
|
):
|
|
converted_from = {}
|
|
converted_from['Weights'] = prompt(
|
|
'Please fill the original checkpoint download link: ')
|
|
converted_from['Code'] = Prompt.ask(
|
|
'Please fill the original repository link',
|
|
default=defaults.get('Convert From.Code', None))
|
|
defaults['Convert From.Code'] = converted_from['Code']
|
|
model['Converted From'] = converted_from
|
|
elif model.get('Converted From', {}).get('Code') is not None:
|
|
defaults['Convert From.Code'] = model['Converted From']['Code']
|
|
|
|
order = [
|
|
'Name', 'Metadata', 'In Collection', 'Results', 'Weights', 'Config',
|
|
'Converted From'
|
|
]
|
|
model = {k: model[k] for k in sorted(model.keys(), key=order.index)}
|
|
return model
|
|
|
|
|
|
def update_model_by_dict(model: dict, update_dict: dict, defaults: dict):
|
|
# Name
|
|
if 'name override' in update_dict:
|
|
model['Name'] = update_dict['name override']
|
|
|
|
# In Collection
|
|
model['In Collection'] = defaults.get('In Collection')
|
|
|
|
# Config
|
|
if 'config' in update_dict:
|
|
config = update_dict['config'].strip()
|
|
config = str(Path(config).absolute().relative_to(MMCLS_ROOT))
|
|
config_updated = (config != model.get('Config'))
|
|
model['Config'] = config
|
|
else:
|
|
config_updated = False
|
|
|
|
# Metadata.Flops, Metadata.Parameters
|
|
flops = model.get('Metadata', {}).get('FLOPs')
|
|
params = model.get('Metadata', {}).get('Parameters')
|
|
if config_updated and (flops is None or params is None):
|
|
print(f'Automatically compute FLOPs and Parameters of {model["Name"]}')
|
|
flops, params = get_flops(str(MMCLS_ROOT / model['Config']))
|
|
|
|
model.setdefault('Metadata', {})
|
|
model['Metadata']['FLOPs'] = flops
|
|
model['Metadata']['Parameters'] = params
|
|
|
|
# Metadata.Training Data
|
|
if 'metadata.training data' in update_dict:
|
|
train_data = update_dict['metadata.training data'].strip()
|
|
train_data = re.split(r'\s+', train_data)
|
|
if len(train_data) > 1:
|
|
model['Metadata']['Training Data'] = train_data
|
|
elif len(train_data) == 1:
|
|
model['Metadata']['Training Data'] = train_data[0]
|
|
|
|
# Results.Dataset
|
|
if 'results.dataset' in update_dict:
|
|
test_data = update_dict['results.dataset'].strip()
|
|
results = model.get('Results') or [{}]
|
|
result = results[0]
|
|
result['Dataset'] = test_data
|
|
model['Results'] = results
|
|
|
|
# Results.Metrics.Top 1 Accuracy
|
|
result = None
|
|
if 'results.metrics.top 1 accuracy' in update_dict:
|
|
top1 = update_dict['results.metrics.top 1 accuracy']
|
|
results = model.get('Results') or [{}]
|
|
result = results[0]
|
|
result.setdefault('Metrics', {})
|
|
result['Metrics']['Top 1 Accuracy'] = round(float(top1), 2)
|
|
task = 'Image Classification'
|
|
model['Results'] = results
|
|
|
|
# Results.Metrics.Top 5 Accuracy
|
|
if 'results.metrics.top 5 accuracy' in update_dict:
|
|
top5 = update_dict['results.metrics.top 5 accuracy']
|
|
results = model.get('Results') or [{}]
|
|
result = results[0]
|
|
result.setdefault('Metrics', {})
|
|
result['Metrics']['Top 5 Accuracy'] = round(float(top5), 2)
|
|
task = 'Image Classification'
|
|
model['Results'] = results
|
|
|
|
if result is not None:
|
|
result['Metrics']['Task'] = task
|
|
|
|
# Weights
|
|
if 'weights' in update_dict:
|
|
weights = update_dict['weights'].strip()
|
|
model['Weights'] = weights
|
|
|
|
# Converted From.Code
|
|
if 'converted from.code' in update_dict:
|
|
from_code = update_dict['converted from.code'].strip()
|
|
model.setdefault('Converted From', {})
|
|
model['Converted From']['Code'] = from_code
|
|
|
|
# Converted From.Weights
|
|
if 'converted from.weights' in update_dict:
|
|
from_weight = update_dict['converted from.weights'].strip()
|
|
model.setdefault('Converted From', {})
|
|
model['Converted From']['Weights'] = from_weight
|
|
|
|
order = [
|
|
'Name', 'Metadata', 'In Collection', 'Results', 'Weights', 'Config',
|
|
'Converted From'
|
|
]
|
|
model = {k: model[k] for k in sorted(model.keys(), key=order.index)}
|
|
return model
|
|
|
|
|
|
def format_collection(collection: dict):
|
|
yaml_str = yaml_dump(collection)
|
|
return Panel(
|
|
Syntax(yaml_str, 'yaml', background_color='default'),
|
|
width=150,
|
|
title='Collection')
|
|
|
|
|
|
def format_model(model: dict):
|
|
yaml_str = yaml_dump(model)
|
|
return Panel(
|
|
Syntax(yaml_str, 'yaml', background_color='default'),
|
|
width=150,
|
|
title='Model')
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
if args.src is not None:
|
|
with open(args.src, 'r') as f:
|
|
content = yaml.load(f, yaml.SafeLoader)
|
|
else:
|
|
content = {}
|
|
|
|
if args.view:
|
|
collection = content.get('Collections', [{}])[0]
|
|
console.print(format_collection(collection))
|
|
models = content.get('Models', [])
|
|
for model in models:
|
|
console.print(format_model(model))
|
|
return
|
|
|
|
collection = content.get('Collections', [{}])[0]
|
|
ori_collection = copy.deepcopy(collection)
|
|
|
|
console.print(format_collection(collection))
|
|
collection = fill_collection(collection)
|
|
if ori_collection != collection:
|
|
console.print(format_collection(collection))
|
|
model_defaults = {'In Collection': collection['Name']}
|
|
|
|
models = content.get('Models', [])
|
|
updated_models = []
|
|
|
|
if args.csv is not None:
|
|
import pandas as pd
|
|
df = pd.read_csv(args.csv).rename(columns=lambda x: x.strip().lower())
|
|
assert df['name'].is_unique, 'The csv has duplicated model names.'
|
|
models_dict = {item['Name']: item for item in models}
|
|
for update_dict in df.to_dict('records'):
|
|
assert 'name' in update_dict, 'The csv must have the `Name` field.'
|
|
model_name = update_dict['name'].strip()
|
|
model = models_dict.pop(model_name, {'Name': model_name})
|
|
model = update_model_by_dict(model, update_dict, model_defaults)
|
|
updated_models.append(model)
|
|
updated_models.extend(models_dict.values())
|
|
else:
|
|
for model in models:
|
|
console.print(format_model(model))
|
|
ori_model = copy.deepcopy(model)
|
|
model = fill_model_by_prompt(model, model_defaults)
|
|
if ori_model != model:
|
|
console.print(format_model(model))
|
|
updated_models.append(model)
|
|
|
|
while Confirm.ask('Add new model?'):
|
|
model = fill_model_by_prompt({}, model_defaults)
|
|
updated_models.append(model)
|
|
|
|
# Save updated models even error happened.
|
|
updated_models.sort(key=lambda item: (item.get('Metadata', {}).get(
|
|
'FLOPs', 0), len(item['Name'])))
|
|
if args.out is not None:
|
|
with open(args.out, 'w') as f:
|
|
yaml_dump({'Collections': [collection]}, f)
|
|
f.write('\n')
|
|
yaml_dump({'Models': updated_models}, f)
|
|
else:
|
|
modelindex = {'Collections': [collection], 'Models': updated_models}
|
|
yaml_str = yaml_dump(modelindex)
|
|
console.print(Syntax(yaml_str, 'yaml', background_color='default'))
|
|
console.print('Specify [red]`--out`[/] to dump to file.')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|