mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Would like to pass GitHub tests again disabling both FX feature extract backward and torchscript tests
This commit is contained in:
parent
a22b85c1b9
commit
f83b0b01e3
@ -422,37 +422,37 @@ if 'GITHUB_ACTIONS' not in os.environ:
|
|||||||
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
||||||
|
|
||||||
|
|
||||||
# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
|
# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
|
||||||
EXCLUDE_FX_JIT_FILTERS = [
|
EXCLUDE_FX_JIT_FILTERS = [
|
||||||
'deit_*_distilled_patch16_224',
|
'deit_*_distilled_patch16_224',
|
||||||
'levit*',
|
'levit*',
|
||||||
'pit_*_distilled_224',
|
'pit_*_distilled_224',
|
||||||
] + EXCLUDE_FX_FILTERS
|
] + EXCLUDE_FX_FILTERS
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(120)
|
@pytest.mark.timeout(120)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
'model_name', list_models(
|
'model_name', list_models(
|
||||||
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, name_matches_cfg=True))
|
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, name_matches_cfg=True))
|
||||||
@pytest.mark.parametrize('batch_size', [1])
|
@pytest.mark.parametrize('batch_size', [1])
|
||||||
def test_model_forward_fx_torchscript(model_name, batch_size):
|
def test_model_forward_fx_torchscript(model_name, batch_size):
|
||||||
"""Symbolically trace each model, script it, and run single forward pass"""
|
"""Symbolically trace each model, script it, and run single forward pass"""
|
||||||
if not has_fx_feature_extraction:
|
if not has_fx_feature_extraction:
|
||||||
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
|
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
|
||||||
|
|
||||||
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
|
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
|
||||||
if max(input_size) > MAX_JIT_SIZE:
|
if max(input_size) > MAX_JIT_SIZE:
|
||||||
pytest.skip("Fixed input size model > limit.")
|
pytest.skip("Fixed input size model > limit.")
|
||||||
|
|
||||||
with set_scriptable(True):
|
with set_scriptable(True):
|
||||||
model = create_model(model_name, pretrained=False)
|
model = create_model(model_name, pretrained=False)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
model = torch.jit.script(_create_fx_model(model))
|
model = torch.jit.script(_create_fx_model(model))
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
|
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
|
||||||
if isinstance(outputs, tuple):
|
if isinstance(outputs, tuple):
|
||||||
outputs = torch.cat(outputs)
|
outputs = torch.cat(outputs)
|
||||||
|
|
||||||
assert outputs.shape[0] == batch_size
|
assert outputs.shape[0] == batch_size
|
||||||
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
||||||
|
Loading…
x
Reference in New Issue
Block a user