mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #1354 from rwightman/fix_tests
Attempting to fix unit test failures...
This commit is contained in:
commit
4547920f85
@ -4,6 +4,8 @@ import platform
|
|||||||
import os
|
import os
|
||||||
import fnmatch
|
import fnmatch
|
||||||
|
|
||||||
|
_IS_MAC = platform.system() == 'Darwin'
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names, NodePathTracer
|
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names, NodePathTracer
|
||||||
has_fx_feature_extraction = True
|
has_fx_feature_extraction = True
|
||||||
@ -322,7 +324,10 @@ def test_model_forward_features(model_name, batch_size):
|
|||||||
assert not torch.isnan(o).any()
|
assert not torch.isnan(o).any()
|
||||||
|
|
||||||
|
|
||||||
def _create_fx_model(model, train=False):
|
if not _IS_MAC:
|
||||||
|
# MACOS test runners are really slow, only running tests below this point if not on a Darwin runner...
|
||||||
|
|
||||||
|
def _create_fx_model(model, train=False):
|
||||||
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
|
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
|
||||||
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
|
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
|
||||||
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
|
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
|
||||||
@ -354,9 +359,9 @@ def _create_fx_model(model, train=False):
|
|||||||
return fx_model
|
return fx_model
|
||||||
|
|
||||||
|
|
||||||
EXCLUDE_FX_FILTERS = ['vit_gi*']
|
EXCLUDE_FX_FILTERS = ['vit_gi*']
|
||||||
# not enough memory to run fx on more models than other tests
|
# not enough memory to run fx on more models than other tests
|
||||||
if 'GITHUB_ACTIONS' in os.environ:
|
if 'GITHUB_ACTIONS' in os.environ:
|
||||||
EXCLUDE_FX_FILTERS += [
|
EXCLUDE_FX_FILTERS += [
|
||||||
'beit_large*',
|
'beit_large*',
|
||||||
'mixer_l*',
|
'mixer_l*',
|
||||||
@ -373,10 +378,10 @@ if 'GITHUB_ACTIONS' in os.environ:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(120)
|
@pytest.mark.timeout(120)
|
||||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS))
|
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS))
|
||||||
@pytest.mark.parametrize('batch_size', [1])
|
@pytest.mark.parametrize('batch_size', [1])
|
||||||
def test_model_forward_fx(model_name, batch_size):
|
def test_model_forward_fx(model_name, batch_size):
|
||||||
"""
|
"""
|
||||||
Symbolically trace each model and run single forward pass through the resulting GraphModule
|
Symbolically trace each model and run single forward pass through the resulting GraphModule
|
||||||
Also check that the output of a forward pass through the GraphModule is the same as that from the original Module
|
Also check that the output of a forward pass through the GraphModule is the same as that from the original Module
|
||||||
@ -406,11 +411,11 @@ def test_model_forward_fx(model_name, batch_size):
|
|||||||
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(120)
|
@pytest.mark.timeout(120)
|
||||||
@pytest.mark.parametrize('model_name', list_models(
|
@pytest.mark.parametrize('model_name', list_models(
|
||||||
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True))
|
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True))
|
||||||
@pytest.mark.parametrize('batch_size', [2])
|
@pytest.mark.parametrize('batch_size', [2])
|
||||||
def test_model_backward_fx(model_name, batch_size):
|
def test_model_backward_fx(model_name, batch_size):
|
||||||
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
|
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
|
||||||
if not has_fx_feature_extraction:
|
if not has_fx_feature_extraction:
|
||||||
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
|
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
|
||||||
@ -439,7 +444,7 @@ def test_model_backward_fx(model_name, batch_size):
|
|||||||
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
||||||
|
|
||||||
|
|
||||||
if 'GITHUB_ACTIONS' not in os.environ:
|
if 'GITHUB_ACTIONS' not in os.environ:
|
||||||
# FIXME this test is causing GitHub actions to run out of RAM and abruptly kill the test process
|
# FIXME this test is causing GitHub actions to run out of RAM and abruptly kill the test process
|
||||||
|
|
||||||
# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
|
# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
|
||||||
|
Loading…
x
Reference in New Issue
Block a user