mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add default_cfg options for min_input_size / fixed_input_size, queries in model registry, and use for testing self-attn models
This commit is contained in:
parent
4e4b863b15
commit
16f7aa9f54
@ -5,7 +5,8 @@ import os
|
|||||||
import fnmatch
|
import fnmatch
|
||||||
|
|
||||||
import timm
|
import timm
|
||||||
from timm import list_models, create_model, set_scriptable
|
from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \
|
||||||
|
get_model_default_value
|
||||||
|
|
||||||
if hasattr(torch._C, '_jit_set_profiling_executor'):
|
if hasattr(torch._C, '_jit_set_profiling_executor'):
|
||||||
# legacy executor is too slow to compile large models for unit tests
|
# legacy executor is too slow to compile large models for unit tests
|
||||||
@ -60,9 +61,15 @@ def test_model_backward(model_name, batch_size):
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
input_size = model.default_cfg['input_size']
|
input_size = model.default_cfg['input_size']
|
||||||
if any([x > MAX_BWD_SIZE for x in input_size]):
|
if not is_model_default_key(model_name, 'fixed_input_size'):
|
||||||
# cap backward test at 128 * 128 to keep resource usage down
|
min_input_size = get_model_default_value(model_name, 'min_input_size')
|
||||||
input_size = tuple([min(x, MAX_BWD_SIZE) for x in input_size])
|
if min_input_size is not None:
|
||||||
|
input_size = min_input_size
|
||||||
|
else:
|
||||||
|
if any([x > MAX_BWD_SIZE for x in input_size]):
|
||||||
|
# cap backward test at 128 * 128 to keep resource usage down
|
||||||
|
input_size = tuple([min(x, MAX_BWD_SIZE) for x in input_size])
|
||||||
|
|
||||||
inputs = torch.randn((batch_size, *input_size))
|
inputs = torch.randn((batch_size, *input_size))
|
||||||
outputs = model(inputs)
|
outputs = model(inputs)
|
||||||
outputs.mean().backward()
|
outputs.mean().backward()
|
||||||
@ -155,7 +162,14 @@ def test_model_forward_torchscript(model_name, batch_size):
|
|||||||
with set_scriptable(True):
|
with set_scriptable(True):
|
||||||
model = create_model(model_name, pretrained=False)
|
model = create_model(model_name, pretrained=False)
|
||||||
model.eval()
|
model.eval()
|
||||||
input_size = (3, 128, 128) # jit compile is already a bit slow and we've tested normal res already...
|
|
||||||
|
if has_model_default_key(model_name, 'fixed_input_size'):
|
||||||
|
input_size = get_model_default_value(model_name, 'input_size')
|
||||||
|
elif has_model_default_key(model_name, 'min_input_size'):
|
||||||
|
input_size = get_model_default_value(model_name, 'min_input_size')
|
||||||
|
else:
|
||||||
|
input_size = (3, 128, 128) # jit compile is already a bit slow and we've tested normal res already...
|
||||||
|
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
outputs = model(torch.randn((batch_size, *input_size)))
|
outputs = model(torch.randn((batch_size, *input_size)))
|
||||||
|
|
||||||
@ -180,7 +194,14 @@ def test_model_forward_features(model_name, batch_size):
|
|||||||
model.eval()
|
model.eval()
|
||||||
expected_channels = model.feature_info.channels()
|
expected_channels = model.feature_info.channels()
|
||||||
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
|
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
|
||||||
input_size = (3, 96, 96) # jit compile is already a bit slow and we've tested normal res already...
|
|
||||||
|
if has_model_default_key(model_name, 'fixed_input_size'):
|
||||||
|
input_size = get_model_default_value(model_name, 'input_size')
|
||||||
|
elif has_model_default_key(model_name, 'min_input_size'):
|
||||||
|
input_size = get_model_default_value(model_name, 'min_input_size')
|
||||||
|
else:
|
||||||
|
input_size = (3, 96, 96) # jit compile is already a bit slow and we've tested normal res already...
|
||||||
|
|
||||||
outputs = model(torch.randn((batch_size, *input_size)))
|
outputs = model(torch.randn((batch_size, *input_size)))
|
||||||
assert len(expected_channels) == len(outputs)
|
assert len(expected_channels) == len(outputs)
|
||||||
for e, o in zip(expected_channels, outputs):
|
for e, o in zip(expected_channels, outputs):
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
from .version import __version__
|
from .version import __version__
|
||||||
from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \
|
from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \
|
||||||
is_scriptable, is_exportable, set_scriptable, set_exportable
|
is_scriptable, is_exportable, set_scriptable, set_exportable, has_model_default_key, is_model_default_key, \
|
||||||
|
get_model_default_value, is_model_pretrained
|
||||||
|
@ -40,4 +40,5 @@ from .helpers import load_checkpoint, resume_checkpoint, model_parameters
|
|||||||
from .layers import TestTimePoolHead, apply_test_time_pool
|
from .layers import TestTimePoolHead, apply_test_time_pool
|
||||||
from .layers import convert_splitbn_model
|
from .layers import convert_splitbn_model
|
||||||
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
|
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
|
||||||
from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules
|
from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\
|
||||||
|
has_model_default_key, is_model_default_key, get_model_default_value, is_model_pretrained
|
||||||
|
@ -37,23 +37,24 @@ def _cfg(url='', **kwargs):
|
|||||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||||
'first_conv': 'stem.conv', 'classifier': 'head.fc',
|
'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc',
|
||||||
|
'fixed_input_size': False, 'min_input_size': (3, 224, 224),
|
||||||
**kwargs
|
**kwargs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
default_cfgs = {
|
default_cfgs = {
|
||||||
# GPU-Efficient (ResNet) weights
|
# GPU-Efficient (ResNet) weights
|
||||||
'botnet50t_224': _cfg(url=''),
|
'botnet50t_224': _cfg(url='', fixed_input_size=True),
|
||||||
'botnet50t_c4c5_224': _cfg(url=''),
|
'botnet50t_c4c5_224': _cfg(url='', fixed_input_size=True),
|
||||||
|
|
||||||
'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
|
||||||
'halonet_h1_c4c5': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
'halonet_h1_c4c5': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
|
||||||
'halonet26t': _cfg(url=''),
|
'halonet26t': _cfg(url=''),
|
||||||
'halonet50t': _cfg(url=''),
|
'halonet50t': _cfg(url=''),
|
||||||
|
|
||||||
'lambda_resnet26t': _cfg(url=''),
|
'lambda_resnet26t': _cfg(url='', min_input_size=(3, 128, 128)),
|
||||||
'lambda_resnet50t': _cfg(url=''),
|
'lambda_resnet50t': _cfg(url='', min_input_size=(3, 128, 128)),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,13 +6,16 @@ import sys
|
|||||||
import re
|
import re
|
||||||
import fnmatch
|
import fnmatch
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
__all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules']
|
__all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
|
||||||
|
'is_model_default_key', 'has_model_default_key', 'get_model_default_value', 'is_model_pretrained']
|
||||||
|
|
||||||
_module_to_models = defaultdict(set) # dict of sets to check membership of model in module
|
_module_to_models = defaultdict(set) # dict of sets to check membership of model in module
|
||||||
_model_to_module = {} # mapping of model names to module names
|
_model_to_module = {} # mapping of model names to module names
|
||||||
_model_entrypoints = {} # mapping of model names to entrypoint fns
|
_model_entrypoints = {} # mapping of model names to entrypoint fns
|
||||||
_model_has_pretrained = set() # set of model names that have pretrained weight url present
|
_model_has_pretrained = set() # set of model names that have pretrained weight url present
|
||||||
|
_model_default_cfgs = dict() # central repo for model default_cfgs
|
||||||
|
|
||||||
|
|
||||||
def register_model(fn):
|
def register_model(fn):
|
||||||
@ -37,6 +40,7 @@ def register_model(fn):
|
|||||||
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing
|
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing
|
||||||
# entrypoints or non-matching combos
|
# entrypoints or non-matching combos
|
||||||
has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']
|
has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']
|
||||||
|
_model_default_cfgs[model_name] = deepcopy(mod.default_cfgs[model_name])
|
||||||
if has_pretrained:
|
if has_pretrained:
|
||||||
_model_has_pretrained.add(model_name)
|
_model_has_pretrained.add(model_name)
|
||||||
return fn
|
return fn
|
||||||
@ -105,3 +109,31 @@ def is_model_in_modules(model_name, module_names):
|
|||||||
assert isinstance(module_names, (tuple, list, set))
|
assert isinstance(module_names, (tuple, list, set))
|
||||||
return any(model_name in _module_to_models[n] for n in module_names)
|
return any(model_name in _module_to_models[n] for n in module_names)
|
||||||
|
|
||||||
|
|
||||||
|
def has_model_default_key(model_name, cfg_key):
|
||||||
|
""" Query model default_cfgs for existence of a specific key.
|
||||||
|
"""
|
||||||
|
if model_name in _model_default_cfgs and cfg_key in _model_default_cfgs[model_name]:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_model_default_key(model_name, cfg_key):
|
||||||
|
""" Return truthy value for specified model default_cfg key, False if does not exist.
|
||||||
|
"""
|
||||||
|
if model_name in _model_default_cfgs and _model_default_cfgs[model_name].get(cfg_key, False):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_default_value(model_name, cfg_key):
|
||||||
|
""" Get a specific model default_cfg value by key. None if it doesn't exist.
|
||||||
|
"""
|
||||||
|
if model_name in _model_default_cfgs:
|
||||||
|
return _model_default_cfgs[model_name].get(cfg_key, None)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def is_model_pretrained(model_name):
|
||||||
|
return model_name in _model_has_pretrained
|
||||||
|
Loading…
x
Reference in New Issue
Block a user