mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #2217 from dsuess/2216_fix_script_on_features_fx
Fix jit.script breaking with features_fx
This commit is contained in:
commit
20fe56bd90
@ -631,3 +631,35 @@ if 'GITHUB_ACTIONS' not in os.environ:
|
|||||||
|
|
||||||
assert outputs.shape[0] == batch_size
|
assert outputs.shape[0] == batch_size
|
||||||
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
||||||
|
|
||||||
|
@pytest.mark.timeout(120)
|
||||||
|
@pytest.mark.parametrize('model_name', ["regnetx_002"])
|
||||||
|
@pytest.mark.parametrize('batch_size', [1])
|
||||||
|
def test_model_forward_torchscript_with_features_fx(model_name, batch_size):
|
||||||
|
"""Create a model with feature extraction based on fx, script it, and run
|
||||||
|
a single forward pass"""
|
||||||
|
if not has_fx_feature_extraction:
|
||||||
|
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
|
||||||
|
|
||||||
|
allowed_models = list_models(
|
||||||
|
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS,
|
||||||
|
name_matches_cfg=True
|
||||||
|
)
|
||||||
|
assert model_name in allowed_models, f"{model_name=} not supported for this test"
|
||||||
|
|
||||||
|
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
|
||||||
|
assert max(input_size) <= MAX_JIT_SIZE, "Fixed input size model > limit. Pick a different model to run this test"
|
||||||
|
|
||||||
|
with set_scriptable(True):
|
||||||
|
model = create_model(model_name, pretrained=False, features_only=True, feature_cfg={"feature_cls": "fx"})
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
model = torch.jit.script(model)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(torch.randn((batch_size, *input_size)))
|
||||||
|
|
||||||
|
assert isinstance(outputs, list)
|
||||||
|
|
||||||
|
for tensor in outputs:
|
||||||
|
assert tensor.shape[0] == batch_size
|
||||||
|
assert not torch.isnan(tensor).any(), 'Output included NaNs'
|
@ -116,6 +116,8 @@ def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str
|
|||||||
class FeatureGraphNet(nn.Module):
|
class FeatureGraphNet(nn.Module):
|
||||||
""" A FX Graph based feature extractor that works with the model feature_info metadata
|
""" A FX Graph based feature extractor that works with the model feature_info metadata
|
||||||
"""
|
"""
|
||||||
|
return_dict: torch.jit.Final[bool]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@ -155,6 +157,8 @@ class GraphExtractNet(nn.Module):
|
|||||||
squeeze_out: if only one output, and output in list format, flatten to single tensor
|
squeeze_out: if only one output, and output in list format, flatten to single tensor
|
||||||
return_dict: return as dictionary from extractor with node names as keys, ignores squeeze_out arg
|
return_dict: return as dictionary from extractor with node names as keys, ignores squeeze_out arg
|
||||||
"""
|
"""
|
||||||
|
return_dict: torch.jit.Final[bool]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user