mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add Visformer-small weighs, tweak torchscript jit test img size.
This commit is contained in:
parent
83487e2a0d
commit
c4572cc5aa
@ -33,7 +33,11 @@ else:
|
||||
TARGET_FWD_SIZE = MAX_FWD_SIZE = 384
|
||||
TARGET_BWD_SIZE = 128
|
||||
MAX_BWD_SIZE = 320
|
||||
MAX_FWD_FEAT_SIZE = 448
|
||||
MAX_FWD_OUT_SIZE = 448
|
||||
TARGET_JIT_SIZE = 128
|
||||
MAX_JIT_SIZE = 320
|
||||
TARGET_FFEAT_SIZE = 96
|
||||
MAX_FFEAT_SIZE = 256
|
||||
|
||||
|
||||
def _get_input_size(model, target=None):
|
||||
@ -109,10 +113,10 @@ def test_model_default_cfgs(model_name, batch_size):
|
||||
pool_size = cfg['pool_size']
|
||||
input_size = model.default_cfg['input_size']
|
||||
|
||||
if all([x <= MAX_FWD_FEAT_SIZE for x in input_size]) and \
|
||||
if all([x <= MAX_FWD_OUT_SIZE for x in input_size]) and \
|
||||
not any([fnmatch.fnmatch(model_name, x) for x in EXCLUDE_FILTERS]):
|
||||
# output sizes only checked if default res <= 448 * 448 to keep resource down
|
||||
input_size = tuple([min(x, MAX_FWD_FEAT_SIZE) for x in input_size])
|
||||
input_size = tuple([min(x, MAX_FWD_OUT_SIZE) for x in input_size])
|
||||
input_tensor = torch.randn((batch_size, *input_size))
|
||||
|
||||
# test forward_features (always unpooled)
|
||||
@ -176,8 +180,8 @@ def test_model_forward_torchscript(model_name, batch_size):
|
||||
model = create_model(model_name, pretrained=False)
|
||||
model.eval()
|
||||
|
||||
input_size = _get_input_size(model, 128)
|
||||
if max(input_size) > MAX_FWD_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional
|
||||
input_size = _get_input_size(model, TARGET_JIT_SIZE)
|
||||
if max(input_size) > MAX_JIT_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional
|
||||
pytest.skip("Fixed input size model > limit.")
|
||||
|
||||
model = torch.jit.script(model)
|
||||
@ -205,8 +209,8 @@ def test_model_forward_features(model_name, batch_size):
|
||||
expected_channels = model.feature_info.channels()
|
||||
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
|
||||
|
||||
input_size = _get_input_size(model, 96) # jit compile is already a bit slow and we've tested normal res already...
|
||||
if max(input_size) > MAX_FWD_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional
|
||||
input_size = _get_input_size(model, TARGET_FFEAT_SIZE)
|
||||
if max(input_size) > MAX_FFEAT_SIZE:
|
||||
pytest.skip("Fixed input size model > limit.")
|
||||
|
||||
outputs = model(torch.randn((batch_size, *input_size)))
|
||||
|
@ -33,7 +33,9 @@ def _cfg(url='', **kwargs):
|
||||
|
||||
default_cfgs = dict(
|
||||
visformer_tiny=_cfg(),
|
||||
visformer_small=_cfg(),
|
||||
visformer_small=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/visformer_small-839e1f5b.pth'
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user