mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add support back to EfficientNet to disable head_conv / bn2 so mobilnetv1 can be implemented properly
This commit is contained in:
parent
800405d941
commit
1334598462
@ -98,7 +98,6 @@ class EfficientNet(nn.Module):
|
|||||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
||||||
se_layer = se_layer or SqueezeExcite
|
se_layer = se_layer or SqueezeExcite
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.num_features = self.head_hidden_size = num_features
|
|
||||||
self.drop_rate = drop_rate
|
self.drop_rate = drop_rate
|
||||||
self.grad_checkpointing = False
|
self.grad_checkpointing = False
|
||||||
|
|
||||||
@ -125,8 +124,15 @@ class EfficientNet(nn.Module):
|
|||||||
head_chs = builder.in_chs
|
head_chs = builder.in_chs
|
||||||
|
|
||||||
# Head + Pooling
|
# Head + Pooling
|
||||||
self.conv_head = create_conv2d(head_chs, self.num_features, 1, padding=pad_type)
|
if num_features > 0:
|
||||||
self.bn2 = norm_act_layer(self.num_features, inplace=True)
|
self.conv_head = create_conv2d(head_chs, num_features, 1, padding=pad_type)
|
||||||
|
self.bn2 = norm_act_layer(num_features, inplace=True)
|
||||||
|
self.num_features = self.head_hidden_size = num_features
|
||||||
|
else:
|
||||||
|
self.conv_head = nn.Identity()
|
||||||
|
self.bn2 = nn.Identity()
|
||||||
|
self.num_features = head_chs
|
||||||
|
|
||||||
self.global_pool, self.classifier = create_classifier(
|
self.global_pool, self.classifier = create_classifier(
|
||||||
self.num_features, self.num_classes, pool_type=global_pool)
|
self.num_features, self.num_classes, pool_type=global_pool)
|
||||||
|
|
||||||
@ -481,7 +487,8 @@ def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwar
|
|||||||
|
|
||||||
|
|
||||||
def _gen_mobilenet_v1(
|
def _gen_mobilenet_v1(
|
||||||
variant, channel_multiplier=1.0, depth_multiplier=1.0, fix_stem_head=False, pretrained=False, **kwargs):
|
variant, channel_multiplier=1.0, depth_multiplier=1.0,
|
||||||
|
fix_stem_head=False, head_conv=False, pretrained=False, **kwargs):
|
||||||
"""
|
"""
|
||||||
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
|
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
|
||||||
Paper: https://arxiv.org/abs/1801.04381
|
Paper: https://arxiv.org/abs/1801.04381
|
||||||
@ -494,9 +501,10 @@ def _gen_mobilenet_v1(
|
|||||||
['dsa_r2_k3_s2_c1024'],
|
['dsa_r2_k3_s2_c1024'],
|
||||||
]
|
]
|
||||||
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
|
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
|
||||||
|
head_features = (1024 if fix_stem_head else max(1024, round_chs_fn(1024))) if head_conv else 0
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head),
|
block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head),
|
||||||
num_features=1024 if fix_stem_head else max(1024, round_chs_fn(1024)),
|
num_features=head_features,
|
||||||
stem_size=32,
|
stem_size=32,
|
||||||
fix_stem=fix_stem_head,
|
fix_stem=fix_stem_head,
|
||||||
round_chs_fn=round_chs_fn,
|
round_chs_fn=round_chs_fn,
|
||||||
@ -1206,6 +1214,7 @@ default_cfgs = generate_default_cfgs({
|
|||||||
hf_hub_id='timm/'),
|
hf_hub_id='timm/'),
|
||||||
|
|
||||||
'mobilenet_100.untrained': _cfg(),
|
'mobilenet_100.untrained': _cfg(),
|
||||||
|
'mobilenet_100h.untrained': _cfg(),
|
||||||
'mobilenet_125.untrained': _cfg(),
|
'mobilenet_125.untrained': _cfg(),
|
||||||
|
|
||||||
'mobilenetv2_035.untrained': _cfg(),
|
'mobilenetv2_035.untrained': _cfg(),
|
||||||
@ -1795,6 +1804,13 @@ def mobilenet_100(pretrained=False, **kwargs) -> EfficientNet:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def mobilenet_100h(pretrained=False, **kwargs) -> EfficientNet:
|
||||||
|
""" MobileNet V1 """
|
||||||
|
model = _gen_mobilenet_v1('mobilenet_100h', 1.0, head_conv=True, pretrained=pretrained, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def mobilenet_125(pretrained=False, **kwargs) -> EfficientNet:
|
def mobilenet_125(pretrained=False, **kwargs) -> EfficientNet:
|
||||||
""" MobileNet V1 """
|
""" MobileNet V1 """
|
||||||
|
Loading…
x
Reference in New Issue
Block a user