From fdcf7cf5c3b168640482a8acaa76d20e27be12d4 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 6 May 2025 12:56:58 -0700 Subject: [PATCH] Check forward_intermediates features against forward_features output --- tests/test_models.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index 3ba3615d..6e7af0fc 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -508,8 +508,9 @@ def test_model_forward_intermediates(model_name, batch_size): spatial_axis = get_spatial_dim(output_fmt) import math + inpt = torch.randn((batch_size, *input_size)) output, intermediates = model.forward_intermediates( - torch.randn((batch_size, *input_size)), + inpt, output_fmt=output_fmt, ) assert len(expected_channels) == len(intermediates) @@ -521,6 +522,9 @@ def test_model_forward_intermediates(model_name, batch_size): assert o.shape[0] == batch_size assert not torch.isnan(o).any() + output2 = model.forward_features(inpt) + assert torch.allclose(output, output2) + def _create_fx_model(model, train=False): # This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode