diff --git a/tests/test_models.py b/tests/test_models.py index d8ac8d64..76d95207 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -175,7 +175,7 @@ def test_model_default_cfgs(model_name, batch_size): outputs = model.forward_features(input_tensor) assert outputs.shape[spatial_axis[0]] == pool_size[0], 'unpooled feature shape != config' assert outputs.shape[spatial_axis[1]] == pool_size[1], 'unpooled feature shape != config' - if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.VGG)): + if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.GhostNetV2, timm.models.VGG)): assert outputs.shape[feat_axis] == model.num_features # test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features @@ -188,8 +188,8 @@ def test_model_default_cfgs(model_name, batch_size): model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through outputs = model.forward(input_tensor) assert len(outputs.shape) == 4 - if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.VGG)): - # mobilenetv3/ghostnet/vgg forward_features vs removed pooling differ due to location or lack of GAP + if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet,timm.models.GhostNetV2, timm.models.VGG)): + # mobilenetv3/ghostnet/ghostnetv2/vgg forward_features vs removed pooling differ due to location or lack of GAP assert outputs.shape[spatial_axis[0]] == pool_size[0] and outputs.shape[spatial_axis[1]] == pool_size[1] if 'pruned' not in model_name: # FIXME better pruned model handling