mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add bulk_runner script and updates to benchmark.py and validate.py for better error handling in bulk runs (used for benchmark and validation result runs). Improved batch size decay stepping on retry...
This commit is contained in:
parent
4547920f85
commit
0dbd9352ce
33
benchmark.py
33
benchmark.py
@ -21,7 +21,7 @@ import torch.nn.parallel
|
||||
from timm.data import resolve_data_config
|
||||
from timm.models import create_model, is_model, list_models
|
||||
from timm.optim import create_optimizer_v2
|
||||
from timm.utils import setup_default_logging, set_jit_fuser
|
||||
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry
|
||||
|
||||
has_apex = False
|
||||
try:
|
||||
@ -506,34 +506,31 @@ class ProfileRunner(BenchmarkRunner):
|
||||
return results
|
||||
|
||||
|
||||
def decay_batch_exp(batch_size, factor=0.5, divisor=16):
|
||||
out_batch_size = batch_size * factor
|
||||
if out_batch_size > divisor:
|
||||
out_batch_size = (out_batch_size + 1) // divisor * divisor
|
||||
else:
|
||||
out_batch_size = batch_size - 1
|
||||
return max(0, int(out_batch_size))
|
||||
|
||||
|
||||
def _try_run(model_name, bench_fn, bench_kwargs, initial_batch_size, no_batch_size_retry=False):
|
||||
def _try_run(
|
||||
model_name,
|
||||
bench_fn,
|
||||
bench_kwargs,
|
||||
initial_batch_size,
|
||||
no_batch_size_retry=False
|
||||
):
|
||||
batch_size = initial_batch_size
|
||||
results = dict()
|
||||
error_str = 'Unknown'
|
||||
while batch_size >= 1:
|
||||
torch.cuda.empty_cache()
|
||||
while batch_size:
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
bench = bench_fn(model_name=model_name, batch_size=batch_size, **bench_kwargs)
|
||||
results = bench.run()
|
||||
return results
|
||||
except RuntimeError as e:
|
||||
error_str = str(e)
|
||||
if 'channels_last' in error_str:
|
||||
_logger.error(f'{model_name} not supported in channels_last, skipping.')
|
||||
break
|
||||
_logger.error(f'"{error_str}" while running benchmark.')
|
||||
if not check_batch_size_retry(error_str):
|
||||
_logger.error(f'Unrecoverable error encountered while benchmarking {model_name}, skipping.')
|
||||
break
|
||||
if no_batch_size_retry:
|
||||
break
|
||||
batch_size = decay_batch_exp(batch_size)
|
||||
batch_size = decay_batch_step(batch_size)
|
||||
_logger.warning(f'Reducing batch size to {batch_size} for retry.')
|
||||
results['error'] = error_str
|
||||
return results
|
||||
@ -586,6 +583,8 @@ def benchmark(args):
|
||||
if prefix and 'error' not in run_results:
|
||||
run_results = {'_'.join([prefix, k]): v for k, v in run_results.items()}
|
||||
model_results.update(run_results)
|
||||
if 'error' in run_results:
|
||||
break
|
||||
if 'error' not in model_results:
|
||||
param_count = model_results.pop('infer_param_count', model_results.pop('train_param_count', 0))
|
||||
model_results.setdefault('param_count', param_count)
|
||||
|
184
bulk_runner.py
Executable file
184
bulk_runner.py
Executable file
@ -0,0 +1,184 @@
|
||||
#!/usr/bin/env python3
|
||||
""" Bulk Model Script Runner
|
||||
|
||||
Run validation or benchmark script in separate process for each model
|
||||
|
||||
Benchmark all 'vit*' models:
|
||||
python bulk_runner.py --model-list 'vit*' --results-file vit_bench.csv benchmark.py --amp -b 512
|
||||
|
||||
Validate all models:
|
||||
python bulk_runner.py --model-list all --results-file val.csv --pretrained validate.py /imagenet/validation/ --amp -b 512 --retry
|
||||
|
||||
Hacked together by Ross Wightman (https://github.com/rwightman)
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import csv
|
||||
import json
|
||||
import subprocess
|
||||
import time
|
||||
from typing import Callable, List, Tuple, Union
|
||||
|
||||
|
||||
from timm.models import is_model, list_models
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='Per-model process launcher')
|
||||
|
||||
# model and results args
|
||||
parser.add_argument(
|
||||
'--model-list', metavar='NAME', default='',
|
||||
help='txt file based list of model names to benchmark')
|
||||
parser.add_argument(
|
||||
'--results-file', default='', type=str, metavar='FILENAME',
|
||||
help='Output csv file for validation results (summary)')
|
||||
parser.add_argument(
|
||||
'--sort-key', default='', type=str, metavar='COL',
|
||||
help='Specify sort key for results csv')
|
||||
parser.add_argument(
|
||||
"--pretrained", action='store_true',
|
||||
help="only run models with pretrained weights")
|
||||
|
||||
parser.add_argument(
|
||||
"--delay",
|
||||
type=float,
|
||||
default=0,
|
||||
help="Interval, in seconds, to delay between model invocations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start_method", type=str, default="spawn", choices=["spawn", "fork", "forkserver"],
|
||||
help="Multiprocessing start method to use when creating workers.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_python",
|
||||
help="Skip prepending the script with 'python' - just execute it directly. Useful "
|
||||
"when the script is not a Python script.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--module",
|
||||
help="Change each process to interpret the launch script as a Python module, executing "
|
||||
"with the same behavior as 'python -m'.",
|
||||
)
|
||||
|
||||
# positional
|
||||
parser.add_argument(
|
||||
"script", type=str,
|
||||
help="Full path to the program/script to be launched for each model config.",
|
||||
)
|
||||
parser.add_argument("script_args", nargs=argparse.REMAINDER)
|
||||
|
||||
|
||||
def cmd_from_args(args) -> Tuple[Union[Callable, str], List[str]]:
|
||||
# If ``args`` not passed, defaults to ``sys.argv[:1]``
|
||||
with_python = not args.no_python
|
||||
cmd: Union[Callable, str]
|
||||
cmd_args = []
|
||||
if with_python:
|
||||
cmd = os.getenv("PYTHON_EXEC", sys.executable)
|
||||
cmd_args.append("-u")
|
||||
if args.module:
|
||||
cmd_args.append("-m")
|
||||
cmd_args.append(args.script)
|
||||
else:
|
||||
if args.module:
|
||||
raise ValueError(
|
||||
"Don't use both the '--no_python' flag"
|
||||
" and the '--module' flag at the same time."
|
||||
)
|
||||
cmd = args.script
|
||||
cmd_args.extend(args.script_args)
|
||||
|
||||
return cmd, cmd_args
|
||||
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
cmd, cmd_args = cmd_from_args(args)
|
||||
|
||||
model_cfgs = []
|
||||
model_names = []
|
||||
if args.model_list == 'all':
|
||||
# NOTE should make this config, for validation / benchmark runs the focus is 1k models,
|
||||
# so we filter out 21/22k and some other unusable heads. This will change in the future...
|
||||
exclude_model_filters = ['*in21k', '*in22k', '*dino', '*_22k']
|
||||
model_names = list_models(
|
||||
pretrained=args.pretrained, # only include models w/ pretrained checkpoints if set
|
||||
exclude_filters=exclude_model_filters
|
||||
)
|
||||
model_cfgs = [(n, None) for n in model_names]
|
||||
elif not is_model(args.model_list):
|
||||
# model name doesn't exist, try as wildcard filter
|
||||
model_names = list_models(args.model_list)
|
||||
model_cfgs = [(n, None) for n in model_names]
|
||||
|
||||
if not model_cfgs and os.path.exists(args.model_list):
|
||||
with open(args.model_list) as f:
|
||||
model_names = [line.rstrip() for line in f]
|
||||
model_cfgs = [(n, None) for n in model_names]
|
||||
|
||||
if len(model_cfgs):
|
||||
results_file = args.results_file or './results.csv'
|
||||
results = []
|
||||
errors = []
|
||||
print('Running script on these models: {}'.format(', '.join(model_names)))
|
||||
if not args.sort_key:
|
||||
if 'benchmark' in args.script:
|
||||
if any(['train' in a for a in args.script_args]):
|
||||
sort_key = 'train_samples_per_sec'
|
||||
else:
|
||||
sort_key = 'infer_samples_per_sec'
|
||||
else:
|
||||
sort_key = 'top1'
|
||||
else:
|
||||
sort_key = args.sort_key
|
||||
print(f'Script: {args.script}, Args: {args.script_args}, Sort key: {sort_key}')
|
||||
|
||||
try:
|
||||
for m, _ in model_cfgs:
|
||||
if not m:
|
||||
continue
|
||||
args_str = (cmd, *[str(e) for e in cmd_args], '--model', m)
|
||||
try:
|
||||
o = subprocess.check_output(args=args_str).decode('utf-8').split('--result')[-1]
|
||||
r = json.loads(o)
|
||||
results.append(r)
|
||||
except Exception as e:
|
||||
# FIXME batch_size retry loop is currently done in either validation.py or benchmark.py
|
||||
# for further robustness (but more overhead), we may want to manage that by looping here...
|
||||
errors.append(dict(model=m, error=str(e)))
|
||||
if args.delay:
|
||||
time.sleep(args.delay)
|
||||
except KeyboardInterrupt as e:
|
||||
pass
|
||||
|
||||
errors.extend(list(filter(lambda x: 'error' in x, results)))
|
||||
if errors:
|
||||
print(f'{len(errors)} models had errors during run.')
|
||||
for e in errors:
|
||||
print(f"\t {e['model']} ({e.get('error', 'Unknown')})")
|
||||
results = list(filter(lambda x: 'error' not in x, results))
|
||||
|
||||
no_sortkey = list(filter(lambda x: sort_key not in x, results))
|
||||
if no_sortkey:
|
||||
print(f'{len(no_sortkey)} results missing sort key, skipping sort.')
|
||||
else:
|
||||
results = sorted(results, key=lambda x: x[sort_key], reverse=True)
|
||||
|
||||
if len(results):
|
||||
print(f'{len(results)} models run successfully. Saving results to {results_file}.')
|
||||
write_results(results_file, results)
|
||||
|
||||
|
||||
def write_results(results_file, results):
|
||||
with open(results_file, mode='w') as cf:
|
||||
dw = csv.DictWriter(cf, fieldnames=results[0].keys())
|
||||
dw.writeheader()
|
||||
for r in results:
|
||||
dw.writerow(r)
|
||||
cf.flush()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -2,6 +2,7 @@ from .agc import adaptive_clip_grad
|
||||
from .checkpoint_saver import CheckpointSaver
|
||||
from .clip_grad import dispatch_clip_grad
|
||||
from .cuda import ApexScaler, NativeScaler
|
||||
from .decay_batch import decay_batch_step, check_batch_size_retry
|
||||
from .distributed import distribute_bn, reduce_tensor
|
||||
from .jit import set_jit_legacy, set_jit_fuser
|
||||
from .log import setup_default_logging, FormatterNoInfo
|
||||
|
43
timm/utils/decay_batch.py
Normal file
43
timm/utils/decay_batch.py
Normal file
@ -0,0 +1,43 @@
|
||||
""" Batch size decay and retry helpers.
|
||||
|
||||
Copyright 2022 Ross Wightman
|
||||
"""
|
||||
import math
|
||||
|
||||
|
||||
def decay_batch_step(batch_size, num_intra_steps=2, no_odd=False):
|
||||
""" power of two batch-size decay with intra steps
|
||||
|
||||
Decay by stepping between powers of 2:
|
||||
* determine power-of-2 floor of current batch size (base batch size)
|
||||
* divide above value by num_intra_steps to determine step size
|
||||
* floor batch_size to nearest multiple of step_size (from base batch size)
|
||||
Examples:
|
||||
num_steps == 4 --> 64, 56, 48, 40, 32, 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1
|
||||
num_steps (no_odd=True) == 4 --> 64, 56, 48, 40, 32, 28, 24, 20, 16, 14, 12, 10, 8, 6, 4, 2
|
||||
num_steps == 2 --> 64, 48, 32, 24, 16, 12, 8, 6, 4, 3, 2, 1
|
||||
num_steps == 1 --> 64, 32, 16, 8, 4, 2, 1
|
||||
"""
|
||||
if batch_size <= 1:
|
||||
# return 0 for stopping value so easy to use in loop
|
||||
return 0
|
||||
base_batch_size = int(2 ** (math.log(batch_size - 1) // math.log(2)))
|
||||
step_size = max(base_batch_size // num_intra_steps, 1)
|
||||
batch_size = base_batch_size + ((batch_size - base_batch_size - 1) // step_size) * step_size
|
||||
if no_odd and batch_size % 2:
|
||||
batch_size -= 1
|
||||
return batch_size
|
||||
|
||||
|
||||
def check_batch_size_retry(error_str):
|
||||
""" check failure error string for conditions where batch decay retry should not be attempted
|
||||
"""
|
||||
error_str = error_str.lower()
|
||||
if 'required rank' in error_str:
|
||||
# Errors involving phrase 'required rank' typically happen when a conv is used that's
|
||||
# not compatible with channels_last memory format.
|
||||
return False
|
||||
if 'illegal' in error_str:
|
||||
# 'Illegal memory access' errors in CUDA typically leave process in unusable state
|
||||
return False
|
||||
return True
|
23
validate.py
23
validate.py
@ -22,7 +22,8 @@ from contextlib import suppress
|
||||
|
||||
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
|
||||
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
|
||||
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser
|
||||
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser,\
|
||||
decay_batch_step, check_batch_size_retry
|
||||
|
||||
has_apex = False
|
||||
try:
|
||||
@ -122,6 +123,8 @@ parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
|
||||
help='Real labels JSON file for imagenet evaluation')
|
||||
parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME',
|
||||
help='Valid label indices txt file for validation of partial label space')
|
||||
parser.add_argument('--retry', default=False, action='store_true',
|
||||
help='Enable batch size decay & retry for single model validation')
|
||||
|
||||
|
||||
def validate(args):
|
||||
@ -303,18 +306,19 @@ def _try_run(args, initial_batch_size):
|
||||
batch_size = initial_batch_size
|
||||
results = OrderedDict()
|
||||
error_str = 'Unknown'
|
||||
while batch_size >= 1:
|
||||
args.batch_size = batch_size
|
||||
torch.cuda.empty_cache()
|
||||
while batch_size:
|
||||
args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
results = validate(args)
|
||||
return results
|
||||
except RuntimeError as e:
|
||||
error_str = str(e)
|
||||
if 'channels_last' in error_str:
|
||||
_logger.error(f'"{error_str}" while running validation.')
|
||||
if not check_batch_size_retry(error_str):
|
||||
break
|
||||
_logger.warning(f'"{error_str}" while running validation. Reducing batch size to {batch_size} for retry.')
|
||||
batch_size = batch_size // 2
|
||||
batch_size = decay_batch_step(batch_size)
|
||||
_logger.warning(f'Reducing batch size to {batch_size} for retry.')
|
||||
results['error'] = error_str
|
||||
_logger.error(f'{args.model} failed to validate ({error_str}).')
|
||||
return results
|
||||
@ -368,7 +372,10 @@ def main():
|
||||
if len(results):
|
||||
write_results(results_file, results)
|
||||
else:
|
||||
results = validate(args)
|
||||
if args.retry:
|
||||
results = _try_run(args, args.batch_size)
|
||||
else:
|
||||
results = validate(args)
|
||||
# output results in JSON to stdout w/ delimiter for runner script
|
||||
print(f'--result\n{json.dumps(results, indent=4)}')
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user