Update bulk_runner with improved filtering options for benchmarking / val runs
parent
dfb8658100
commit
25cf2c2dbb
|
@ -21,7 +21,7 @@ import time
|
|||
from typing import Callable, List, Tuple, Union
|
||||
|
||||
|
||||
from timm.models import is_model, list_models
|
||||
from timm.models import is_model, list_models, get_pretrained_cfg
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='Per-model process launcher')
|
||||
|
@ -98,16 +98,32 @@ def main():
|
|||
cmd, cmd_args = cmd_from_args(args)
|
||||
|
||||
model_cfgs = []
|
||||
model_names = []
|
||||
if args.model_list == 'all':
|
||||
# NOTE should make this config, for validation / benchmark runs the focus is 1k models,
|
||||
# so we filter out 21/22k and some other unusable heads. This will change in the future...
|
||||
exclude_model_filters = ['*in21k', '*in22k', '*dino', '*_22k']
|
||||
model_names = list_models(
|
||||
pretrained=args.pretrained, # only include models w/ pretrained checkpoints if set
|
||||
exclude_filters=exclude_model_filters
|
||||
)
|
||||
model_cfgs = [(n, None) for n in model_names]
|
||||
elif args.model_list == 'all_in1k':
|
||||
model_names = list_models(pretrained=True)
|
||||
model_cfgs = []
|
||||
for n in model_names:
|
||||
pt_cfg = get_pretrained_cfg(n)
|
||||
if getattr(pt_cfg, 'num_classes', 0) == 1000:
|
||||
print(n, pt_cfg.num_classes)
|
||||
model_cfgs.append((n, None))
|
||||
elif args.model_list == 'all_res':
|
||||
model_names = list_models()
|
||||
model_names += [n.split('.')[0] for n in list_models(pretrained=True)]
|
||||
model_cfgs = set()
|
||||
for n in model_names:
|
||||
pt_cfg = get_pretrained_cfg(n)
|
||||
if pt_cfg is None:
|
||||
print(f'Model {n} is missing pretrained cfg, skipping.')
|
||||
continue
|
||||
model_cfgs.add((n, pt_cfg.input_size[-1]))
|
||||
if pt_cfg.test_input_size is not None:
|
||||
model_cfgs.add((n, pt_cfg.test_input_size[-1]))
|
||||
model_cfgs = [(n, {'img-size': r}) for n, r in sorted(model_cfgs)]
|
||||
elif not is_model(args.model_list):
|
||||
# model name doesn't exist, try as wildcard filter
|
||||
model_names = list_models(args.model_list)
|
||||
|
@ -122,7 +138,8 @@ def main():
|
|||
results_file = args.results_file or './results.csv'
|
||||
results = []
|
||||
errors = []
|
||||
print('Running script on these models: {}'.format(', '.join(model_names)))
|
||||
model_strings = '\n'.join([f'{x[0]}, {x[1]}' for x in model_cfgs])
|
||||
print(f"Running script on these models:\n {model_strings}")
|
||||
if not args.sort_key:
|
||||
if 'benchmark' in args.script:
|
||||
if any(['train' in a for a in args.script_args]):
|
||||
|
@ -136,10 +153,14 @@ def main():
|
|||
print(f'Script: {args.script}, Args: {args.script_args}, Sort key: {sort_key}')
|
||||
|
||||
try:
|
||||
for m, _ in model_cfgs:
|
||||
for m, ax in model_cfgs:
|
||||
if not m:
|
||||
continue
|
||||
args_str = (cmd, *[str(e) for e in cmd_args], '--model', m)
|
||||
if ax is not None:
|
||||
extra_args = [(f'--{k}', str(v)) for k, v in ax.items()]
|
||||
extra_args = [i for t in extra_args for i in t]
|
||||
args_str += tuple(extra_args)
|
||||
try:
|
||||
o = subprocess.check_output(args=args_str).decode('utf-8').split('--result')[-1]
|
||||
r = json.loads(o)
|
||||
|
@ -157,7 +178,11 @@ def main():
|
|||
if errors:
|
||||
print(f'{len(errors)} models had errors during run.')
|
||||
for e in errors:
|
||||
print(f"\t {e['model']} ({e.get('error', 'Unknown')})")
|
||||
if 'model' in e:
|
||||
print(f"\t {e['model']} ({e.get('error', 'Unknown')})")
|
||||
else:
|
||||
print(e)
|
||||
|
||||
results = list(filter(lambda x: 'error' not in x, results))
|
||||
|
||||
no_sortkey = list(filter(lambda x: sort_key not in x, results))
|
||||
|
|
Loading…
Reference in New Issue