[Feature] Add some scripts for development. (#1257)
* [Feature] Add some scripts for development. * Add `generate_readme.py`. * Update according to commentspull/1240/head
parent
6ea59bd846
commit
0e4163668f
|
@ -0,0 +1,186 @@
|
||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from rich.console import Console
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
prog_description = """\
|
||||||
|
Draw the state dict tree.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description=prog_description)
|
||||||
|
parser.add_argument(
|
||||||
|
'path',
|
||||||
|
type=Path,
|
||||||
|
help='The path of the checkpoint or model config to draw.')
|
||||||
|
parser.add_argument('--depth', type=int, help='The max depth to draw.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--full-name',
|
||||||
|
action='store_true',
|
||||||
|
help='Whether to print the full name of the key.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--shape',
|
||||||
|
action='store_true',
|
||||||
|
help='Whether to print the shape of the parameter.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--state-key',
|
||||||
|
type=str,
|
||||||
|
help='The key of the state dict in the checkpoint.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--number',
|
||||||
|
action='store_true',
|
||||||
|
help='Mark all parameters and their index number.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--node',
|
||||||
|
type=str,
|
||||||
|
help='Show the sub-tree of a node, like "backbone.layers".')
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def ckpt_to_state_dict(checkpoint, key=None):
|
||||||
|
if key is not None:
|
||||||
|
state_dict = checkpoint[key]
|
||||||
|
elif 'state_dict' in checkpoint:
|
||||||
|
# try mmcls style
|
||||||
|
state_dict = checkpoint['state_dict']
|
||||||
|
elif 'model' in checkpoint:
|
||||||
|
state_dict = checkpoint['model']
|
||||||
|
elif isinstance(next(iter(checkpoint.values())), torch.Tensor):
|
||||||
|
# try native style
|
||||||
|
state_dict = checkpoint
|
||||||
|
else:
|
||||||
|
raise KeyError('Please specify the key of state '
|
||||||
|
f'dict from {list(checkpoint.keys())}.')
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
class StateDictTree:
|
||||||
|
|
||||||
|
def __init__(self, key='', value=None):
|
||||||
|
self.children = {}
|
||||||
|
self.key: str = key
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def add_parameter(self, key, value):
|
||||||
|
keys = key.split('.', 1)
|
||||||
|
if len(keys) == 1:
|
||||||
|
self.children[key] = StateDictTree(key, value)
|
||||||
|
elif keys[0] in self.children:
|
||||||
|
self.children[keys[0]].add_parameter(keys[1], value)
|
||||||
|
else:
|
||||||
|
node = StateDictTree(keys[0])
|
||||||
|
node.add_parameter(keys[1], value)
|
||||||
|
self.children[keys[0]] = node
|
||||||
|
|
||||||
|
def __getitem__(self, key: str):
|
||||||
|
return self.children[key]
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
with console.capture() as capture:
|
||||||
|
for line in self.iter_tree():
|
||||||
|
console.print(line)
|
||||||
|
return capture.get()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.children)
|
||||||
|
|
||||||
|
def draw_tree(self,
|
||||||
|
max_depth=None,
|
||||||
|
full_name=False,
|
||||||
|
with_shape=False,
|
||||||
|
with_value=False):
|
||||||
|
for line in self.iter_tree(
|
||||||
|
max_depth=max_depth,
|
||||||
|
full_name=full_name,
|
||||||
|
with_shape=with_shape,
|
||||||
|
with_value=with_value,
|
||||||
|
):
|
||||||
|
console.print(line, highlight=False)
|
||||||
|
|
||||||
|
def iter_tree(
|
||||||
|
self,
|
||||||
|
lead='',
|
||||||
|
prefix='',
|
||||||
|
max_depth=None,
|
||||||
|
full_name=False,
|
||||||
|
with_shape=False,
|
||||||
|
with_value=False,
|
||||||
|
):
|
||||||
|
if self.value is None:
|
||||||
|
key_str = f'[blue]{self.key}[/]'
|
||||||
|
elif with_shape:
|
||||||
|
key_str = f'[green]{self.key}[/] {tuple(self.value.shape)}'
|
||||||
|
elif with_value:
|
||||||
|
key_str = f'[green]{self.key}[/] {self.value}'
|
||||||
|
else:
|
||||||
|
key_str = f'[green]{self.key}[/]'
|
||||||
|
|
||||||
|
yield lead + prefix + key_str
|
||||||
|
|
||||||
|
lead = lead.replace('├─', '│ ')
|
||||||
|
lead = lead.replace('└─', ' ')
|
||||||
|
if self.key and full_name:
|
||||||
|
prefix = f'{prefix}{self.key}.'
|
||||||
|
|
||||||
|
if max_depth == 0:
|
||||||
|
return
|
||||||
|
elif max_depth is not None:
|
||||||
|
max_depth -= 1
|
||||||
|
|
||||||
|
for i, child in enumerate(self.children.values()):
|
||||||
|
level_lead = '├─' if i < len(self.children) - 1 else '└─'
|
||||||
|
yield from child.iter_tree(
|
||||||
|
lead=f'{lead}{level_lead} ',
|
||||||
|
prefix=prefix,
|
||||||
|
max_depth=max_depth,
|
||||||
|
full_name=full_name,
|
||||||
|
with_shape=with_shape,
|
||||||
|
with_value=with_value)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
if args.path.suffix in ['.json', '.py', '.yml']:
|
||||||
|
from mmengine.runner import get_state_dict
|
||||||
|
|
||||||
|
from mmcls.apis import init_model
|
||||||
|
model = init_model(args.path, device='cpu')
|
||||||
|
state_dict = get_state_dict(model)
|
||||||
|
else:
|
||||||
|
ckpt = torch.load(args.path, map_location='cpu')
|
||||||
|
state_dict = ckpt_to_state_dict(ckpt, args.state_key)
|
||||||
|
|
||||||
|
root = StateDictTree()
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
root.add_parameter(k, v)
|
||||||
|
|
||||||
|
para_index = 0
|
||||||
|
mark_width = math.floor(math.log(len(state_dict), 10) + 1)
|
||||||
|
if args.node is not None:
|
||||||
|
for key in args.node.split('.'):
|
||||||
|
root = root[key]
|
||||||
|
|
||||||
|
for line in root.iter_tree(
|
||||||
|
max_depth=args.depth,
|
||||||
|
full_name=args.full_name,
|
||||||
|
with_shape=args.shape,
|
||||||
|
):
|
||||||
|
if not args.number:
|
||||||
|
mark = ''
|
||||||
|
# A hack method to determine whether a line is parameter.
|
||||||
|
elif '[green]' in line:
|
||||||
|
mark = f'[red]({str(para_index).ljust(mark_width)})[/]'
|
||||||
|
para_index += 1
|
||||||
|
else:
|
||||||
|
mark = ' ' * (mark_width + 2)
|
||||||
|
console.print(mark + line, highlight=False)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
|
@ -0,0 +1,121 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import torch
|
||||||
|
from ckpt_tree import StateDictTree, ckpt_to_state_dict
|
||||||
|
from rich.progress import track
|
||||||
|
from scipy import stats
|
||||||
|
|
||||||
|
prog_description = """\
|
||||||
|
Compare the initialization distribution between state dicts by Kolmogorov-Smirnov test.
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
description=prog_description)
|
||||||
|
parser.add_argument(
|
||||||
|
'model_a',
|
||||||
|
type=Path,
|
||||||
|
help='The path of the first checkpoint or model config.')
|
||||||
|
parser.add_argument(
|
||||||
|
'model_b',
|
||||||
|
type=Path,
|
||||||
|
help='The path of the second checkpoint or model config.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--show',
|
||||||
|
action='store_true',
|
||||||
|
help='Whether to draw the KDE of variables')
|
||||||
|
parser.add_argument(
|
||||||
|
'-p',
|
||||||
|
default=0.01,
|
||||||
|
type=float,
|
||||||
|
help='The threshold of p-value. '
|
||||||
|
'Higher threshold means more strict test.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def compare_distribution(state_dict_a, state_dict_b, p_thres):
|
||||||
|
assert len(state_dict_a) == len(state_dict_b)
|
||||||
|
for k, v1 in state_dict_a.items():
|
||||||
|
assert k in state_dict_b
|
||||||
|
v2 = state_dict_b[k]
|
||||||
|
v1 = v1.cpu().flatten()
|
||||||
|
v2 = v2.cpu().flatten()
|
||||||
|
pvalue = stats.kstest(v1, v2).pvalue
|
||||||
|
if pvalue < p_thres:
|
||||||
|
yield k, pvalue, v1, v2
|
||||||
|
|
||||||
|
|
||||||
|
def state_dict_from_cfg_or_ckpt(path, state_key=None):
|
||||||
|
if path.suffix in ['.json', '.py', '.yml']:
|
||||||
|
from mmengine.runner import get_state_dict
|
||||||
|
|
||||||
|
from mmcls.apis import init_model
|
||||||
|
model = init_model(path, device='cpu')
|
||||||
|
model.init_weights()
|
||||||
|
return get_state_dict(model)
|
||||||
|
else:
|
||||||
|
ckpt = torch.load(path, map_location='cpu')
|
||||||
|
return ckpt_to_state_dict(ckpt, state_key)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
state_dict_a = state_dict_from_cfg_or_ckpt(args.model_a)
|
||||||
|
state_dict_b = state_dict_from_cfg_or_ckpt(args.model_b)
|
||||||
|
compare_keys = state_dict_a.keys() & state_dict_b.keys()
|
||||||
|
if len(compare_keys) == 0:
|
||||||
|
raise ValueError("The state dicts don't match, please convert "
|
||||||
|
'to the same keys before comparison.')
|
||||||
|
|
||||||
|
root = StateDictTree()
|
||||||
|
for key in track(compare_keys):
|
||||||
|
if state_dict_a[key].shape != state_dict_b[key].shape:
|
||||||
|
raise ValueError(f'The shapes of "{key}" are different. '
|
||||||
|
'Please check models in the same architecture.')
|
||||||
|
|
||||||
|
# Sample at most 30000 items to prevent long-time calcuation.
|
||||||
|
perm_ids = torch.randperm(state_dict_a[key].numel())[:30000]
|
||||||
|
value_a = state_dict_a[key].flatten()[perm_ids]
|
||||||
|
value_b = state_dict_b[key].flatten()[perm_ids]
|
||||||
|
pvalue = stats.kstest(value_a, value_b).pvalue
|
||||||
|
if pvalue < args.p:
|
||||||
|
root.add_parameter(key, round(pvalue, 4))
|
||||||
|
if args.show:
|
||||||
|
try:
|
||||||
|
import seaborn as sns
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError('Please install `seaborn` by '
|
||||||
|
'`pip install seaborn` to show KDE.')
|
||||||
|
sample_a = str([round(v.item(), 2) for v in value_a[:10]])
|
||||||
|
sample_b = str([round(v.item(), 2) for v in value_b[:10]])
|
||||||
|
if value_a.std() > 0:
|
||||||
|
sns.kdeplot(value_a, fill=True)
|
||||||
|
else:
|
||||||
|
sns.scatterplot(x=[value_a[0].item()], y=[1])
|
||||||
|
if value_b.std() > 0:
|
||||||
|
sns.kdeplot(value_b, fill=True)
|
||||||
|
else:
|
||||||
|
sns.scatterplot(x=[value_b[0].item()], y=[1])
|
||||||
|
plt.legend([
|
||||||
|
f'{args.model_a.stem}: {sample_a}',
|
||||||
|
f'{args.model_b.stem}: {sample_b}'
|
||||||
|
])
|
||||||
|
plt.title(key)
|
||||||
|
plt.show()
|
||||||
|
if len(root) > 0:
|
||||||
|
root.draw_tree(with_value=True)
|
||||||
|
print("Above parameters didn't pass the test, "
|
||||||
|
'and the values are their similarity score.')
|
||||||
|
else:
|
||||||
|
print('The distributions of all weights are the same.')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
|
@ -0,0 +1,161 @@
|
||||||
|
# flake8: noqa
|
||||||
|
import argparse
|
||||||
|
import warnings
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from modelindex.load_model_index import load
|
||||||
|
from modelindex.models.ModelIndex import ModelIndex
|
||||||
|
|
||||||
|
prog_description = """\
|
||||||
|
Use metafile to generate a README.md.
|
||||||
|
|
||||||
|
Notice that the tool may fail in some corner cases, and you still need to check and fill some contents manually in the generated README.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description=prog_description)
|
||||||
|
parser.add_argument('metafile', type=Path, help='The path of metafile')
|
||||||
|
parser.add_argument(
|
||||||
|
'--table', action='store_true', help='Only generate summary tables')
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def add_title(metafile: ModelIndex, readme: list):
|
||||||
|
paper = metafile.collections[0].paper
|
||||||
|
title = paper['Title']
|
||||||
|
url = paper['URL']
|
||||||
|
abbr = metafile.collections[0].name
|
||||||
|
papertype = metafile.collections[0].data.get('type', 'Algorithm')
|
||||||
|
|
||||||
|
readme.append(f'# {abbr}\n')
|
||||||
|
readme.append(f'> [{title}]({url})')
|
||||||
|
readme.append(f'<!-- [{papertype.upper()}] -->')
|
||||||
|
readme.append('')
|
||||||
|
|
||||||
|
|
||||||
|
def add_abstract(metafile, readme):
|
||||||
|
paper = metafile.collections[0].paper
|
||||||
|
url = paper['URL']
|
||||||
|
if 'arxiv' in url:
|
||||||
|
try:
|
||||||
|
import arxiv
|
||||||
|
search = arxiv.Search(id_list=[url.split('/')[-1]])
|
||||||
|
info = next(search.results())
|
||||||
|
abstract = info.summary
|
||||||
|
except ImportError:
|
||||||
|
warnings.warn('Install arxiv parser by `pip install arxiv` '
|
||||||
|
'to automatically generate abstract.')
|
||||||
|
abstract = None
|
||||||
|
else:
|
||||||
|
abstract = None
|
||||||
|
|
||||||
|
readme.append('## Abstract\n')
|
||||||
|
if abstract is not None:
|
||||||
|
readme.append(abstract.replace('\n', ' '))
|
||||||
|
|
||||||
|
readme.append('')
|
||||||
|
readme.append('<div align=center>\n'
|
||||||
|
'<img src="" width="50%"/>\n'
|
||||||
|
'</div>')
|
||||||
|
readme.append('')
|
||||||
|
|
||||||
|
|
||||||
|
def add_models(metafile, readme):
|
||||||
|
models = metafile.models
|
||||||
|
if len(models) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
readme.append('## Results and models')
|
||||||
|
readme.append('')
|
||||||
|
|
||||||
|
datasets = defaultdict(list)
|
||||||
|
for model in models:
|
||||||
|
if model.results is None:
|
||||||
|
# No results on pretrained model.
|
||||||
|
datasets['Pre-trained Models'].append(model)
|
||||||
|
else:
|
||||||
|
datasets[model.results[0].dataset].append(model)
|
||||||
|
|
||||||
|
for dataset, models in datasets.items():
|
||||||
|
if dataset == 'Pre-trained Models':
|
||||||
|
readme.append(f'### {dataset}\n')
|
||||||
|
readme.append(
|
||||||
|
'The pre-trained models are only used to fine-tune, '
|
||||||
|
"and therefore cannot be trained and don't have evaluation results.\n"
|
||||||
|
)
|
||||||
|
readme.append(
|
||||||
|
'| Model | Pretrain | Params(M) | Flops(G) | Config | Download |\n'
|
||||||
|
'|:---------------------:|:---------:|:---------:|:--------:|:------:|:--------:|'
|
||||||
|
)
|
||||||
|
converted_from = None
|
||||||
|
for model in models:
|
||||||
|
name = model.name.center(21)
|
||||||
|
params = model.metadata.parameters / 1e6
|
||||||
|
flops = model.metadata.flops / 1e9
|
||||||
|
converted_from = converted_from or model.data.get(
|
||||||
|
'Converted From', None)
|
||||||
|
config = './' + Path(model.config).name
|
||||||
|
weights = model.weights
|
||||||
|
star = '\*' if '3rdparty' in weights else ''
|
||||||
|
readme.append(
|
||||||
|
f'| {name}{star} | {params:.2f} | {flops:.2f} | [config]({config}) | [model]({weights}) |'
|
||||||
|
),
|
||||||
|
if converted_from is not None:
|
||||||
|
readme.append('')
|
||||||
|
readme.append(
|
||||||
|
f"*Models with \* are converted from the [official repo]({converted_from['Code']}).*\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
readme.append(f'### {dataset}\n')
|
||||||
|
readme.append(
|
||||||
|
'| Model | Pretrain | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |\n'
|
||||||
|
'|:---------------------:|:----------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:|'
|
||||||
|
)
|
||||||
|
converted_from = None
|
||||||
|
for model in models:
|
||||||
|
name = model.name.center(21)
|
||||||
|
params = model.metadata.parameters / 1e6
|
||||||
|
flops = model.metadata.flops / 1e9
|
||||||
|
metrics = model.results[0].metrics
|
||||||
|
top1 = metrics.get('Top 1 Accuracy')
|
||||||
|
top5 = metrics.get('Top 5 Accuracy', 0)
|
||||||
|
converted_from = converted_from or model.data.get(
|
||||||
|
'Converted From', None)
|
||||||
|
config = './' + Path(model.config).name
|
||||||
|
weights = model.weights
|
||||||
|
star = '\*' if '3rdparty' in weights else ''
|
||||||
|
if 'in21k-pre' in weights:
|
||||||
|
pretrain = 'ImageNet 21k'
|
||||||
|
else:
|
||||||
|
pretrain = 'From scratch'
|
||||||
|
readme.append(
|
||||||
|
f'| {name}{star} | {pretrain} | {params:.2f} | {flops:.2f} | {top1:.2f} | {top5:.2f} | [config]({config}) | [model]({weights}) |'
|
||||||
|
),
|
||||||
|
if converted_from is not None:
|
||||||
|
readme.append('')
|
||||||
|
readme.append(
|
||||||
|
f"*Models with \* are converted from the [official repo]({converted_from['Code']}). "
|
||||||
|
'The config files of these models are only for inference. '
|
||||||
|
"We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
metafile = load(str(args.metafile))
|
||||||
|
readme_lines = []
|
||||||
|
if not args.table:
|
||||||
|
add_title(metafile, readme_lines)
|
||||||
|
add_abstract(metafile, readme_lines)
|
||||||
|
add_models(metafile, readme_lines)
|
||||||
|
if not args.table:
|
||||||
|
readme_lines.append('## Citation\n')
|
||||||
|
readme_lines.append('```bibtex\n\n```\n')
|
||||||
|
print('\n'.join(readme_lines))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Loading…
Reference in New Issue