mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add default_cfg options for min_input_size / fixed_input_size, queries in model registry, and use for testing self-attn models
This commit is contained in:
parent
4e4b863b15
commit
16f7aa9f54
@ -5,7 +5,8 @@ import os
|
||||
import fnmatch
|
||||
|
||||
import timm
|
||||
from timm import list_models, create_model, set_scriptable
|
||||
from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \
|
||||
get_model_default_value
|
||||
|
||||
if hasattr(torch._C, '_jit_set_profiling_executor'):
|
||||
# legacy executor is too slow to compile large models for unit tests
|
||||
@ -60,9 +61,15 @@ def test_model_backward(model_name, batch_size):
|
||||
model.eval()
|
||||
|
||||
input_size = model.default_cfg['input_size']
|
||||
if any([x > MAX_BWD_SIZE for x in input_size]):
|
||||
# cap backward test at 128 * 128 to keep resource usage down
|
||||
input_size = tuple([min(x, MAX_BWD_SIZE) for x in input_size])
|
||||
if not is_model_default_key(model_name, 'fixed_input_size'):
|
||||
min_input_size = get_model_default_value(model_name, 'min_input_size')
|
||||
if min_input_size is not None:
|
||||
input_size = min_input_size
|
||||
else:
|
||||
if any([x > MAX_BWD_SIZE for x in input_size]):
|
||||
# cap backward test at 128 * 128 to keep resource usage down
|
||||
input_size = tuple([min(x, MAX_BWD_SIZE) for x in input_size])
|
||||
|
||||
inputs = torch.randn((batch_size, *input_size))
|
||||
outputs = model(inputs)
|
||||
outputs.mean().backward()
|
||||
@ -155,7 +162,14 @@ def test_model_forward_torchscript(model_name, batch_size):
|
||||
with set_scriptable(True):
|
||||
model = create_model(model_name, pretrained=False)
|
||||
model.eval()
|
||||
input_size = (3, 128, 128) # jit compile is already a bit slow and we've tested normal res already...
|
||||
|
||||
if has_model_default_key(model_name, 'fixed_input_size'):
|
||||
input_size = get_model_default_value(model_name, 'input_size')
|
||||
elif has_model_default_key(model_name, 'min_input_size'):
|
||||
input_size = get_model_default_value(model_name, 'min_input_size')
|
||||
else:
|
||||
input_size = (3, 128, 128) # jit compile is already a bit slow and we've tested normal res already...
|
||||
|
||||
model = torch.jit.script(model)
|
||||
outputs = model(torch.randn((batch_size, *input_size)))
|
||||
|
||||
@ -180,7 +194,14 @@ def test_model_forward_features(model_name, batch_size):
|
||||
model.eval()
|
||||
expected_channels = model.feature_info.channels()
|
||||
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
|
||||
input_size = (3, 96, 96) # jit compile is already a bit slow and we've tested normal res already...
|
||||
|
||||
if has_model_default_key(model_name, 'fixed_input_size'):
|
||||
input_size = get_model_default_value(model_name, 'input_size')
|
||||
elif has_model_default_key(model_name, 'min_input_size'):
|
||||
input_size = get_model_default_value(model_name, 'min_input_size')
|
||||
else:
|
||||
input_size = (3, 96, 96) # jit compile is already a bit slow and we've tested normal res already...
|
||||
|
||||
outputs = model(torch.randn((batch_size, *input_size)))
|
||||
assert len(expected_channels) == len(outputs)
|
||||
for e, o in zip(expected_channels, outputs):
|
||||
|
@ -1,3 +1,4 @@
|
||||
from .version import __version__
|
||||
from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \
|
||||
is_scriptable, is_exportable, set_scriptable, set_exportable
|
||||
is_scriptable, is_exportable, set_scriptable, set_exportable, has_model_default_key, is_model_default_key, \
|
||||
get_model_default_value, is_model_pretrained
|
||||
|
@ -40,4 +40,5 @@ from .helpers import load_checkpoint, resume_checkpoint, model_parameters
|
||||
from .layers import TestTimePoolHead, apply_test_time_pool
|
||||
from .layers import convert_splitbn_model
|
||||
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
|
||||
from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules
|
||||
from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\
|
||||
has_model_default_key, is_model_default_key, get_model_default_value, is_model_pretrained
|
||||
|
@ -37,23 +37,24 @@ def _cfg(url='', **kwargs):
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'stem.conv', 'classifier': 'head.fc',
|
||||
'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc',
|
||||
'fixed_input_size': False, 'min_input_size': (3, 224, 224),
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
# GPU-Efficient (ResNet) weights
|
||||
'botnet50t_224': _cfg(url=''),
|
||||
'botnet50t_c4c5_224': _cfg(url=''),
|
||||
'botnet50t_224': _cfg(url='', fixed_input_size=True),
|
||||
'botnet50t_c4c5_224': _cfg(url='', fixed_input_size=True),
|
||||
|
||||
'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'halonet_h1_c4c5': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
|
||||
'halonet_h1_c4c5': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
|
||||
'halonet26t': _cfg(url=''),
|
||||
'halonet50t': _cfg(url=''),
|
||||
|
||||
'lambda_resnet26t': _cfg(url=''),
|
||||
'lambda_resnet50t': _cfg(url=''),
|
||||
'lambda_resnet26t': _cfg(url='', min_input_size=(3, 128, 128)),
|
||||
'lambda_resnet50t': _cfg(url='', min_input_size=(3, 128, 128)),
|
||||
}
|
||||
|
||||
|
||||
|
@ -6,13 +6,16 @@ import sys
|
||||
import re
|
||||
import fnmatch
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
|
||||
__all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules']
|
||||
__all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
|
||||
'is_model_default_key', 'has_model_default_key', 'get_model_default_value', 'is_model_pretrained']
|
||||
|
||||
_module_to_models = defaultdict(set) # dict of sets to check membership of model in module
|
||||
_model_to_module = {} # mapping of model names to module names
|
||||
_model_entrypoints = {} # mapping of model names to entrypoint fns
|
||||
_model_has_pretrained = set() # set of model names that have pretrained weight url present
|
||||
_model_default_cfgs = dict() # central repo for model default_cfgs
|
||||
|
||||
|
||||
def register_model(fn):
|
||||
@ -37,6 +40,7 @@ def register_model(fn):
|
||||
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing
|
||||
# entrypoints or non-matching combos
|
||||
has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']
|
||||
_model_default_cfgs[model_name] = deepcopy(mod.default_cfgs[model_name])
|
||||
if has_pretrained:
|
||||
_model_has_pretrained.add(model_name)
|
||||
return fn
|
||||
@ -105,3 +109,31 @@ def is_model_in_modules(model_name, module_names):
|
||||
assert isinstance(module_names, (tuple, list, set))
|
||||
return any(model_name in _module_to_models[n] for n in module_names)
|
||||
|
||||
|
||||
def has_model_default_key(model_name, cfg_key):
|
||||
""" Query model default_cfgs for existence of a specific key.
|
||||
"""
|
||||
if model_name in _model_default_cfgs and cfg_key in _model_default_cfgs[model_name]:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_model_default_key(model_name, cfg_key):
|
||||
""" Return truthy value for specified model default_cfg key, False if does not exist.
|
||||
"""
|
||||
if model_name in _model_default_cfgs and _model_default_cfgs[model_name].get(cfg_key, False):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_model_default_value(model_name, cfg_key):
|
||||
""" Get a specific model default_cfg value by key. None if it doesn't exist.
|
||||
"""
|
||||
if model_name in _model_default_cfgs:
|
||||
return _model_default_cfgs[model_name].get(cfg_key, None)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def is_model_pretrained(model_name):
|
||||
return model_name in _model_has_pretrained
|
||||
|
Loading…
x
Reference in New Issue
Block a user