mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update model img-size/crop expansion for bulk runner
This commit is contained in:
parent
4515a435e4
commit
907555fda6
@ -93,6 +93,30 @@ def cmd_from_args(args) -> Tuple[Union[Callable, str], List[str]]:
|
||||
return cmd, cmd_args
|
||||
|
||||
|
||||
def _get_model_cfgs(
|
||||
model_names,
|
||||
num_classes=None,
|
||||
expand_train_test=False,
|
||||
include_crop=True,
|
||||
):
|
||||
model_cfgs = []
|
||||
for n in model_names:
|
||||
pt_cfg = get_pretrained_cfg(n)
|
||||
if num_classes is not None and getattr(pt_cfg, 'num_classes', 0) != num_classes:
|
||||
continue
|
||||
model_cfgs.append((n, pt_cfg.input_size[-1], pt_cfg.crop_pct))
|
||||
if expand_train_test and pt_cfg.test_input_size is not None:
|
||||
if pt_cfg.test_crop_pct is not None:
|
||||
model_cfgs.append((n, pt_cfg.test_input_size[-1], pt_cfg.test_crop_pct))
|
||||
else:
|
||||
model_cfgs.append((n, pt_cfg.test_input_size[-1], pt_cfg.crop_pct))
|
||||
if include_crop:
|
||||
model_cfgs = [(n, {'img-size': r, 'crop-pct': cp}) for n, r, cp in sorted(model_cfgs)]
|
||||
else:
|
||||
model_cfgs = [(n, {'img-size': r}) for n, r, cp in sorted(model_cfgs)]
|
||||
return model_cfgs
|
||||
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
cmd, cmd_args = cmd_from_args(args)
|
||||
@ -105,26 +129,10 @@ def main():
|
||||
model_cfgs = [(n, None) for n in model_names]
|
||||
elif args.model_list == 'all_in1k':
|
||||
model_names = list_models(pretrained=True)
|
||||
model_cfgs = []
|
||||
for n in model_names:
|
||||
pt_cfg = get_pretrained_cfg(n)
|
||||
if getattr(pt_cfg, 'num_classes', 0) == 1000:
|
||||
print(n, pt_cfg.num_classes)
|
||||
model_cfgs.append((n, None))
|
||||
model_cfgs = _get_model_cfgs(model_names, num_classes=1000, expand_train_test=True)
|
||||
elif args.model_list == 'all_res':
|
||||
model_names = list_models()
|
||||
model_names += list_models(pretrained=True)
|
||||
model_cfgs = set()
|
||||
for n in model_names:
|
||||
pt_cfg = get_pretrained_cfg(n)
|
||||
if pt_cfg is None:
|
||||
print(f'Model {n} is missing pretrained cfg, skipping.')
|
||||
continue
|
||||
n = n.split('.')[0]
|
||||
model_cfgs.add((n, pt_cfg.input_size[-1]))
|
||||
if pt_cfg.test_input_size is not None:
|
||||
model_cfgs.add((n, pt_cfg.test_input_size[-1]))
|
||||
model_cfgs = [(n, {'img-size': r}) for n, r in sorted(model_cfgs)]
|
||||
model_cfgs = _get_model_cfgs(model_names, expand_train_test=True, include_crop=False)
|
||||
elif not is_model(args.model_list):
|
||||
# model name doesn't exist, try as wildcard filter
|
||||
model_names = list_models(args.model_list)
|
||||
@ -132,8 +140,14 @@ def main():
|
||||
|
||||
if not model_cfgs and os.path.exists(args.model_list):
|
||||
with open(args.model_list) as f:
|
||||
model_cfgs = []
|
||||
model_names = [line.rstrip() for line in f]
|
||||
model_cfgs = [(n, None) for n in model_names]
|
||||
_get_model_cfgs(
|
||||
model_names,
|
||||
#num_classes=1000,
|
||||
expand_train_test=True,
|
||||
#include_crop=False,
|
||||
)
|
||||
|
||||
if len(model_cfgs):
|
||||
results_file = args.results_file or './results.csv'
|
||||
|
Loading…
x
Reference in New Issue
Block a user