mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #136 from yoniaflalo/adding_effnet_pruned
added efficientnet pruned weights
This commit is contained in:
commit
8ec554b82e
@ -27,7 +27,7 @@ Hacked together by Ross Wightman
|
||||
from .efficientnet_builder import *
|
||||
from .feature_hooks import FeatureHooks
|
||||
from .registry import register_model
|
||||
from .helpers import load_pretrained
|
||||
from .helpers import load_pretrained, adapt_model_from_file
|
||||
from .layers import SelectAdaptivePool2d
|
||||
from timm.models.layers import create_conv2d
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
@ -131,6 +131,16 @@ default_cfgs = {
|
||||
'efficientnet_lite4': _cfg(
|
||||
url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
|
||||
|
||||
'efficientnet_b1_pruned': _cfg(
|
||||
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb1_pruned_9ebb3fe6.pth',
|
||||
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
||||
'efficientnet_b2_pruned': _cfg(
|
||||
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb2_pruned_203f55bc.pth',
|
||||
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
||||
'efficientnet_b3_pruned': _cfg(
|
||||
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb3_pruned_5abcc29f.pth',
|
||||
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
||||
|
||||
'tf_efficientnet_b0': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
|
||||
input_size=(3, 224, 224)),
|
||||
@ -482,9 +492,11 @@ def _create_model(model_kwargs, default_cfg, pretrained=False):
|
||||
else:
|
||||
load_strict = True
|
||||
model_class = EfficientNet
|
||||
|
||||
variant = model_kwargs.pop('variant', '')
|
||||
model = model_class(**model_kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if '_pruned' in variant:
|
||||
model = adapt_model_from_file(model, variant)
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model,
|
||||
@ -730,6 +742,7 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
|
||||
channel_multiplier=channel_multiplier,
|
||||
act_layer=Swish,
|
||||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
variant=variant,
|
||||
**kwargs,
|
||||
)
|
||||
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
|
||||
@ -1229,6 +1242,41 @@ def efficientnet_lite4(pretrained=False, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnet_b1_pruned(pretrained=False, **kwargs):
|
||||
""" EfficientNet-B1 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
variant = 'efficientnet_b1_pruned'
|
||||
model = _gen_efficientnet(
|
||||
variant, channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnet_b2_pruned(pretrained=False, **kwargs):
|
||||
""" EfficientNet-B2 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_efficientnet(
|
||||
'efficientnet_b2_pruned', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnet_b3_pruned(pretrained=False, **kwargs):
|
||||
""" EfficientNet-B3 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_efficientnet(
|
||||
'efficientnet_b3_pruned', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_efficientnet_b0(pretrained=False, **kwargs):
|
||||
""" EfficientNet-B0. Tensorflow compatible variant """
|
||||
|
1
timm/models/pruned/efficientnet_b1_pruned.txt
Normal file
1
timm/models/pruned/efficientnet_b1_pruned.txt
Normal file
File diff suppressed because one or more lines are too long
1
timm/models/pruned/efficientnet_b2_pruned.txt
Normal file
1
timm/models/pruned/efficientnet_b2_pruned.txt
Normal file
File diff suppressed because one or more lines are too long
1
timm/models/pruned/efficientnet_b3_pruned.txt
Normal file
1
timm/models/pruned/efficientnet_b3_pruned.txt
Normal file
File diff suppressed because one or more lines are too long
Loading…
x
Reference in New Issue
Block a user