mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #2274 from huggingface/bulk_runner_tweaks
Better all res resolution for bulk runner
This commit is contained in:
commit
f81cbdcca9
@ -21,7 +21,7 @@ import time
|
||||
from typing import Callable, List, Tuple, Union
|
||||
|
||||
|
||||
from timm.models import is_model, list_models, get_pretrained_cfg
|
||||
from timm.models import is_model, list_models, get_pretrained_cfg, get_arch_pretrained_cfgs
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='Per-model process launcher')
|
||||
@ -98,23 +98,44 @@ def _get_model_cfgs(
|
||||
num_classes=None,
|
||||
expand_train_test=False,
|
||||
include_crop=True,
|
||||
expand_arch=False,
|
||||
):
|
||||
model_cfgs = []
|
||||
for n in model_names:
|
||||
pt_cfg = get_pretrained_cfg(n)
|
||||
if num_classes is not None and getattr(pt_cfg, 'num_classes', 0) != num_classes:
|
||||
continue
|
||||
model_cfgs.append((n, pt_cfg.input_size[-1], pt_cfg.crop_pct))
|
||||
if expand_train_test and pt_cfg.test_input_size is not None:
|
||||
if pt_cfg.test_crop_pct is not None:
|
||||
model_cfgs.append((n, pt_cfg.test_input_size[-1], pt_cfg.test_crop_pct))
|
||||
model_cfgs = set()
|
||||
|
||||
for name in model_names:
|
||||
if expand_arch:
|
||||
pt_cfgs = get_arch_pretrained_cfgs(name).values()
|
||||
else:
|
||||
pt_cfg = get_pretrained_cfg(name)
|
||||
pt_cfgs = [pt_cfg] if pt_cfg is not None else []
|
||||
|
||||
for cfg in pt_cfgs:
|
||||
if cfg.input_size is None:
|
||||
continue
|
||||
if num_classes is not None and getattr(cfg, 'num_classes', 0) != num_classes:
|
||||
continue
|
||||
|
||||
# Add main configuration
|
||||
size = cfg.input_size[-1]
|
||||
if include_crop:
|
||||
model_cfgs.add((name, size, cfg.crop_pct))
|
||||
else:
|
||||
model_cfgs.append((n, pt_cfg.test_input_size[-1], pt_cfg.crop_pct))
|
||||
model_cfgs.add((name, size))
|
||||
|
||||
# Add test configuration if required
|
||||
if expand_train_test and cfg.test_input_size is not None:
|
||||
test_size = cfg.test_input_size[-1]
|
||||
if include_crop:
|
||||
test_crop = cfg.test_crop_pct or cfg.crop_pct
|
||||
model_cfgs.add((name, test_size, test_crop))
|
||||
else:
|
||||
model_cfgs.add((name, test_size))
|
||||
|
||||
# Format the output
|
||||
if include_crop:
|
||||
model_cfgs = [(n, {'img-size': r, 'crop-pct': cp}) for n, r, cp in sorted(model_cfgs)]
|
||||
return [(n, {'img-size': r, 'crop-pct': cp}) for n, r, cp in sorted(model_cfgs)]
|
||||
else:
|
||||
model_cfgs = [(n, {'img-size': r}) for n, r, cp in sorted(model_cfgs)]
|
||||
return model_cfgs
|
||||
return [(n, {'img-size': r}) for n, r in sorted(model_cfgs)]
|
||||
|
||||
|
||||
def main():
|
||||
@ -132,7 +153,7 @@ def main():
|
||||
model_cfgs = _get_model_cfgs(model_names, num_classes=1000, expand_train_test=True)
|
||||
elif args.model_list == 'all_res':
|
||||
model_names = list_models()
|
||||
model_cfgs = _get_model_cfgs(model_names, expand_train_test=True, include_crop=False)
|
||||
model_cfgs = _get_model_cfgs(model_names, expand_train_test=True, include_crop=False, expand_arch=True)
|
||||
elif not is_model(args.model_list):
|
||||
# model name doesn't exist, try as wildcard filter
|
||||
model_names = list_models(args.model_list)
|
||||
@ -140,9 +161,8 @@ def main():
|
||||
|
||||
if not model_cfgs and os.path.exists(args.model_list):
|
||||
with open(args.model_list) as f:
|
||||
model_cfgs = []
|
||||
model_names = [line.rstrip() for line in f]
|
||||
_get_model_cfgs(
|
||||
model_cfgs = _get_model_cfgs(
|
||||
model_names,
|
||||
#num_classes=1000,
|
||||
expand_train_test=True,
|
||||
|
@ -95,4 +95,5 @@ from ._pretrained import PretrainedCfg, DefaultCfg, filter_pretrained_cfg
|
||||
from ._prune import adapt_model_from_string
|
||||
from ._registry import split_model_name_tag, get_arch_name, generate_default_cfgs, register_model, \
|
||||
register_model_deprecations, model_entrypoint, list_models, list_pretrained, get_deprecated_models, \
|
||||
is_model, list_modules, is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value
|
||||
is_model, list_modules, is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value, \
|
||||
get_arch_pretrained_cfgs
|
||||
|
@ -16,7 +16,7 @@ from ._pretrained import PretrainedCfg, DefaultCfg
|
||||
__all__ = [
|
||||
'split_model_name_tag', 'get_arch_name', 'register_model', 'generate_default_cfgs',
|
||||
'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
|
||||
'get_pretrained_cfg_value', 'is_model_pretrained'
|
||||
'get_pretrained_cfg_value', 'is_model_pretrained', 'get_pretrained_cfgs_for_arch'
|
||||
]
|
||||
|
||||
_module_to_models: Dict[str, Set[str]] = defaultdict(set) # dict of sets to check membership of model in module
|
||||
@ -341,3 +341,12 @@ def get_pretrained_cfg_value(model_name: str, cfg_key: str) -> Optional[Any]:
|
||||
"""
|
||||
cfg = get_pretrained_cfg(model_name, allow_unregistered=False)
|
||||
return getattr(cfg, cfg_key, None)
|
||||
|
||||
|
||||
def get_arch_pretrained_cfgs(model_name: str) -> Dict[str, PretrainedCfg]:
|
||||
""" Get all pretrained cfgs for a given architecture.
|
||||
"""
|
||||
arch_name, _ = split_model_name_tag(model_name)
|
||||
model_names = _model_with_tags[arch_name]
|
||||
cfgs = {m: _model_pretrained_cfgs[m] for m in model_names}
|
||||
return cfgs
|
||||
|
Loading…
x
Reference in New Issue
Block a user