Check forward_intermediates features against forward_features output
parent
c8c4f256b8
commit
fdcf7cf5c3
|
@ -508,8 +508,9 @@ def test_model_forward_intermediates(model_name, batch_size):
|
|||
spatial_axis = get_spatial_dim(output_fmt)
|
||||
import math
|
||||
|
||||
inpt = torch.randn((batch_size, *input_size))
|
||||
output, intermediates = model.forward_intermediates(
|
||||
torch.randn((batch_size, *input_size)),
|
||||
inpt,
|
||||
output_fmt=output_fmt,
|
||||
)
|
||||
assert len(expected_channels) == len(intermediates)
|
||||
|
@ -521,6 +522,9 @@ def test_model_forward_intermediates(model_name, batch_size):
|
|||
assert o.shape[0] == batch_size
|
||||
assert not torch.isnan(o).any()
|
||||
|
||||
output2 = model.forward_features(inpt)
|
||||
assert torch.allclose(output, output2)
|
||||
|
||||
|
||||
def _create_fx_model(model, train=False):
|
||||
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
|
||||
|
|
Loading…
Reference in New Issue