Add GhostNetV2

This commit is contained in:
yehuitang 2023-08-13 18:23:52 +08:00 committed by GitHub
parent b407794e3a
commit e4babe7372
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -175,7 +175,7 @@ def test_model_default_cfgs(model_name, batch_size):
outputs = model.forward_features(input_tensor)
assert outputs.shape[spatial_axis[0]] == pool_size[0], 'unpooled feature shape != config'
assert outputs.shape[spatial_axis[1]] == pool_size[1], 'unpooled feature shape != config'
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.VGG)):
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.GhostNetV2, timm.models.VGG)):
assert outputs.shape[feat_axis] == model.num_features
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
@ -188,8 +188,8 @@ def test_model_default_cfgs(model_name, batch_size):
model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through
outputs = model.forward(input_tensor)
assert len(outputs.shape) == 4
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.VGG)):
# mobilenetv3/ghostnet/vgg forward_features vs removed pooling differ due to location or lack of GAP
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet,timm.models.GhostNetV2, timm.models.VGG)):
# mobilenetv3/ghostnet/ghostnetv2/vgg forward_features vs removed pooling differ due to location or lack of GAP
assert outputs.shape[spatial_axis[0]] == pool_size[0] and outputs.shape[spatial_axis[1]] == pool_size[1]
if 'pruned' not in model_name: # FIXME better pruned model handling