mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix / improve tests for features
This commit is contained in:
parent
4b2565e4cb
commit
fe3cf542fa
@ -55,8 +55,8 @@ FEAT_INTER_FILTERS = [
|
||||
# transformer models don't support many of the spatial / feature based model functionalities
|
||||
NON_STD_FILTERS = [
|
||||
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
|
||||
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
|
||||
'poolformer_*', 'volo_*', 'sequencer2d_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
|
||||
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*',
|
||||
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*',
|
||||
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*'
|
||||
]
|
||||
NUM_NON_STD = len(NON_STD_FILTERS)
|
||||
@ -356,7 +356,7 @@ if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
|
||||
|
||||
@pytest.mark.features
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FEAT_FILTERS, include_tags=True))
|
||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FEAT_FILTERS))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
def test_model_forward_features(model_name, batch_size):
|
||||
"""Run a single forward pass with each model in feature extraction mode"""
|
||||
@ -364,7 +364,7 @@ def test_model_forward_features(model_name, batch_size):
|
||||
model.eval()
|
||||
expected_channels = model.feature_info.channels()
|
||||
expected_reduction = model.feature_info.reduction()
|
||||
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
|
||||
assert len(expected_channels) >= 3 # all models here should have at least 3 default feat levels
|
||||
|
||||
input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE)
|
||||
if max(input_size) > MAX_FFEAT_SIZE:
|
||||
@ -387,7 +387,7 @@ def test_model_forward_features(model_name, batch_size):
|
||||
|
||||
@pytest.mark.features
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, include_tags=True))
|
||||
@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
def test_model_forward_intermediates_features(model_name, batch_size):
|
||||
"""Run a single forward pass with each model in feature extraction mode"""
|
||||
@ -419,7 +419,7 @@ def test_model_forward_intermediates_features(model_name, batch_size):
|
||||
|
||||
@pytest.mark.features
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, include_tags=True))
|
||||
@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
def test_model_forward_intermediates(model_name, batch_size):
|
||||
"""Run a single forward pass with each model in feature extraction mode"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user