321 lines
10 KiB
Python
321 lines
10 KiB
Python
import argparse
|
|
import fnmatch
|
|
import logging
|
|
import os
|
|
import os.path as osp
|
|
import pickle
|
|
from collections import OrderedDict, defaultdict
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
from modelindex.load_model_index import load
|
|
from rich.console import Console
|
|
from rich.syntax import Syntax
|
|
from rich.table import Table
|
|
from utils import METRICS_MAP, MMCLS_ROOT, substitute_weights
|
|
|
|
# Avoid to import MMPretrain to accelerate speed to show summary
|
|
|
|
console = Console()
|
|
logger = logging.getLogger('test')
|
|
logger.addHandler(logging.StreamHandler())
|
|
logger.addHandler(logging.FileHandler('benchmark_test.log', mode='w'))
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
description="Test all models' accuracy in model-index.yml")
|
|
parser.add_argument('checkpoint_root', help='Checkpoint file root path.')
|
|
parser.add_argument(
|
|
'--local', action='store_true', help='run at local instead of slurm.')
|
|
parser.add_argument(
|
|
'--models', nargs='+', type=str, help='Specify model names to run.')
|
|
parser.add_argument(
|
|
'--run', action='store_true', help='run script directly')
|
|
parser.add_argument(
|
|
'--summary',
|
|
action='store_true',
|
|
help='Summarize benchmark test results.')
|
|
parser.add_argument('--save', action='store_true', help='Save the summary')
|
|
parser.add_argument(
|
|
'--gpus', type=int, default=1, help='How many GPUS to use.')
|
|
parser.add_argument(
|
|
'--no-skip',
|
|
action='store_true',
|
|
help='Whether to skip models without results record in the metafile.')
|
|
parser.add_argument(
|
|
'--work-dir',
|
|
default='work_dirs/benchmark_test',
|
|
help='the dir to save metric')
|
|
parser.add_argument('--port', type=int, default=29666, help='dist port')
|
|
parser.add_argument(
|
|
'--partition',
|
|
type=str,
|
|
default='mm_model',
|
|
help='(for slurm) Cluster partition to use.')
|
|
parser.add_argument(
|
|
'--job-name',
|
|
type=str,
|
|
default='cls-test-benchmark',
|
|
help='(for slurm) Slurm job name prefix')
|
|
parser.add_argument(
|
|
'--quotatype',
|
|
default=None,
|
|
choices=['reserved', 'auto', 'spot'],
|
|
help='(for slurm) Quota type, only available for phoenix-slurm>=0.2')
|
|
parser.add_argument(
|
|
'--cfg-options',
|
|
nargs='+',
|
|
type=str,
|
|
default=[],
|
|
help='Config options for all config files.')
|
|
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def create_test_job_batch(commands, model_info, args, port, script_name):
|
|
model_name = model_info.name
|
|
config = Path(model_info.config)
|
|
|
|
if model_info.weights is not None:
|
|
checkpoint = substitute_weights(model_info.weights,
|
|
args.checkpoint_root)
|
|
if checkpoint is None:
|
|
logger.warning(f'{model_name}: {checkpoint} not found.')
|
|
return None
|
|
else:
|
|
return None
|
|
|
|
job_name = f'{args.job_name}_{model_name}'
|
|
work_dir = Path(args.work_dir) / model_name
|
|
work_dir.mkdir(parents=True, exist_ok=True)
|
|
result_file = work_dir / 'result.pkl'
|
|
|
|
if args.quotatype is not None:
|
|
quota_cfg = f'#SBATCH --quotatype {args.quotatype}'
|
|
else:
|
|
quota_cfg = ''
|
|
|
|
if not args.local:
|
|
launcher = 'srun python'
|
|
runner = 'slurm'
|
|
elif args.gpus > 1:
|
|
launcher = 'pytorch'
|
|
runner = ('torchrun --master_addr="127.0.0.1" '
|
|
f'--master_port={port} --nproc_per_node={args.gpus}')
|
|
else:
|
|
launcher = 'none'
|
|
runner = 'python -u'
|
|
|
|
job_script = (f'#!/bin/bash\n'
|
|
f'#SBATCH --output {work_dir}/job.%j.out\n'
|
|
f'#SBATCH --partition={args.partition}\n'
|
|
f'#SBATCH --job-name {job_name}\n'
|
|
f'#SBATCH --gres=gpu:{min(8, args.gpus)}\n'
|
|
f'{quota_cfg}\n'
|
|
f'#SBATCH --ntasks-per-node={min(8, args.gpus)}\n'
|
|
f'#SBATCH --ntasks={args.gpus}\n'
|
|
f'#SBATCH --cpus-per-task=5\n\n'
|
|
f'{runner} {script_name} {config} {checkpoint} '
|
|
f'--work-dir={work_dir} --cfg-option '
|
|
f'env_cfg.dist_cfg.port={port} '
|
|
f'{" ".join(args.cfg_options)} '
|
|
f'--out={result_file} --out-item="metrics" '
|
|
f'--launcher={launcher}\n')
|
|
|
|
with open(work_dir / 'job.sh', 'w') as f:
|
|
f.write(job_script)
|
|
|
|
commands.append(f'echo "{config}"')
|
|
if args.local:
|
|
commands.append(f'bash {work_dir}/job.sh')
|
|
else:
|
|
commands.append(f'sbatch {work_dir}/job.sh')
|
|
|
|
return work_dir / 'job.sh'
|
|
|
|
|
|
def test(models, args):
|
|
script_name = osp.join('tools', 'test.py')
|
|
port = args.port
|
|
|
|
commands = []
|
|
|
|
preview_script = ''
|
|
for model_info in models.values():
|
|
|
|
if model_info.results is None:
|
|
# Skip pre-train model
|
|
continue
|
|
|
|
script_path = create_test_job_batch(commands, model_info, args, port,
|
|
script_name)
|
|
preview_script = script_path or preview_script
|
|
port += 1
|
|
|
|
command_str = '\n'.join(commands)
|
|
|
|
preview = Table()
|
|
preview.add_column(str(preview_script))
|
|
preview.add_column('Shell command preview')
|
|
preview.add_row(
|
|
Syntax.from_path(
|
|
preview_script,
|
|
background_color='default',
|
|
line_numbers=True,
|
|
word_wrap=True),
|
|
Syntax(
|
|
command_str,
|
|
'bash',
|
|
background_color='default',
|
|
line_numbers=True,
|
|
word_wrap=True))
|
|
console.print(preview)
|
|
|
|
if args.run:
|
|
os.system(command_str)
|
|
else:
|
|
console.print('Please set "--run" to start the job')
|
|
|
|
|
|
def save_summary(summary_data, work_dir):
|
|
summary_path = work_dir / 'test_benchmark_summary.csv'
|
|
file = open(summary_path, 'w')
|
|
columns = defaultdict(list)
|
|
for model_name, summary in summary_data.items():
|
|
if len(summary) == 0:
|
|
# Skip models without results
|
|
continue
|
|
columns['Name'].append(model_name)
|
|
|
|
for metric_key in METRICS_MAP:
|
|
if metric_key in summary:
|
|
metric = summary[metric_key]
|
|
expect = round(metric['expect'], 2)
|
|
result = round(metric['result'], 2)
|
|
columns[f'{metric_key} (expect)'].append(str(expect))
|
|
columns[f'{metric_key}'].append(str(result))
|
|
else:
|
|
columns[f'{metric_key} (expect)'].append('')
|
|
columns[f'{metric_key}'].append('')
|
|
|
|
columns = {
|
|
field: column
|
|
for field, column in columns.items() if ''.join(column)
|
|
}
|
|
file.write(','.join(columns.keys()) + '\n')
|
|
for row in zip(*columns.values()):
|
|
file.write(','.join(row) + '\n')
|
|
file.close()
|
|
logger.info('Summary file saved at ' + str(summary_path))
|
|
|
|
|
|
def show_summary(summary_data):
|
|
table = Table(title='Test Benchmark Regression Summary')
|
|
table.add_column('Name')
|
|
for metric in METRICS_MAP:
|
|
table.add_column(f'{metric} (expect)')
|
|
table.add_column(f'{metric}')
|
|
table.add_column('Date')
|
|
|
|
def set_color(value, expect):
|
|
if value > expect + 0.01:
|
|
return 'green'
|
|
elif value >= expect - 0.01:
|
|
return 'white'
|
|
else:
|
|
return 'red'
|
|
|
|
for model_name, summary in summary_data.items():
|
|
row = [model_name]
|
|
for metric_key in METRICS_MAP:
|
|
if metric_key in summary:
|
|
metric = summary[metric_key]
|
|
expect = round(metric['expect'], 2)
|
|
result = round(metric['result'], 2)
|
|
color = set_color(result, expect)
|
|
row.append(f'{expect:.2f}')
|
|
row.append(f'[{color}]{result:.2f}[/{color}]')
|
|
else:
|
|
row.extend([''] * 2)
|
|
if 'date' in summary:
|
|
row.append(summary['date'])
|
|
else:
|
|
row.append('')
|
|
table.add_row(*row)
|
|
|
|
# Remove empty columns
|
|
table.columns = [
|
|
column for column in table.columns if ''.join(column._cells)
|
|
]
|
|
console.print(table)
|
|
|
|
|
|
def summary(models, args):
|
|
work_dir = Path(args.work_dir)
|
|
|
|
summary_data = {}
|
|
for model_name, model_info in models.items():
|
|
|
|
if model_info.results is None and not args.no_skip:
|
|
continue
|
|
|
|
# Skip if not found result file.
|
|
result_file = work_dir / model_name / 'result.pkl'
|
|
if not result_file.exists():
|
|
summary_data[model_name] = {}
|
|
continue
|
|
|
|
with open(result_file, 'rb') as file:
|
|
results = pickle.load(file)
|
|
date = datetime.fromtimestamp(result_file.lstat().st_mtime)
|
|
|
|
expect_metrics = model_info.results[0].metrics
|
|
|
|
# extract metrics
|
|
summary = {'date': date.strftime('%Y-%m-%d')}
|
|
for key_yml, key_res in METRICS_MAP.items():
|
|
if key_yml in expect_metrics and key_res in results:
|
|
expect_result = float(expect_metrics[key_yml])
|
|
result = float(results[key_res])
|
|
summary[key_yml] = dict(expect=expect_result, result=result)
|
|
|
|
summary_data[model_name] = summary
|
|
|
|
show_summary(summary_data)
|
|
if args.save:
|
|
save_summary(summary_data, work_dir)
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
# parse model-index.yml
|
|
model_index_file = MMCLS_ROOT / 'model-index.yml'
|
|
model_index = load(str(model_index_file))
|
|
model_index.build_models_with_collections()
|
|
models = OrderedDict({model.name: model for model in model_index.models})
|
|
|
|
if args.models:
|
|
filter_models = {}
|
|
for pattern in args.models:
|
|
filter_models.update({
|
|
name: models[name]
|
|
for name in fnmatch.filter(models, pattern + '*')
|
|
})
|
|
if len(filter_models) == 0:
|
|
logger.error('No model found, please specify models in:\n' +
|
|
'\n'.join(models.keys()))
|
|
return
|
|
models = filter_models
|
|
|
|
if args.summary:
|
|
summary(models, args)
|
|
else:
|
|
test(models, args)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|