mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Check forward_intermediates features against forward_features output
This commit is contained in:
parent
c8c4f256b8
commit
907a32e699
@ -508,8 +508,9 @@ def test_model_forward_intermediates(model_name, batch_size):
|
||||
spatial_axis = get_spatial_dim(output_fmt)
|
||||
import math
|
||||
|
||||
inpt = torch.randn((batch_size, *input_size))
|
||||
output, intermediates = model.forward_intermediates(
|
||||
torch.randn((batch_size, *input_size)),
|
||||
inpt,
|
||||
output_fmt=output_fmt,
|
||||
)
|
||||
assert len(expected_channels) == len(intermediates)
|
||||
@ -521,6 +522,9 @@ def test_model_forward_intermediates(model_name, batch_size):
|
||||
assert o.shape[0] == batch_size
|
||||
assert not torch.isnan(o).any()
|
||||
|
||||
output2 = model.forward_features(inpt)
|
||||
assert torch.allclose(output, output2)
|
||||
|
||||
|
||||
def _create_fx_model(model, train=False):
|
||||
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
|
||||
|
Loading…
x
Reference in New Issue
Block a user