mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
More FX test tweaks
This commit is contained in:
parent
f0507f6da6
commit
1e51c2d02e
@ -348,6 +348,7 @@ if 'GITHUB_ACTIONS' in os.environ:
|
||||
'vgg*',
|
||||
'vit_large*',
|
||||
'xcit_large*',
|
||||
'mixer_l*',
|
||||
]
|
||||
|
||||
|
||||
@ -368,6 +369,7 @@ def test_model_forward_fx(model_name, batch_size):
|
||||
input_size = _get_input_size(model=model, target=TARGET_FWD_FX_SIZE)
|
||||
if max(input_size) > MAX_FWD_FX_SIZE:
|
||||
pytest.skip("Fixed input size model > limit.")
|
||||
with torch.no_grad():
|
||||
inputs = torch.randn((batch_size, *input_size))
|
||||
outputs = model(inputs)
|
||||
if isinstance(outputs, tuple):
|
||||
@ -440,6 +442,7 @@ def test_model_forward_fx_torchscript(model_name, batch_size):
|
||||
model.eval()
|
||||
|
||||
model = torch.jit.script(_create_fx_model(model))
|
||||
with torch.no_grad():
|
||||
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
|
||||
if isinstance(outputs, tuple):
|
||||
outputs = torch.cat(outputs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user