mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #2264 from huggingface/group_size_eff
Allow group_size override for more efficientnet and mobilenetv3 based…
This commit is contained in:
commit
ed7aaf8d6d
@ -488,7 +488,8 @@ def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwar
|
|||||||
|
|
||||||
def _gen_mobilenet_v1(
|
def _gen_mobilenet_v1(
|
||||||
variant, channel_multiplier=1.0, depth_multiplier=1.0,
|
variant, channel_multiplier=1.0, depth_multiplier=1.0,
|
||||||
fix_stem_head=False, head_conv=False, pretrained=False, **kwargs):
|
group_size=None, fix_stem_head=False, head_conv=False, pretrained=False, **kwargs
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
|
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
|
||||||
Paper: https://arxiv.org/abs/1801.04381
|
Paper: https://arxiv.org/abs/1801.04381
|
||||||
@ -503,7 +504,12 @@ def _gen_mobilenet_v1(
|
|||||||
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
|
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
|
||||||
head_features = (1024 if fix_stem_head else max(1024, round_chs_fn(1024))) if head_conv else 0
|
head_features = (1024 if fix_stem_head else max(1024, round_chs_fn(1024))) if head_conv else 0
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head),
|
block_args=decode_arch_def(
|
||||||
|
arch_def,
|
||||||
|
depth_multiplier=depth_multiplier,
|
||||||
|
fix_first_last=fix_stem_head,
|
||||||
|
group_size=group_size,
|
||||||
|
),
|
||||||
num_features=head_features,
|
num_features=head_features,
|
||||||
stem_size=32,
|
stem_size=32,
|
||||||
fix_stem=fix_stem_head,
|
fix_stem=fix_stem_head,
|
||||||
@ -517,7 +523,9 @@ def _gen_mobilenet_v1(
|
|||||||
|
|
||||||
|
|
||||||
def _gen_mobilenet_v2(
|
def _gen_mobilenet_v2(
|
||||||
variant, channel_multiplier=1.0, depth_multiplier=1.0, fix_stem_head=False, pretrained=False, **kwargs):
|
variant, channel_multiplier=1.0, depth_multiplier=1.0,
|
||||||
|
group_size=None, fix_stem_head=False, pretrained=False, **kwargs
|
||||||
|
):
|
||||||
""" Generate MobileNet-V2 network
|
""" Generate MobileNet-V2 network
|
||||||
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
|
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
|
||||||
Paper: https://arxiv.org/abs/1801.04381
|
Paper: https://arxiv.org/abs/1801.04381
|
||||||
@ -533,7 +541,12 @@ def _gen_mobilenet_v2(
|
|||||||
]
|
]
|
||||||
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
|
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head),
|
block_args=decode_arch_def(
|
||||||
|
arch_def,
|
||||||
|
depth_multiplier=depth_multiplier,
|
||||||
|
fix_first_last=fix_stem_head,
|
||||||
|
group_size=group_size,
|
||||||
|
),
|
||||||
num_features=1280 if fix_stem_head else max(1280, round_chs_fn(1280)),
|
num_features=1280 if fix_stem_head else max(1280, round_chs_fn(1280)),
|
||||||
stem_size=32,
|
stem_size=32,
|
||||||
fix_stem=fix_stem_head,
|
fix_stem=fix_stem_head,
|
||||||
@ -613,7 +626,8 @@ def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
|||||||
|
|
||||||
def _gen_efficientnet(
|
def _gen_efficientnet(
|
||||||
variant, channel_multiplier=1.0, depth_multiplier=1.0, channel_divisor=8,
|
variant, channel_multiplier=1.0, depth_multiplier=1.0, channel_divisor=8,
|
||||||
group_size=None, pretrained=False, **kwargs):
|
group_size=None, pretrained=False, **kwargs
|
||||||
|
):
|
||||||
"""Creates an EfficientNet model.
|
"""Creates an EfficientNet model.
|
||||||
|
|
||||||
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
|
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
|
||||||
@ -661,7 +675,8 @@ def _gen_efficientnet(
|
|||||||
|
|
||||||
|
|
||||||
def _gen_efficientnet_edge(
|
def _gen_efficientnet_edge(
|
||||||
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs):
|
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs
|
||||||
|
):
|
||||||
""" Creates an EfficientNet-EdgeTPU model
|
""" Creates an EfficientNet-EdgeTPU model
|
||||||
|
|
||||||
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/edgetpu
|
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/edgetpu
|
||||||
@ -692,7 +707,8 @@ def _gen_efficientnet_edge(
|
|||||||
|
|
||||||
|
|
||||||
def _gen_efficientnet_condconv(
|
def _gen_efficientnet_condconv(
|
||||||
variant, channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=1, pretrained=False, **kwargs):
|
variant, channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=1, pretrained=False, **kwargs
|
||||||
|
):
|
||||||
"""Creates an EfficientNet-CondConv model.
|
"""Creates an EfficientNet-CondConv model.
|
||||||
|
|
||||||
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/condconv
|
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/condconv
|
||||||
@ -764,7 +780,8 @@ def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0
|
|||||||
|
|
||||||
|
|
||||||
def _gen_efficientnetv2_base(
|
def _gen_efficientnetv2_base(
|
||||||
variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
|
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs
|
||||||
|
):
|
||||||
""" Creates an EfficientNet-V2 base model
|
""" Creates an EfficientNet-V2 base model
|
||||||
|
|
||||||
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
|
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
|
||||||
@ -780,7 +797,7 @@ def _gen_efficientnetv2_base(
|
|||||||
]
|
]
|
||||||
round_chs_fn = partial(round_channels, multiplier=channel_multiplier, round_limit=0.)
|
round_chs_fn = partial(round_channels, multiplier=channel_multiplier, round_limit=0.)
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
block_args=decode_arch_def(arch_def, depth_multiplier),
|
block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
|
||||||
num_features=round_chs_fn(1280),
|
num_features=round_chs_fn(1280),
|
||||||
stem_size=32,
|
stem_size=32,
|
||||||
round_chs_fn=round_chs_fn,
|
round_chs_fn=round_chs_fn,
|
||||||
@ -793,7 +810,8 @@ def _gen_efficientnetv2_base(
|
|||||||
|
|
||||||
|
|
||||||
def _gen_efficientnetv2_s(
|
def _gen_efficientnetv2_s(
|
||||||
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, rw=False, pretrained=False, **kwargs):
|
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, rw=False, pretrained=False, **kwargs
|
||||||
|
):
|
||||||
""" Creates an EfficientNet-V2 Small model
|
""" Creates an EfficientNet-V2 Small model
|
||||||
|
|
||||||
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
|
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
|
||||||
@ -831,7 +849,9 @@ def _gen_efficientnetv2_s(
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def _gen_efficientnetv2_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
|
def _gen_efficientnetv2_m(
|
||||||
|
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs
|
||||||
|
):
|
||||||
""" Creates an EfficientNet-V2 Medium model
|
""" Creates an EfficientNet-V2 Medium model
|
||||||
|
|
||||||
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
|
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
|
||||||
@ -849,7 +869,7 @@ def _gen_efficientnetv2_m(variant, channel_multiplier=1.0, depth_multiplier=1.0,
|
|||||||
]
|
]
|
||||||
|
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
block_args=decode_arch_def(arch_def, depth_multiplier),
|
block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
|
||||||
num_features=1280,
|
num_features=1280,
|
||||||
stem_size=24,
|
stem_size=24,
|
||||||
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
||||||
@ -861,7 +881,9 @@ def _gen_efficientnetv2_m(variant, channel_multiplier=1.0, depth_multiplier=1.0,
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def _gen_efficientnetv2_l(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
|
def _gen_efficientnetv2_l(
|
||||||
|
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs
|
||||||
|
):
|
||||||
""" Creates an EfficientNet-V2 Large model
|
""" Creates an EfficientNet-V2 Large model
|
||||||
|
|
||||||
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
|
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
|
||||||
@ -879,7 +901,7 @@ def _gen_efficientnetv2_l(variant, channel_multiplier=1.0, depth_multiplier=1.0,
|
|||||||
]
|
]
|
||||||
|
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
block_args=decode_arch_def(arch_def, depth_multiplier),
|
block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
|
||||||
num_features=1280,
|
num_features=1280,
|
||||||
stem_size=32,
|
stem_size=32,
|
||||||
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
||||||
@ -891,7 +913,9 @@ def _gen_efficientnetv2_l(variant, channel_multiplier=1.0, depth_multiplier=1.0,
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def _gen_efficientnetv2_xl(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
|
def _gen_efficientnetv2_xl(
|
||||||
|
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs
|
||||||
|
):
|
||||||
""" Creates an EfficientNet-V2 Xtra-Large model
|
""" Creates an EfficientNet-V2 Xtra-Large model
|
||||||
|
|
||||||
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
|
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
|
||||||
@ -909,7 +933,7 @@ def _gen_efficientnetv2_xl(variant, channel_multiplier=1.0, depth_multiplier=1.0
|
|||||||
]
|
]
|
||||||
|
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
block_args=decode_arch_def(arch_def, depth_multiplier),
|
block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
|
||||||
num_features=1280,
|
num_features=1280,
|
||||||
stem_size=32,
|
stem_size=32,
|
||||||
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
||||||
@ -923,7 +947,8 @@ def _gen_efficientnetv2_xl(variant, channel_multiplier=1.0, depth_multiplier=1.0
|
|||||||
|
|
||||||
def _gen_efficientnet_x(
|
def _gen_efficientnet_x(
|
||||||
variant, channel_multiplier=1.0, depth_multiplier=1.0, channel_divisor=8,
|
variant, channel_multiplier=1.0, depth_multiplier=1.0, channel_divisor=8,
|
||||||
group_size=None, version=1, pretrained=False, **kwargs):
|
group_size=None, version=1, pretrained=False, **kwargs
|
||||||
|
):
|
||||||
"""Creates an EfficientNet model.
|
"""Creates an EfficientNet model.
|
||||||
|
|
||||||
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
|
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
|
||||||
@ -1069,9 +1094,7 @@ def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrai
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def _gen_tinynet(
|
def _gen_tinynet(variant, model_width=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
|
||||||
variant, model_width=1.0, depth_multiplier=1.0, pretrained=False, **kwargs
|
|
||||||
):
|
|
||||||
"""Creates a TinyNet model.
|
"""Creates a TinyNet model.
|
||||||
"""
|
"""
|
||||||
arch_def = [
|
arch_def = [
|
||||||
@ -1183,8 +1206,7 @@ def _gen_mobilenet_edgetpu(variant, channel_multiplier=1.0, depth_multiplier=1.0
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def _gen_test_efficientnet(
|
def _gen_test_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
|
||||||
variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
|
|
||||||
""" Minimal test EfficientNet generator.
|
""" Minimal test EfficientNet generator.
|
||||||
"""
|
"""
|
||||||
arch_def = [
|
arch_def = [
|
||||||
|
@ -412,7 +412,9 @@ def _create_mnv3(variant: str, pretrained: bool = False, **kwargs) -> MobileNetV
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def _gen_mobilenet_v3_rw(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3:
|
def _gen_mobilenet_v3_rw(
|
||||||
|
variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs
|
||||||
|
) -> MobileNetV3:
|
||||||
"""Creates a MobileNet-V3 model.
|
"""Creates a MobileNet-V3 model.
|
||||||
|
|
||||||
Ref impl: ?
|
Ref impl: ?
|
||||||
@ -450,7 +452,9 @@ def _gen_mobilenet_v3_rw(variant: str, channel_multiplier: float = 1.0, pretrain
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def _gen_mobilenet_v3(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3:
|
def _gen_mobilenet_v3(
|
||||||
|
variant: str, channel_multiplier: float = 1.0, group_size=None, pretrained: bool = False, **kwargs
|
||||||
|
) -> MobileNetV3:
|
||||||
"""Creates a MobileNet-V3 model.
|
"""Creates a MobileNet-V3 model.
|
||||||
|
|
||||||
Ref impl: ?
|
Ref impl: ?
|
||||||
@ -533,7 +537,7 @@ def _gen_mobilenet_v3(variant: str, channel_multiplier: float = 1.0, pretrained:
|
|||||||
]
|
]
|
||||||
se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels)
|
se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels)
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
block_args=decode_arch_def(arch_def),
|
block_args=decode_arch_def(arch_def, group_size=group_size),
|
||||||
num_features=num_features,
|
num_features=num_features,
|
||||||
stem_size=16,
|
stem_size=16,
|
||||||
fix_stem=channel_multiplier < 0.75,
|
fix_stem=channel_multiplier < 0.75,
|
||||||
@ -646,7 +650,9 @@ def _gen_lcnet(variant: str, channel_multiplier: float = 1.0, pretrained: bool =
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3:
|
def _gen_mobilenet_v4(
|
||||||
|
variant: str, channel_multiplier: float = 1.0, group_size=None, pretrained: bool = False, **kwargs,
|
||||||
|
) -> MobileNetV3:
|
||||||
"""Creates a MobileNet-V4 model.
|
"""Creates a MobileNet-V4 model.
|
||||||
|
|
||||||
Ref impl: ?
|
Ref impl: ?
|
||||||
@ -877,7 +883,7 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained:
|
|||||||
assert False, f'Unknown variant {variant}.'
|
assert False, f'Unknown variant {variant}.'
|
||||||
|
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
block_args=decode_arch_def(arch_def),
|
block_args=decode_arch_def(arch_def, group_size=group_size),
|
||||||
head_bias=False,
|
head_bias=False,
|
||||||
head_norm=True,
|
head_norm=True,
|
||||||
num_features=num_features,
|
num_features=num_features,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user