import argparse import json import os import os.path as osp import re from datetime import datetime from pathlib import Path from zipfile import ZipFile from mmcv import Config from rich.console import Console from rich.syntax import Syntax from rich.table import Table console = Console() METRICS_MAP = { 'Top 1 Accuracy': 'accuracy_top-1', 'Top 5 Accuracy': 'accuracy_top-5' } def parse_args(): parser = argparse.ArgumentParser( description='Train models (in bench_train.yml) and compare accuracy.') parser.add_argument( 'partition', type=str, help='Cluster partition to use.') parser.add_argument( '--job-name', type=str, default='cls-train-benchmark', help='Slurm job name prefix') parser.add_argument('--port', type=int, default=29666, help='dist port') parser.add_argument( '--models', nargs='+', type=str, help='Specify model names to run.') parser.add_argument( '--work-dir', default='work_dirs/benchmark_train', help='the dir to save train log') parser.add_argument( '--run', action='store_true', help='run script directly') parser.add_argument( '--local', action='store_true', help='run at local instead of cluster.') parser.add_argument( '--mail', type=str, help='Mail address to watch train status.') parser.add_argument( '--mail-type', nargs='+', default=['BEGIN', 'END', 'FAIL'], choices=['NONE', 'BEGIN', 'END', 'FAIL', 'REQUEUE', 'ALL'], help='Mail address to watch train status.') parser.add_argument( '--quotatype', default=None, choices=['reserved', 'auto', 'spot'], help='Quota type, only available for phoenix-slurm>=0.2') parser.add_argument( '--summary', action='store_true', help='Summarize benchmark train results.') parser.add_argument( '--save', action='store_true', help='Save the summary and archive log files.') args = parser.parse_args() return args def create_train_job_batch(commands, model_info, args, port, script_name): fname = model_info.Name assert 'Gpus' in model_info, \ f"Haven't specify gpu numbers for {fname}" gpus = model_info.Gpus config = Path(model_info.Config) assert config.exists(), f'"{fname}": {config} not found.' job_name = f'{args.job_name}_{fname}' work_dir = Path(args.work_dir) / fname work_dir.mkdir(parents=True, exist_ok=True) if args.mail is not None and 'NONE' not in args.mail_type: mail_cfg = (f'#SBATCH --mail {args.mail}\n' f'#SBATCH --mail-type {args.mail_type}\n') else: mail_cfg = '' if args.quotatype is not None: quota_cfg = f'#SBATCH --quotatype {args.quotatype}\n' else: quota_cfg = '' launcher = 'none' if args.local else 'slurm' runner = 'python' if args.local else 'srun python' job_script = (f'#!/bin/bash\n' f'#SBATCH --output {work_dir}/job.%j.out\n' f'#SBATCH --partition={args.partition}\n' f'#SBATCH --job-name {job_name}\n' f'#SBATCH --gres=gpu:8\n' f'{mail_cfg}{quota_cfg}' f'#SBATCH --ntasks-per-node=8\n' f'#SBATCH --ntasks={gpus}\n' f'#SBATCH --cpus-per-task=5\n\n' f'{runner} -u {script_name} {config} ' f'--work-dir={work_dir} --cfg-option ' f'dist_params.port={port} ' f'checkpoint_config.max_keep_ckpts=10 ' f'--launcher={launcher}\n') with open(work_dir / 'job.sh', 'w') as f: f.write(job_script) commands.append(f'echo "{config}"') if args.local: commands.append(f'bash {work_dir}/job.sh') else: commands.append(f'sbatch {work_dir}/job.sh') return work_dir / 'job.sh' def train(args): models_cfg = Config.fromfile(Path(__file__).parent / 'bench_train.yml') models = {model.Name: model for model in models_cfg.Models} script_name = osp.join('tools', 'train.py') port = args.port commands = [] if args.models: patterns = [re.compile(pattern) for pattern in args.models] filter_models = {} for k, v in models.items(): if any([re.match(pattern, k) for pattern in patterns]): filter_models[k] = v if len(filter_models) == 0: print('No model found, please specify models in:') print('\n'.join(models.keys())) return models = filter_models for model_info in models.values(): months = model_info.get('Months', range(1, 13)) if datetime.now().month in months: script_path = create_train_job_batch(commands, model_info, args, port, script_name) port += 1 command_str = '\n'.join(commands) preview = Table() preview.add_column(str(script_path)) preview.add_column('Shell command preview') preview.add_row( Syntax.from_path( script_path, background_color='default', line_numbers=True, word_wrap=True), Syntax( command_str, 'bash', background_color='default', line_numbers=True, word_wrap=True)) console.print(preview) if args.run: os.system(command_str) else: console.print('Please set "--run" to start the job') def save_summary(summary_data, models_map, work_dir): date = datetime.now().strftime('%Y%m%d-%H%M%S') zip_path = work_dir / f'archive-{date}.zip' zip_file = ZipFile(zip_path, 'w') summary_path = work_dir / 'benchmark_summary.md' file = open(summary_path, 'w') headers = [ 'Model', 'Top-1 Expected(%)', 'Top-1 (%)', 'Top-1 best(%)', 'best epoch', 'Top-5 Expected (%)', 'Top-5 (%)', 'Config', 'Log' ] file.write('# Train Benchmark Regression Summary\n') file.write('| ' + ' | '.join(headers) + ' |\n') file.write('|:' + ':|:'.join(['---'] * len(headers)) + ':|\n') for model_name, summary in summary_data.items(): if len(summary) == 0: # Skip models without results continue row = [model_name] if 'Top 1 Accuracy' in summary: metric = summary['Top 1 Accuracy'] row.append(f"{metric['expect']:.2f}") row.append(f"{metric['last']:.2f}") row.append(f"{metric['best']:.2f}") row.append(f"{metric['best_epoch']:.2f}") else: row.extend([''] * 4) if 'Top 5 Accuracy' in summary: metric = summary['Top 5 Accuracy'] row.append(f"{metric['expect']:.2f}") row.append(f"{metric['last']:.2f}") else: row.extend([''] * 2) model_info = models_map[model_name] row.append(model_info.Config) row.append(str(summary['log_file'].relative_to(work_dir))) zip_file.write(summary['log_file']) file.write('| ' + ' | '.join(row) + ' |\n') file.close() zip_file.write(summary_path) zip_file.close() print('Summary file saved at ' + str(summary_path)) print('Log files archived at ' + str(zip_path)) def show_summary(summary_data): table = Table(title='Train Benchmark Regression Summary') table.add_column('Model') for metric in METRICS_MAP: table.add_column(f'{metric} (expect)') table.add_column(f'{metric}') table.add_column(f'{metric} (best)') def set_color(value, expect): if value > expect: return 'green' elif value > expect - 0.2: return 'white' else: return 'red' for model_name, summary in summary_data.items(): row = [model_name] for metric_key in METRICS_MAP: if metric_key in summary: metric = summary[metric_key] expect = metric['expect'] last = metric['last'] last_color = set_color(last, expect) best = metric['best'] best_color = set_color(best, expect) best_epoch = metric['best_epoch'] row.append(f'{expect:.2f}') row.append(f'[{last_color}]{last:.2f}[/{last_color}]') row.append( f'[{best_color}]{best:.2f}[/{best_color}] ({best_epoch})') table.add_row(*row) console.print(table) def summary(args): models_cfg = Config.fromfile(Path(__file__).parent / 'bench_train.yml') models = {model.Name: model for model in models_cfg.Models} work_dir = Path(args.work_dir) dir_map = {p.name: p for p in work_dir.iterdir() if p.is_dir()} if args.models: patterns = [re.compile(pattern) for pattern in args.models] filter_models = {} for k, v in models.items(): if any([re.match(pattern, k) for pattern in patterns]): filter_models[k] = v if len(filter_models) == 0: print('No model found, please specify models in:') print('\n'.join(models.keys())) return models = filter_models summary_data = {} for model_name, model_info in models.items(): # Skip if not found any log file. if model_name not in dir_map: summary_data[model_name] = {} continue sub_dir = dir_map[model_name] log_files = list(sub_dir.glob('*.log.json')) if len(log_files) == 0: continue log_file = sorted(log_files)[-1] # parse train log with open(log_file) as f: json_logs = [json.loads(s) for s in f.readlines()] val_logs = [ log for log in json_logs if 'mode' in log and log['mode'] == 'val' ] if len(val_logs) == 0: continue expect_metrics = model_info.Results[0].Metrics # extract metrics summary = {'log_file': log_file} for key_yml, key_res in METRICS_MAP.items(): if key_yml in expect_metrics: assert key_res in val_logs[-1], \ f'{model_name}: No metric "{key_res}"' expect_result = float(expect_metrics[key_yml]) last = float(val_logs[-1][key_res]) best_log = sorted(val_logs, key=lambda x: x[key_res])[-1] best = float(best_log[key_res]) best_epoch = int(best_log['epoch']) summary[key_yml] = dict( expect=expect_result, last=last, best=best, best_epoch=best_epoch) summary_data[model_name] = summary show_summary(summary_data) if args.save: save_summary(summary_data, models, work_dir) def main(): args = parse_args() if args.summary: summary(args) else: train(args) if __name__ == '__main__': main()