456 lines
15 KiB
Python
456 lines
15 KiB
Python
import argparse
|
|
import json
|
|
import os
|
|
import os.path as osp
|
|
import re
|
|
from collections import OrderedDict
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from zipfile import ZipFile
|
|
|
|
import yaml
|
|
from modelindex.load_model_index import load
|
|
from rich.console import Console
|
|
from rich.syntax import Syntax
|
|
from rich.table import Table
|
|
|
|
console = Console()
|
|
MMSELFSUP_ROOT = Path(__file__).absolute().parents[2]
|
|
CYCLE_LEVELS = ['month', 'quarter', 'half-year', 'no-training']
|
|
METRICS_MAP = {
|
|
'Top 1 Accuracy': 'accuracy/top1',
|
|
'Top 5 Accuracy': 'accuracy/top5'
|
|
}
|
|
|
|
|
|
class RangeAction(argparse.Action):
|
|
|
|
def __call__(self, parser, namespace, values: str, option_string):
|
|
matches = re.match(r'([><=]*)([-\w]+)', values)
|
|
if matches is None:
|
|
raise ValueError(f'Unavailable range option {values}')
|
|
symbol, range_str = matches.groups()
|
|
assert range_str in CYCLE_LEVELS, \
|
|
f'{range_str} are not in {CYCLE_LEVELS}.'
|
|
level = CYCLE_LEVELS.index(range_str)
|
|
symbol = symbol or '<='
|
|
ranges = set()
|
|
if '=' in symbol:
|
|
ranges.add(level)
|
|
if '>' in symbol:
|
|
ranges.update(range(level + 1, len(CYCLE_LEVELS)))
|
|
if '<' in symbol:
|
|
ranges.update(range(level))
|
|
assert len(ranges) > 0, 'No range are selected.'
|
|
setattr(namespace, self.dest, ranges)
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
description='Train models (in models.yml) and compare accuracy.')
|
|
parser.add_argument(
|
|
'partition', type=str, help='Cluster partition to use.')
|
|
parser.add_argument(
|
|
'--job-name',
|
|
type=str,
|
|
default='selfsup-benchmark',
|
|
help='Slurm job name prefix')
|
|
parser.add_argument('--port', type=int, default=29777, help='dist port')
|
|
parser.add_argument(
|
|
'--models', nargs='+', type=str, help='Specify model names to run.')
|
|
parser.add_argument(
|
|
'--range',
|
|
type=str,
|
|
default={0},
|
|
action=RangeAction,
|
|
metavar='{month,quarter,half-year,no-training}',
|
|
help='The training benchmark range, "no-training" means all models '
|
|
"including those we haven't trained.")
|
|
parser.add_argument(
|
|
'--work-dir',
|
|
default='work_dirs/benchmark_pretrain_cls',
|
|
help='the dir to save train log')
|
|
parser.add_argument(
|
|
'--run', action='store_true', help='run script directly')
|
|
parser.add_argument(
|
|
'--local',
|
|
action='store_true',
|
|
help='run at local instead of cluster.')
|
|
parser.add_argument(
|
|
'--mail', type=str, help='Mail address to watch train status.')
|
|
parser.add_argument(
|
|
'--mail-type',
|
|
nargs='+',
|
|
default=['BEGIN', 'END', 'FAIL'],
|
|
choices=['NONE', 'BEGIN', 'END', 'FAIL', 'REQUEUE', 'ALL'],
|
|
help='Mail address to watch train status.')
|
|
parser.add_argument(
|
|
'--quotatype',
|
|
default=None,
|
|
choices=['reserved', 'auto', 'spot'],
|
|
help='Quota type, only available for phoenix-slurm>=0.2')
|
|
parser.add_argument(
|
|
'--summary',
|
|
action='store_true',
|
|
help='Summarize benchmark train results.')
|
|
parser.add_argument(
|
|
'--save',
|
|
action='store_true',
|
|
help='Save the summary and archive log files.')
|
|
parser.add_argument(
|
|
'--cfg-options',
|
|
nargs='+',
|
|
type=str,
|
|
default=[],
|
|
help='Config options for all config files.')
|
|
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def get_gpu_number(model_info):
|
|
config = osp.basename(model_info.config)
|
|
matches = re.match(r'.*[-_](\d+)xb(\d+).*', config)
|
|
if matches is None:
|
|
raise ValueError(
|
|
f'Cannot get gpu numbers from the config name {config}')
|
|
gpus = int(matches.groups()[0])
|
|
return gpus
|
|
|
|
|
|
def get_pretrain_epoch(model_info):
|
|
config = osp.basename(model_info.config)
|
|
matches = re.match(r'.*[-_](\d+)e[-_].*', config)
|
|
if matches is None:
|
|
raise ValueError(
|
|
f'Cannot get epoch setting from the config name {config}')
|
|
epoch = int(matches.groups()[0])
|
|
return epoch
|
|
|
|
|
|
def create_train_job_batch(commands, model_info, args, port, script_name):
|
|
|
|
fname = model_info.name
|
|
|
|
gpus = get_gpu_number(model_info)
|
|
gpus_per_node = min(gpus, 8)
|
|
|
|
config = Path(model_info.config)
|
|
assert config.exists(), f'"{fname}": {config} not found.'
|
|
|
|
work_dir = Path(args.work_dir) / fname
|
|
work_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
if args.quotatype is not None:
|
|
quota_cfg = f'--quotatype {args.quotatype} '
|
|
else:
|
|
quota_cfg = ''
|
|
|
|
launcher = 'none' if args.local else 'slurm'
|
|
job_name = f'{args.job_name}_{fname}'
|
|
job_script = (f'#!/bin/bash\n'
|
|
f'srun -p {args.partition} '
|
|
f'--job-name {job_name} '
|
|
f'--gres=gpu:{gpus_per_node} '
|
|
f'{quota_cfg}'
|
|
f'--ntasks-per-node={gpus_per_node} '
|
|
f'--ntasks={gpus} '
|
|
f'--cpus-per-task=12 '
|
|
f'--kill-on-bad-exit=1 '
|
|
f'python -u {script_name} {config} '
|
|
f'--work-dir={work_dir} '
|
|
f'--cfg-option env_cfg.dist_cfg.port={port} '
|
|
f'{" ".join(args.cfg_options)} '
|
|
f'default_hooks.checkpoint.max_keep_ckpts=1 '
|
|
f'--launcher={launcher}\n')
|
|
|
|
commands.append(f'echo "{config}"')
|
|
|
|
# downstream classification task
|
|
cls_config = None
|
|
task = getattr(model_info, 'task', None)
|
|
if task is not None:
|
|
for downstream in model_info.data['Downstream']:
|
|
if task == downstream['Results'][0]['Task']:
|
|
cls_config = downstream['Config']
|
|
break
|
|
else:
|
|
cls_config = None
|
|
|
|
if cls_config:
|
|
fname = model_info.name
|
|
|
|
gpus = get_gpu_number(model_info)
|
|
gpus_per_node = min(gpus, 8)
|
|
|
|
cls_config_path = Path(cls_config)
|
|
assert cls_config_path.exists(), f'"{fname}": {cls_config} not found.'
|
|
|
|
job_name = f'{args.job_name}_{fname}'
|
|
|
|
cls_work_dir = work_dir / Path(
|
|
cls_config.split('/')[-1].replace('.py', ''))
|
|
cls_work_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
srun_args = ''
|
|
if args.quotatype is not None:
|
|
srun_args = srun_args.join(f'--quotatype {args.quotatype}')
|
|
|
|
# get pretrain weights path
|
|
epoch = get_pretrain_epoch(model_info)
|
|
ckpt = work_dir / f'epoch_{epoch}.pth'
|
|
|
|
launcher = 'none' if args.local else 'slurm'
|
|
cls_job_script = (
|
|
f'\n'
|
|
f'mim train mmcls {cls_config} '
|
|
f'--launcher {launcher} '
|
|
f'-G {gpus} '
|
|
f'--gpus-per-node {gpus_per_node} '
|
|
f'--cpus-per-task 12 '
|
|
f'--partition {args.partition} '
|
|
f'--srun-args "{srun_args}" '
|
|
f'--work-dir {cls_work_dir} '
|
|
f'--cfg-option model.backbone.init_cfg.type=Pretrained '
|
|
f'model.backbone.init_cfg.checkpoint={ckpt} '
|
|
f'model.backbone.init_cfg.prefix=backbone. '
|
|
f'default_hooks.checkpoint.max_keep_ckpts=1 '
|
|
f'{" ".join(args.cfg_options)}\n')
|
|
|
|
commands.append(f'echo "{cls_config}"')
|
|
|
|
with open(work_dir / 'job.sh', 'w') as f:
|
|
f.write(job_script)
|
|
if cls_config:
|
|
f.write(cls_job_script)
|
|
|
|
commands.append(
|
|
f'nohup bash {work_dir}/job.sh > {work_dir}/out.log 2>&1 &')
|
|
|
|
return work_dir / 'job.sh'
|
|
|
|
|
|
def train(models, args):
|
|
script_name = osp.join('tools', 'train.py')
|
|
port = args.port
|
|
|
|
commands = []
|
|
|
|
for model_info in models.values():
|
|
script_path = create_train_job_batch(commands, model_info, args, port,
|
|
script_name)
|
|
port += 1
|
|
|
|
command_str = '\n'.join(commands)
|
|
|
|
preview = Table()
|
|
preview.add_column(str(script_path))
|
|
preview.add_column('Shell command preview')
|
|
preview.add_row(
|
|
Syntax.from_path(
|
|
script_path,
|
|
background_color='default',
|
|
line_numbers=True,
|
|
word_wrap=True),
|
|
Syntax(
|
|
command_str,
|
|
'bash',
|
|
background_color='default',
|
|
line_numbers=True,
|
|
word_wrap=True))
|
|
console.print(preview)
|
|
|
|
if args.run:
|
|
os.system(command_str)
|
|
else:
|
|
console.print('Please set "--run" to start the job')
|
|
|
|
|
|
def save_summary(summary_data, models_map, work_dir):
|
|
date = datetime.now().strftime('%Y%m%d-%H%M%S')
|
|
zip_path = work_dir / f'archive-{date}.zip'
|
|
zip_file = ZipFile(zip_path, 'w')
|
|
summary_path = work_dir / 'benchmark_summary.md'
|
|
file = open(summary_path, 'w')
|
|
headers = [
|
|
'Model', 'Top-1 Expected(%)', 'Top-1 (%)', 'Top-1 best(%)',
|
|
'best epoch', 'Top-5 Expected (%)', 'Top-5 (%)', 'Config', 'Log'
|
|
]
|
|
file.write('# Train Benchmark Regression Summary\n')
|
|
file.write('| ' + ' | '.join(headers) + ' |\n')
|
|
file.write('|:' + ':|:'.join(['---'] * len(headers)) + ':|\n')
|
|
for model_name, summary in summary_data.items():
|
|
if len(summary) == 0:
|
|
# Skip models without results
|
|
continue
|
|
row = [model_name]
|
|
if 'Top 1 Accuracy' in summary:
|
|
metric = summary['Top 1 Accuracy']
|
|
row.append(f"{metric['expect']:.2f}")
|
|
row.append(f"{metric['last']:.2f}")
|
|
row.append(f"{metric['best']:.2f}")
|
|
row.append(f"{metric['best_epoch']:.2f}")
|
|
else:
|
|
row.extend([''] * 4)
|
|
if 'Top 5 Accuracy' in summary:
|
|
metric = summary['Top 5 Accuracy']
|
|
row.append(f"{metric['expect']:.2f}")
|
|
row.append(f"{metric['last']:.2f}")
|
|
else:
|
|
row.extend([''] * 2)
|
|
|
|
model_info = models_map[model_name]
|
|
row.append(model_info.config)
|
|
row.append(str(summary['log_file'].relative_to(work_dir)))
|
|
zip_file.write(summary['log_file'])
|
|
file.write('| ' + ' | '.join(row) + ' |\n')
|
|
file.close()
|
|
zip_file.write(summary_path)
|
|
zip_file.close()
|
|
print('Summary file saved at ' + str(summary_path))
|
|
print('Log files archived at ' + str(zip_path))
|
|
|
|
|
|
def show_summary(summary_data):
|
|
table = Table(title='Train Benchmark Regression Summary')
|
|
table.add_column('Model')
|
|
for metric in METRICS_MAP:
|
|
table.add_column(f'{metric} (expect)')
|
|
table.add_column(f'{metric}')
|
|
table.add_column(f'{metric} (best)')
|
|
|
|
def set_color(value, expect):
|
|
if value > expect:
|
|
return 'green'
|
|
elif value > expect - 0.2:
|
|
return 'white'
|
|
else:
|
|
return 'red'
|
|
|
|
for model_name, summary in summary_data.items():
|
|
row = [model_name]
|
|
for metric_key in METRICS_MAP:
|
|
if metric_key in summary:
|
|
metric = summary[metric_key]
|
|
expect = metric['expect']
|
|
last = metric['last']
|
|
last_epoch = metric['last_epoch']
|
|
last_color = set_color(last, expect)
|
|
best = metric['best']
|
|
best_color = set_color(best, expect)
|
|
best_epoch = metric['best_epoch']
|
|
row.append(f'{expect:.2f}')
|
|
row.append(
|
|
f'[{last_color}]{last:.2f}[/{last_color}] ({last_epoch})')
|
|
row.append(
|
|
f'[{best_color}]{best:.2f}[/{best_color}] ({best_epoch})')
|
|
table.add_row(*row)
|
|
|
|
console.print(table)
|
|
|
|
|
|
def summary(models, args):
|
|
|
|
work_dir = Path(args.work_dir)
|
|
dir_map = {p.name: p for p in work_dir.iterdir() if p.is_dir()}
|
|
|
|
summary_data = {}
|
|
for model_name, model_info in models.items():
|
|
|
|
summary_data[model_name] = {}
|
|
|
|
if model_name not in dir_map:
|
|
continue
|
|
|
|
# Skip if not found any vis_data folder.
|
|
sub_dir = dir_map[model_name]
|
|
log_files = [f for f in sub_dir.glob('*/*/vis_data/scalars.json')]
|
|
if len(log_files) == 0:
|
|
continue
|
|
log_file = sorted(log_files)[-1]
|
|
|
|
# parse train log
|
|
with open(log_file) as f:
|
|
json_logs = [json.loads(s) for s in f.readlines()]
|
|
val_logs = [
|
|
log for log in json_logs
|
|
# TODO: need a better method to extract validate log
|
|
if 'loss' not in log and 'accuracy/top1' in log
|
|
]
|
|
|
|
if len(val_logs) == 0:
|
|
continue
|
|
|
|
for downstream in model_info.data['Downstream']:
|
|
if model_info.task == downstream['Results'][0]['Task']:
|
|
expect_metrics = downstream['Results'][0]['Metrics']
|
|
break
|
|
|
|
# extract metrics
|
|
summary = {'log_file': log_file}
|
|
for key_yml, key_res in METRICS_MAP.items():
|
|
if key_yml in expect_metrics:
|
|
assert key_res in val_logs[-1], \
|
|
f'{model_name}: No metric "{key_res}"'
|
|
expect_result = float(expect_metrics[key_yml])
|
|
last = float(val_logs[-1][key_res])
|
|
best_log, best_epoch = sorted(
|
|
zip(val_logs, range(len(val_logs))),
|
|
key=lambda x: x[0][key_res])[-1]
|
|
best = float(best_log[key_res])
|
|
|
|
summary[key_yml] = dict(
|
|
expect=expect_result,
|
|
last=last,
|
|
last_epoch=len(val_logs),
|
|
best=best,
|
|
best_epoch=best_epoch + 1)
|
|
summary_data[model_name].update(summary)
|
|
|
|
show_summary(summary_data)
|
|
if args.save:
|
|
save_summary(summary_data, models, work_dir)
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
model_index_file = MMSELFSUP_ROOT / 'model-index.yml'
|
|
model_index = load(str(model_index_file))
|
|
model_index.build_models_with_collections()
|
|
all_models = {model.name: model for model in model_index.models}
|
|
|
|
with open(Path(__file__).parent / 'models.yml', 'r') as f:
|
|
train_items = yaml.safe_load(f)
|
|
models = OrderedDict()
|
|
for item in train_items:
|
|
name = item['Name']
|
|
model_info = all_models[item['Name']]
|
|
model_info.cycle = item.get('Cycle', None)
|
|
model_info.task = item.get('Task', None)
|
|
cycle = getattr(model_info, 'cycle', 'month')
|
|
cycle_level = CYCLE_LEVELS.index(cycle)
|
|
if cycle_level in args.range:
|
|
models[name] = model_info
|
|
|
|
if args.models:
|
|
patterns = [re.compile(pattern) for pattern in args.models]
|
|
filter_models = {}
|
|
for k, v in models.items():
|
|
if any([re.match(pattern, k) for pattern in patterns]):
|
|
filter_models[k] = v
|
|
if len(filter_models) == 0:
|
|
print('No model found, please specify models in:')
|
|
print('\n'.join(models.keys()))
|
|
return
|
|
models = filter_models
|
|
|
|
if args.summary:
|
|
summary(models, args)
|
|
else:
|
|
train(models, args)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|