mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Check forward_intermediates features against forward_features output
This commit is contained in:
parent
c8c4f256b8
commit
fdcf7cf5c3
@ -508,8 +508,9 @@ def test_model_forward_intermediates(model_name, batch_size):
|
|||||||
spatial_axis = get_spatial_dim(output_fmt)
|
spatial_axis = get_spatial_dim(output_fmt)
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
inpt = torch.randn((batch_size, *input_size))
|
||||||
output, intermediates = model.forward_intermediates(
|
output, intermediates = model.forward_intermediates(
|
||||||
torch.randn((batch_size, *input_size)),
|
inpt,
|
||||||
output_fmt=output_fmt,
|
output_fmt=output_fmt,
|
||||||
)
|
)
|
||||||
assert len(expected_channels) == len(intermediates)
|
assert len(expected_channels) == len(intermediates)
|
||||||
@ -521,6 +522,9 @@ def test_model_forward_intermediates(model_name, batch_size):
|
|||||||
assert o.shape[0] == batch_size
|
assert o.shape[0] == batch_size
|
||||||
assert not torch.isnan(o).any()
|
assert not torch.isnan(o).any()
|
||||||
|
|
||||||
|
output2 = model.forward_features(inpt)
|
||||||
|
assert torch.allclose(output, output2)
|
||||||
|
|
||||||
|
|
||||||
def _create_fx_model(model, train=False):
|
def _create_fx_model(model, train=False):
|
||||||
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
|
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
|
||||||
|
Loading…
x
Reference in New Issue
Block a user