mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add GhostNetV2
This commit is contained in:
parent
b407794e3a
commit
e4babe7372
@ -175,7 +175,7 @@ def test_model_default_cfgs(model_name, batch_size):
|
||||
outputs = model.forward_features(input_tensor)
|
||||
assert outputs.shape[spatial_axis[0]] == pool_size[0], 'unpooled feature shape != config'
|
||||
assert outputs.shape[spatial_axis[1]] == pool_size[1], 'unpooled feature shape != config'
|
||||
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.VGG)):
|
||||
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.GhostNetV2, timm.models.VGG)):
|
||||
assert outputs.shape[feat_axis] == model.num_features
|
||||
|
||||
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
|
||||
@ -188,8 +188,8 @@ def test_model_default_cfgs(model_name, batch_size):
|
||||
model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through
|
||||
outputs = model.forward(input_tensor)
|
||||
assert len(outputs.shape) == 4
|
||||
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.VGG)):
|
||||
# mobilenetv3/ghostnet/vgg forward_features vs removed pooling differ due to location or lack of GAP
|
||||
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet,timm.models.GhostNetV2, timm.models.VGG)):
|
||||
# mobilenetv3/ghostnet/ghostnetv2/vgg forward_features vs removed pooling differ due to location or lack of GAP
|
||||
assert outputs.shape[spatial_axis[0]] == pool_size[0] and outputs.shape[spatial_axis[1]] == pool_size[1]
|
||||
|
||||
if 'pruned' not in model_name: # FIXME better pruned model handling
|
||||
|
Loading…
x
Reference in New Issue
Block a user