mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update test for cait 448x448 model
This commit is contained in:
parent
5fcddb96a8
commit
d45e50b9db
@ -43,7 +43,9 @@ def test_model_forward(model_name, batch_size):
|
|||||||
|
|
||||||
input_size = model.default_cfg['input_size']
|
input_size = model.default_cfg['input_size']
|
||||||
if any([x > MAX_FWD_SIZE for x in input_size]):
|
if any([x > MAX_FWD_SIZE for x in input_size]):
|
||||||
# cap forward test at max res 448 * 448 to keep resource down
|
if is_model_default_key(model_name, 'fixed_input_size'):
|
||||||
|
pytest.skip("Fixed input size model > limit.")
|
||||||
|
# cap forward test at max res 384 * 384 to keep resource down
|
||||||
input_size = tuple([min(x, MAX_FWD_SIZE) for x in input_size])
|
input_size = tuple([min(x, MAX_FWD_SIZE) for x in input_size])
|
||||||
inputs = torch.randn((batch_size, *input_size))
|
inputs = torch.randn((batch_size, *input_size))
|
||||||
outputs = model(inputs)
|
outputs = model(inputs)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user