Check forward_intermediates features against forward_features output

pull/2478/merge
Ross Wightman 2025-05-06 12:56:58 -07:00 committed by Ross Wightman
parent c8c4f256b8
commit fdcf7cf5c3
1 changed files with 5 additions and 1 deletions

View File

@ -508,8 +508,9 @@ def test_model_forward_intermediates(model_name, batch_size):
spatial_axis = get_spatial_dim(output_fmt)
import math
inpt = torch.randn((batch_size, *input_size))
output, intermediates = model.forward_intermediates(
torch.randn((batch_size, *input_size)),
inpt,
output_fmt=output_fmt,
)
assert len(expected_channels) == len(intermediates)
@ -521,6 +522,9 @@ def test_model_forward_intermediates(model_name, batch_size):
assert o.shape[0] == batch_size
assert not torch.isnan(o).any()
output2 = model.forward_features(inpt)
assert torch.allclose(output, output2)
def _create_fx_model(model, train=False):
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode