[Feature] Add some scripts for development. (#1257)

* [Feature] Add some scripts for development.

* Add `generate_readme.py`.

* Update according to comments
pull/1240/head
Ma Zerun 2022-12-19 13:53:13 +08:00 committed by GitHub
parent 6ea59bd846
commit 0e4163668f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 468 additions and 0 deletions

View File

@ -0,0 +1,186 @@
import argparse
import math
from pathlib import Path
import torch
from rich.console import Console
console = Console()
prog_description = """\
Draw the state dict tree.
"""
def parse_args():
parser = argparse.ArgumentParser(description=prog_description)
parser.add_argument(
'path',
type=Path,
help='The path of the checkpoint or model config to draw.')
parser.add_argument('--depth', type=int, help='The max depth to draw.')
parser.add_argument(
'--full-name',
action='store_true',
help='Whether to print the full name of the key.')
parser.add_argument(
'--shape',
action='store_true',
help='Whether to print the shape of the parameter.')
parser.add_argument(
'--state-key',
type=str,
help='The key of the state dict in the checkpoint.')
parser.add_argument(
'--number',
action='store_true',
help='Mark all parameters and their index number.')
parser.add_argument(
'--node',
type=str,
help='Show the sub-tree of a node, like "backbone.layers".')
args = parser.parse_args()
return args
def ckpt_to_state_dict(checkpoint, key=None):
if key is not None:
state_dict = checkpoint[key]
elif 'state_dict' in checkpoint:
# try mmcls style
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
elif isinstance(next(iter(checkpoint.values())), torch.Tensor):
# try native style
state_dict = checkpoint
else:
raise KeyError('Please specify the key of state '
f'dict from {list(checkpoint.keys())}.')
return state_dict
class StateDictTree:
def __init__(self, key='', value=None):
self.children = {}
self.key: str = key
self.value = value
def add_parameter(self, key, value):
keys = key.split('.', 1)
if len(keys) == 1:
self.children[key] = StateDictTree(key, value)
elif keys[0] in self.children:
self.children[keys[0]].add_parameter(keys[1], value)
else:
node = StateDictTree(keys[0])
node.add_parameter(keys[1], value)
self.children[keys[0]] = node
def __getitem__(self, key: str):
return self.children[key]
def __repr__(self) -> str:
with console.capture() as capture:
for line in self.iter_tree():
console.print(line)
return capture.get()
def __len__(self):
return len(self.children)
def draw_tree(self,
max_depth=None,
full_name=False,
with_shape=False,
with_value=False):
for line in self.iter_tree(
max_depth=max_depth,
full_name=full_name,
with_shape=with_shape,
with_value=with_value,
):
console.print(line, highlight=False)
def iter_tree(
self,
lead='',
prefix='',
max_depth=None,
full_name=False,
with_shape=False,
with_value=False,
):
if self.value is None:
key_str = f'[blue]{self.key}[/]'
elif with_shape:
key_str = f'[green]{self.key}[/] {tuple(self.value.shape)}'
elif with_value:
key_str = f'[green]{self.key}[/] {self.value}'
else:
key_str = f'[green]{self.key}[/]'
yield lead + prefix + key_str
lead = lead.replace('├─', '')
lead = lead.replace('└─', ' ')
if self.key and full_name:
prefix = f'{prefix}{self.key}.'
if max_depth == 0:
return
elif max_depth is not None:
max_depth -= 1
for i, child in enumerate(self.children.values()):
level_lead = '├─' if i < len(self.children) - 1 else '└─'
yield from child.iter_tree(
lead=f'{lead}{level_lead} ',
prefix=prefix,
max_depth=max_depth,
full_name=full_name,
with_shape=with_shape,
with_value=with_value)
def main():
args = parse_args()
if args.path.suffix in ['.json', '.py', '.yml']:
from mmengine.runner import get_state_dict
from mmcls.apis import init_model
model = init_model(args.path, device='cpu')
state_dict = get_state_dict(model)
else:
ckpt = torch.load(args.path, map_location='cpu')
state_dict = ckpt_to_state_dict(ckpt, args.state_key)
root = StateDictTree()
for k, v in state_dict.items():
root.add_parameter(k, v)
para_index = 0
mark_width = math.floor(math.log(len(state_dict), 10) + 1)
if args.node is not None:
for key in args.node.split('.'):
root = root[key]
for line in root.iter_tree(
max_depth=args.depth,
full_name=args.full_name,
with_shape=args.shape,
):
if not args.number:
mark = ''
# A hack method to determine whether a line is parameter.
elif '[green]' in line:
mark = f'[red]({str(para_index).ljust(mark_width)})[/]'
para_index += 1
else:
mark = ' ' * (mark_width + 2)
console.print(mark + line, highlight=False)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,121 @@
#!/usr/bin/env python
import argparse
from pathlib import Path
import matplotlib.pyplot as plt
import torch
from ckpt_tree import StateDictTree, ckpt_to_state_dict
from rich.progress import track
from scipy import stats
prog_description = """\
Compare the initialization distribution between state dicts by Kolmogorov-Smirnov test.
""" # noqa: E501
def parse_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter,
description=prog_description)
parser.add_argument(
'model_a',
type=Path,
help='The path of the first checkpoint or model config.')
parser.add_argument(
'model_b',
type=Path,
help='The path of the second checkpoint or model config.')
parser.add_argument(
'--show',
action='store_true',
help='Whether to draw the KDE of variables')
parser.add_argument(
'-p',
default=0.01,
type=float,
help='The threshold of p-value. '
'Higher threshold means more strict test.')
args = parser.parse_args()
return args
def compare_distribution(state_dict_a, state_dict_b, p_thres):
assert len(state_dict_a) == len(state_dict_b)
for k, v1 in state_dict_a.items():
assert k in state_dict_b
v2 = state_dict_b[k]
v1 = v1.cpu().flatten()
v2 = v2.cpu().flatten()
pvalue = stats.kstest(v1, v2).pvalue
if pvalue < p_thres:
yield k, pvalue, v1, v2
def state_dict_from_cfg_or_ckpt(path, state_key=None):
if path.suffix in ['.json', '.py', '.yml']:
from mmengine.runner import get_state_dict
from mmcls.apis import init_model
model = init_model(path, device='cpu')
model.init_weights()
return get_state_dict(model)
else:
ckpt = torch.load(path, map_location='cpu')
return ckpt_to_state_dict(ckpt, state_key)
def main():
args = parse_args()
state_dict_a = state_dict_from_cfg_or_ckpt(args.model_a)
state_dict_b = state_dict_from_cfg_or_ckpt(args.model_b)
compare_keys = state_dict_a.keys() & state_dict_b.keys()
if len(compare_keys) == 0:
raise ValueError("The state dicts don't match, please convert "
'to the same keys before comparison.')
root = StateDictTree()
for key in track(compare_keys):
if state_dict_a[key].shape != state_dict_b[key].shape:
raise ValueError(f'The shapes of "{key}" are different. '
'Please check models in the same architecture.')
# Sample at most 30000 items to prevent long-time calcuation.
perm_ids = torch.randperm(state_dict_a[key].numel())[:30000]
value_a = state_dict_a[key].flatten()[perm_ids]
value_b = state_dict_b[key].flatten()[perm_ids]
pvalue = stats.kstest(value_a, value_b).pvalue
if pvalue < args.p:
root.add_parameter(key, round(pvalue, 4))
if args.show:
try:
import seaborn as sns
except ImportError:
raise ImportError('Please install `seaborn` by '
'`pip install seaborn` to show KDE.')
sample_a = str([round(v.item(), 2) for v in value_a[:10]])
sample_b = str([round(v.item(), 2) for v in value_b[:10]])
if value_a.std() > 0:
sns.kdeplot(value_a, fill=True)
else:
sns.scatterplot(x=[value_a[0].item()], y=[1])
if value_b.std() > 0:
sns.kdeplot(value_b, fill=True)
else:
sns.scatterplot(x=[value_b[0].item()], y=[1])
plt.legend([
f'{args.model_a.stem}: {sample_a}',
f'{args.model_b.stem}: {sample_b}'
])
plt.title(key)
plt.show()
if len(root) > 0:
root.draw_tree(with_value=True)
print("Above parameters didn't pass the test, "
'and the values are their similarity score.')
else:
print('The distributions of all weights are the same.')
if __name__ == '__main__':
main()

View File

@ -0,0 +1,161 @@
# flake8: noqa
import argparse
import warnings
from collections import defaultdict
from pathlib import Path
from modelindex.load_model_index import load
from modelindex.models.ModelIndex import ModelIndex
prog_description = """\
Use metafile to generate a README.md.
Notice that the tool may fail in some corner cases, and you still need to check and fill some contents manually in the generated README.
"""
def parse_args():
parser = argparse.ArgumentParser(description=prog_description)
parser.add_argument('metafile', type=Path, help='The path of metafile')
parser.add_argument(
'--table', action='store_true', help='Only generate summary tables')
args = parser.parse_args()
return args
def add_title(metafile: ModelIndex, readme: list):
paper = metafile.collections[0].paper
title = paper['Title']
url = paper['URL']
abbr = metafile.collections[0].name
papertype = metafile.collections[0].data.get('type', 'Algorithm')
readme.append(f'# {abbr}\n')
readme.append(f'> [{title}]({url})')
readme.append(f'<!-- [{papertype.upper()}] -->')
readme.append('')
def add_abstract(metafile, readme):
paper = metafile.collections[0].paper
url = paper['URL']
if 'arxiv' in url:
try:
import arxiv
search = arxiv.Search(id_list=[url.split('/')[-1]])
info = next(search.results())
abstract = info.summary
except ImportError:
warnings.warn('Install arxiv parser by `pip install arxiv` '
'to automatically generate abstract.')
abstract = None
else:
abstract = None
readme.append('## Abstract\n')
if abstract is not None:
readme.append(abstract.replace('\n', ' '))
readme.append('')
readme.append('<div align=center>\n'
'<img src="" width="50%"/>\n'
'</div>')
readme.append('')
def add_models(metafile, readme):
models = metafile.models
if len(models) == 0:
return
readme.append('## Results and models')
readme.append('')
datasets = defaultdict(list)
for model in models:
if model.results is None:
# No results on pretrained model.
datasets['Pre-trained Models'].append(model)
else:
datasets[model.results[0].dataset].append(model)
for dataset, models in datasets.items():
if dataset == 'Pre-trained Models':
readme.append(f'### {dataset}\n')
readme.append(
'The pre-trained models are only used to fine-tune, '
"and therefore cannot be trained and don't have evaluation results.\n"
)
readme.append(
'| Model | Pretrain | Params(M) | Flops(G) | Config | Download |\n'
'|:---------------------:|:---------:|:---------:|:--------:|:------:|:--------:|'
)
converted_from = None
for model in models:
name = model.name.center(21)
params = model.metadata.parameters / 1e6
flops = model.metadata.flops / 1e9
converted_from = converted_from or model.data.get(
'Converted From', None)
config = './' + Path(model.config).name
weights = model.weights
star = '\*' if '3rdparty' in weights else ''
readme.append(
f'| {name}{star} | {params:.2f} | {flops:.2f} | [config]({config}) | [model]({weights}) |'
),
if converted_from is not None:
readme.append('')
readme.append(
f"*Models with \* are converted from the [official repo]({converted_from['Code']}).*\n"
)
else:
readme.append(f'### {dataset}\n')
readme.append(
'| Model | Pretrain | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |\n'
'|:---------------------:|:----------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:|'
)
converted_from = None
for model in models:
name = model.name.center(21)
params = model.metadata.parameters / 1e6
flops = model.metadata.flops / 1e9
metrics = model.results[0].metrics
top1 = metrics.get('Top 1 Accuracy')
top5 = metrics.get('Top 5 Accuracy', 0)
converted_from = converted_from or model.data.get(
'Converted From', None)
config = './' + Path(model.config).name
weights = model.weights
star = '\*' if '3rdparty' in weights else ''
if 'in21k-pre' in weights:
pretrain = 'ImageNet 21k'
else:
pretrain = 'From scratch'
readme.append(
f'| {name}{star} | {pretrain} | {params:.2f} | {flops:.2f} | {top1:.2f} | {top5:.2f} | [config]({config}) | [model]({weights}) |'
),
if converted_from is not None:
readme.append('')
readme.append(
f"*Models with \* are converted from the [official repo]({converted_from['Code']}). "
'The config files of these models are only for inference. '
"We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*\n"
)
def main():
args = parse_args()
metafile = load(str(args.metafile))
readme_lines = []
if not args.table:
add_title(metafile, readme_lines)
add_abstract(metafile, readme_lines)
add_models(metafile, readme_lines)
if not args.table:
readme_lines.append('## Citation\n')
readme_lines.append('```bibtex\n\n```\n')
print('\n'.join(readme_lines))
if __name__ == '__main__':
main()