mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add try/except guards
This commit is contained in:
parent
b25ff96768
commit
d2994016e9
@ -4,7 +4,11 @@ import platform
|
|||||||
import os
|
import os
|
||||||
import fnmatch
|
import fnmatch
|
||||||
|
|
||||||
|
try:
|
||||||
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names, NodePathTracer
|
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names, NodePathTracer
|
||||||
|
has_fx_feature_extraction = True
|
||||||
|
except ImportError:
|
||||||
|
has_fx_feature_extraction = False
|
||||||
|
|
||||||
import timm
|
import timm
|
||||||
from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \
|
from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \
|
||||||
@ -307,6 +311,9 @@ def test_model_forward_features(model_name, batch_size):
|
|||||||
@pytest.mark.parametrize('batch_size', [1])
|
@pytest.mark.parametrize('batch_size', [1])
|
||||||
def test_model_forward_fx(model_name, batch_size):
|
def test_model_forward_fx(model_name, batch_size):
|
||||||
"""Symbolically trace each model and run single forward pass through the resulting GraphModule"""
|
"""Symbolically trace each model and run single forward pass through the resulting GraphModule"""
|
||||||
|
if not has_fx_feature_extraction:
|
||||||
|
pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required")
|
||||||
|
|
||||||
model = create_model(model_name, pretrained=False)
|
model = create_model(model_name, pretrained=False)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
@ -332,6 +339,9 @@ def test_model_forward_fx(model_name, batch_size):
|
|||||||
@pytest.mark.parametrize('batch_size', [2])
|
@pytest.mark.parametrize('batch_size', [2])
|
||||||
def test_model_backward_fx(model_name, batch_size):
|
def test_model_backward_fx(model_name, batch_size):
|
||||||
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
|
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
|
||||||
|
if not has_fx_feature_extraction:
|
||||||
|
pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required")
|
||||||
|
|
||||||
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE)
|
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE)
|
||||||
if max(input_size) > MAX_BWD_SIZE:
|
if max(input_size) > MAX_BWD_SIZE:
|
||||||
pytest.skip("Fixed input size model > limit.")
|
pytest.skip("Fixed input size model > limit.")
|
||||||
@ -387,6 +397,9 @@ EXCLUDE_FX_JIT_FILTERS = [
|
|||||||
@pytest.mark.parametrize('batch_size', [1])
|
@pytest.mark.parametrize('batch_size', [1])
|
||||||
def test_model_forward_fx_torchscript(model_name, batch_size):
|
def test_model_forward_fx_torchscript(model_name, batch_size):
|
||||||
"""Symbolically trace each model, script it, and run single forward pass"""
|
"""Symbolically trace each model, script it, and run single forward pass"""
|
||||||
|
if not has_fx_feature_extraction:
|
||||||
|
pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required")
|
||||||
|
|
||||||
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
|
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
|
||||||
if max(input_size) > MAX_JIT_SIZE:
|
if max(input_size) > MAX_JIT_SIZE:
|
||||||
pytest.skip("Fixed input size model > limit.")
|
pytest.skip("Fixed input size model > limit.")
|
||||||
|
@ -8,8 +8,9 @@ from .features import _get_feature_info
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from torchvision.models.feature_extraction import create_feature_extractor
|
from torchvision.models.feature_extraction import create_feature_extractor
|
||||||
|
has_fx_feature_extraction = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
has_fx_feature_extraction = False
|
||||||
|
|
||||||
# Layers we went to treat as leaf modules
|
# Layers we went to treat as leaf modules
|
||||||
from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, DropPath
|
from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, DropPath
|
||||||
@ -58,6 +59,7 @@ def register_autowrap_function(func: Callable):
|
|||||||
class FeatureGraphNet(nn.Module):
|
class FeatureGraphNet(nn.Module):
|
||||||
def __init__(self, model, out_indices, out_map=None):
|
def __init__(self, model, out_indices, out_map=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
|
||||||
self.feature_info = _get_feature_info(model, out_indices)
|
self.feature_info = _get_feature_info(model, out_indices)
|
||||||
if out_map is not None:
|
if out_map is not None:
|
||||||
assert len(out_map) == len(out_indices)
|
assert len(out_map) == len(out_indices)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user