mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Filter test models before creation for backward/torchscript tests
This commit is contained in:
parent
c4572cc5aa
commit
d400f1dbdd
@ -40,14 +40,25 @@ TARGET_FFEAT_SIZE = 96
|
||||
MAX_FFEAT_SIZE = 256
|
||||
|
||||
|
||||
def _get_input_size(model, target=None):
|
||||
default_cfg = model.default_cfg
|
||||
input_size = default_cfg['input_size']
|
||||
if 'fixed_input_size' in default_cfg and default_cfg['fixed_input_size']:
|
||||
def _get_input_size(model=None, model_name='', target=None):
|
||||
if model is None:
|
||||
assert model_name, "One of model or model_name must be provided"
|
||||
input_size = get_model_default_value(model_name, 'input_size')
|
||||
fixed_input_size = get_model_default_value(model_name, 'fixed_input_size')
|
||||
min_input_size = get_model_default_value(model_name, 'min_input_size')
|
||||
else:
|
||||
default_cfg = model.default_cfg
|
||||
input_size = default_cfg['input_size']
|
||||
fixed_input_size = default_cfg.get('fixed_input_size', None)
|
||||
min_input_size = default_cfg.get('min_input_size', None)
|
||||
assert input_size is not None
|
||||
|
||||
if fixed_input_size:
|
||||
return input_size
|
||||
if 'min_input_size' in default_cfg:
|
||||
|
||||
if min_input_size:
|
||||
if target and max(input_size) > target:
|
||||
input_size = default_cfg['min_input_size']
|
||||
input_size = min_input_size
|
||||
else:
|
||||
if target and max(input_size) > target:
|
||||
input_size = tuple([min(x, target) for x in input_size])
|
||||
@ -73,18 +84,18 @@ def test_model_forward(model_name, batch_size):
|
||||
|
||||
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS))
|
||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True))
|
||||
@pytest.mark.parametrize('batch_size', [2])
|
||||
def test_model_backward(model_name, batch_size):
|
||||
"""Run a single forward pass with each model"""
|
||||
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE)
|
||||
if max(input_size) > MAX_BWD_SIZE:
|
||||
pytest.skip("Fixed input size model > limit.")
|
||||
|
||||
model = create_model(model_name, pretrained=False, num_classes=42)
|
||||
num_params = sum([x.numel() for x in model.parameters()])
|
||||
model.train()
|
||||
|
||||
input_size = _get_input_size(model, TARGET_BWD_SIZE)
|
||||
if max(input_size) > MAX_BWD_SIZE:
|
||||
pytest.skip("Fixed input size model > limit.")
|
||||
|
||||
inputs = torch.randn((batch_size, *input_size))
|
||||
outputs = model(inputs)
|
||||
if isinstance(outputs, tuple):
|
||||
@ -172,18 +183,19 @@ EXCLUDE_JIT_FILTERS = [
|
||||
|
||||
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS))
|
||||
@pytest.mark.parametrize(
|
||||
'model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS, name_matches_cfg=True))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
def test_model_forward_torchscript(model_name, batch_size):
|
||||
"""Run a single forward pass with each model"""
|
||||
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
|
||||
if max(input_size) > MAX_JIT_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional
|
||||
pytest.skip("Fixed input size model > limit.")
|
||||
|
||||
with set_scriptable(True):
|
||||
model = create_model(model_name, pretrained=False)
|
||||
model.eval()
|
||||
|
||||
input_size = _get_input_size(model, TARGET_JIT_SIZE)
|
||||
if max(input_size) > MAX_JIT_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional
|
||||
pytest.skip("Fixed input size model > limit.")
|
||||
|
||||
model = torch.jit.script(model)
|
||||
outputs = model(torch.randn((batch_size, *input_size)))
|
||||
|
||||
|
@ -50,7 +50,7 @@ def _natural_key(string_):
|
||||
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
||||
|
||||
|
||||
def list_models(filter='', module='', pretrained=False, exclude_filters=''):
|
||||
def list_models(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False):
|
||||
""" Return list of available model names, sorted alphabetically
|
||||
|
||||
Args:
|
||||
@ -58,6 +58,7 @@ def list_models(filter='', module='', pretrained=False, exclude_filters=''):
|
||||
module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet')
|
||||
pretrained (bool) - Include only models with pretrained weights if True
|
||||
exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
|
||||
name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases)
|
||||
|
||||
Example:
|
||||
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
|
||||
@ -70,7 +71,7 @@ def list_models(filter='', module='', pretrained=False, exclude_filters=''):
|
||||
if filter:
|
||||
models = fnmatch.filter(models, filter) # include these models
|
||||
if exclude_filters:
|
||||
if not isinstance(exclude_filters, list):
|
||||
if not isinstance(exclude_filters, (tuple, list)):
|
||||
exclude_filters = [exclude_filters]
|
||||
for xf in exclude_filters:
|
||||
exclude_models = fnmatch.filter(models, xf) # exclude these models
|
||||
@ -78,6 +79,8 @@ def list_models(filter='', module='', pretrained=False, exclude_filters=''):
|
||||
models = set(models).difference(exclude_models)
|
||||
if pretrained:
|
||||
models = _model_has_pretrained.intersection(models)
|
||||
if name_matches_cfg:
|
||||
models = set(_model_default_cfgs).intersection(models)
|
||||
return list(sorted(models, key=_natural_key))
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user