Improve bulk_runner for 'all resolution' mode

This commit is contained in:
Ross Wightman 2023-11-23 12:44:07 -08:00
parent 40d55ab4bc
commit e7abc320f5

View File

@ -113,13 +113,14 @@ def main():
model_cfgs.append((n, None))
elif args.model_list == 'all_res':
model_names = list_models()
model_names += [n.split('.')[0] for n in list_models(pretrained=True)]
model_names += list_models(pretrained=True)
model_cfgs = set()
for n in model_names:
pt_cfg = get_pretrained_cfg(n)
if pt_cfg is None:
print(f'Model {n} is missing pretrained cfg, skipping.')
continue
n = n.split('.')[0]
model_cfgs.add((n, pt_cfg.input_size[-1]))
if pt_cfg.test_input_size is not None:
model_cfgs.add((n, pt_cfg.test_input_size[-1]))