mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add Visformer-small weighs, tweak torchscript jit test img size.
This commit is contained in:
parent
83487e2a0d
commit
c4572cc5aa
@ -33,7 +33,11 @@ else:
|
|||||||
TARGET_FWD_SIZE = MAX_FWD_SIZE = 384
|
TARGET_FWD_SIZE = MAX_FWD_SIZE = 384
|
||||||
TARGET_BWD_SIZE = 128
|
TARGET_BWD_SIZE = 128
|
||||||
MAX_BWD_SIZE = 320
|
MAX_BWD_SIZE = 320
|
||||||
MAX_FWD_FEAT_SIZE = 448
|
MAX_FWD_OUT_SIZE = 448
|
||||||
|
TARGET_JIT_SIZE = 128
|
||||||
|
MAX_JIT_SIZE = 320
|
||||||
|
TARGET_FFEAT_SIZE = 96
|
||||||
|
MAX_FFEAT_SIZE = 256
|
||||||
|
|
||||||
|
|
||||||
def _get_input_size(model, target=None):
|
def _get_input_size(model, target=None):
|
||||||
@ -109,10 +113,10 @@ def test_model_default_cfgs(model_name, batch_size):
|
|||||||
pool_size = cfg['pool_size']
|
pool_size = cfg['pool_size']
|
||||||
input_size = model.default_cfg['input_size']
|
input_size = model.default_cfg['input_size']
|
||||||
|
|
||||||
if all([x <= MAX_FWD_FEAT_SIZE for x in input_size]) and \
|
if all([x <= MAX_FWD_OUT_SIZE for x in input_size]) and \
|
||||||
not any([fnmatch.fnmatch(model_name, x) for x in EXCLUDE_FILTERS]):
|
not any([fnmatch.fnmatch(model_name, x) for x in EXCLUDE_FILTERS]):
|
||||||
# output sizes only checked if default res <= 448 * 448 to keep resource down
|
# output sizes only checked if default res <= 448 * 448 to keep resource down
|
||||||
input_size = tuple([min(x, MAX_FWD_FEAT_SIZE) for x in input_size])
|
input_size = tuple([min(x, MAX_FWD_OUT_SIZE) for x in input_size])
|
||||||
input_tensor = torch.randn((batch_size, *input_size))
|
input_tensor = torch.randn((batch_size, *input_size))
|
||||||
|
|
||||||
# test forward_features (always unpooled)
|
# test forward_features (always unpooled)
|
||||||
@ -176,8 +180,8 @@ def test_model_forward_torchscript(model_name, batch_size):
|
|||||||
model = create_model(model_name, pretrained=False)
|
model = create_model(model_name, pretrained=False)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
input_size = _get_input_size(model, 128)
|
input_size = _get_input_size(model, TARGET_JIT_SIZE)
|
||||||
if max(input_size) > MAX_FWD_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional
|
if max(input_size) > MAX_JIT_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional
|
||||||
pytest.skip("Fixed input size model > limit.")
|
pytest.skip("Fixed input size model > limit.")
|
||||||
|
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
@ -205,8 +209,8 @@ def test_model_forward_features(model_name, batch_size):
|
|||||||
expected_channels = model.feature_info.channels()
|
expected_channels = model.feature_info.channels()
|
||||||
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
|
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
|
||||||
|
|
||||||
input_size = _get_input_size(model, 96) # jit compile is already a bit slow and we've tested normal res already...
|
input_size = _get_input_size(model, TARGET_FFEAT_SIZE)
|
||||||
if max(input_size) > MAX_FWD_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional
|
if max(input_size) > MAX_FFEAT_SIZE:
|
||||||
pytest.skip("Fixed input size model > limit.")
|
pytest.skip("Fixed input size model > limit.")
|
||||||
|
|
||||||
outputs = model(torch.randn((batch_size, *input_size)))
|
outputs = model(torch.randn((batch_size, *input_size)))
|
||||||
|
@ -33,7 +33,9 @@ def _cfg(url='', **kwargs):
|
|||||||
|
|
||||||
default_cfgs = dict(
|
default_cfgs = dict(
|
||||||
visformer_tiny=_cfg(),
|
visformer_tiny=_cfg(),
|
||||||
visformer_small=_cfg(),
|
visformer_small=_cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/visformer_small-839e1f5b.pth'
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user