mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add FX test exclusion since it uses more ram and barfs on GitHub actions. Will take a few iterations to include needed models :(
This commit is contained in:
parent
c976a410d9
commit
3819bef93e
@ -334,8 +334,14 @@ def _create_fx_model(model, train=False):
|
|||||||
return fx_model
|
return fx_model
|
||||||
|
|
||||||
|
|
||||||
|
EXCLUDE_FX_FILTERS = []
|
||||||
|
# not enough memory to run fx on more models than other tests
|
||||||
|
if 'GITHUB_ACTIONS' in os.environ:
|
||||||
|
EXCLUDE_FX_FILTERS += ['beit_large*', 'swin_large*']
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(120)
|
@pytest.mark.timeout(120)
|
||||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS))
|
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS))
|
||||||
@pytest.mark.parametrize('batch_size', [1])
|
@pytest.mark.parametrize('batch_size', [1])
|
||||||
def test_model_forward_fx(model_name, batch_size):
|
def test_model_forward_fx(model_name, batch_size):
|
||||||
"""
|
"""
|
||||||
@ -367,7 +373,8 @@ def test_model_forward_fx(model_name, batch_size):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(120)
|
@pytest.mark.timeout(120)
|
||||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True))
|
@pytest.mark.parametrize('model_name', list_models(
|
||||||
|
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True))
|
||||||
@pytest.mark.parametrize('batch_size', [2])
|
@pytest.mark.parametrize('batch_size', [2])
|
||||||
def test_model_backward_fx(model_name, batch_size):
|
def test_model_backward_fx(model_name, batch_size):
|
||||||
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
|
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
|
||||||
@ -400,7 +407,7 @@ EXCLUDE_FX_JIT_FILTERS = [
|
|||||||
'deit_*_distilled_patch16_224',
|
'deit_*_distilled_patch16_224',
|
||||||
'levit*',
|
'levit*',
|
||||||
'pit_*_distilled_224',
|
'pit_*_distilled_224',
|
||||||
]
|
] + EXCLUDE_FX_FILTERS
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(120)
|
@pytest.mark.timeout(120)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user