mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #136 from yoniaflalo/adding_effnet_pruned
added efficientnet pruned weights
This commit is contained in:
commit
8ec554b82e
@ -27,7 +27,7 @@ Hacked together by Ross Wightman
|
|||||||
from .efficientnet_builder import *
|
from .efficientnet_builder import *
|
||||||
from .feature_hooks import FeatureHooks
|
from .feature_hooks import FeatureHooks
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
from .helpers import load_pretrained
|
from .helpers import load_pretrained, adapt_model_from_file
|
||||||
from .layers import SelectAdaptivePool2d
|
from .layers import SelectAdaptivePool2d
|
||||||
from timm.models.layers import create_conv2d
|
from timm.models.layers import create_conv2d
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||||
@ -131,6 +131,16 @@ default_cfgs = {
|
|||||||
'efficientnet_lite4': _cfg(
|
'efficientnet_lite4': _cfg(
|
||||||
url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
|
url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
|
||||||
|
|
||||||
|
'efficientnet_b1_pruned': _cfg(
|
||||||
|
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb1_pruned_9ebb3fe6.pth',
|
||||||
|
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
||||||
|
'efficientnet_b2_pruned': _cfg(
|
||||||
|
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb2_pruned_203f55bc.pth',
|
||||||
|
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
||||||
|
'efficientnet_b3_pruned': _cfg(
|
||||||
|
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb3_pruned_5abcc29f.pth',
|
||||||
|
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
||||||
|
|
||||||
'tf_efficientnet_b0': _cfg(
|
'tf_efficientnet_b0': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
|
||||||
input_size=(3, 224, 224)),
|
input_size=(3, 224, 224)),
|
||||||
@ -482,9 +492,11 @@ def _create_model(model_kwargs, default_cfg, pretrained=False):
|
|||||||
else:
|
else:
|
||||||
load_strict = True
|
load_strict = True
|
||||||
model_class = EfficientNet
|
model_class = EfficientNet
|
||||||
|
variant = model_kwargs.pop('variant', '')
|
||||||
model = model_class(**model_kwargs)
|
model = model_class(**model_kwargs)
|
||||||
model.default_cfg = default_cfg
|
model.default_cfg = default_cfg
|
||||||
|
if '_pruned' in variant:
|
||||||
|
model = adapt_model_from_file(model, variant)
|
||||||
if pretrained:
|
if pretrained:
|
||||||
load_pretrained(
|
load_pretrained(
|
||||||
model,
|
model,
|
||||||
@ -730,6 +742,7 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
|
|||||||
channel_multiplier=channel_multiplier,
|
channel_multiplier=channel_multiplier,
|
||||||
act_layer=Swish,
|
act_layer=Swish,
|
||||||
norm_kwargs=resolve_bn_args(kwargs),
|
norm_kwargs=resolve_bn_args(kwargs),
|
||||||
|
variant=variant,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
|
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
|
||||||
@ -1229,6 +1242,41 @@ def efficientnet_lite4(pretrained=False, **kwargs):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def efficientnet_b1_pruned(pretrained=False, **kwargs):
|
||||||
|
""" EfficientNet-B1 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """
|
||||||
|
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||||
|
kwargs['pad_type'] = 'same'
|
||||||
|
variant = 'efficientnet_b1_pruned'
|
||||||
|
model = _gen_efficientnet(
|
||||||
|
variant, channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def efficientnet_b2_pruned(pretrained=False, **kwargs):
|
||||||
|
""" EfficientNet-B2 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """
|
||||||
|
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||||
|
kwargs['pad_type'] = 'same'
|
||||||
|
model = _gen_efficientnet(
|
||||||
|
'efficientnet_b2_pruned', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def efficientnet_b3_pruned(pretrained=False, **kwargs):
|
||||||
|
""" EfficientNet-B3 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """
|
||||||
|
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||||
|
kwargs['pad_type'] = 'same'
|
||||||
|
model = _gen_efficientnet(
|
||||||
|
'efficientnet_b3_pruned', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def tf_efficientnet_b0(pretrained=False, **kwargs):
|
def tf_efficientnet_b0(pretrained=False, **kwargs):
|
||||||
""" EfficientNet-B0. Tensorflow compatible variant """
|
""" EfficientNet-B0. Tensorflow compatible variant """
|
||||||
|
1
timm/models/pruned/efficientnet_b1_pruned.txt
Normal file
1
timm/models/pruned/efficientnet_b1_pruned.txt
Normal file
File diff suppressed because one or more lines are too long
1
timm/models/pruned/efficientnet_b2_pruned.txt
Normal file
1
timm/models/pruned/efficientnet_b2_pruned.txt
Normal file
File diff suppressed because one or more lines are too long
1
timm/models/pruned/efficientnet_b3_pruned.txt
Normal file
1
timm/models/pruned/efficientnet_b3_pruned.txt
Normal file
File diff suppressed because one or more lines are too long
Loading…
x
Reference in New Issue
Block a user