Add support back to EfficientNet to disable head_conv / bn2 so mobilnetv1 can be implemented properly

small_things
Ross Wightman 2024-07-08 13:51:26 -07:00
parent 800405d941
commit 1334598462
1 changed files with 21 additions and 5 deletions

View File

@ -98,7 +98,6 @@ class EfficientNet(nn.Module):
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
se_layer = se_layer or SqueezeExcite
self.num_classes = num_classes
self.num_features = self.head_hidden_size = num_features
self.drop_rate = drop_rate
self.grad_checkpointing = False
@ -125,8 +124,15 @@ class EfficientNet(nn.Module):
head_chs = builder.in_chs
# Head + Pooling
self.conv_head = create_conv2d(head_chs, self.num_features, 1, padding=pad_type)
self.bn2 = norm_act_layer(self.num_features, inplace=True)
if num_features > 0:
self.conv_head = create_conv2d(head_chs, num_features, 1, padding=pad_type)
self.bn2 = norm_act_layer(num_features, inplace=True)
self.num_features = self.head_hidden_size = num_features
else:
self.conv_head = nn.Identity()
self.bn2 = nn.Identity()
self.num_features = head_chs
self.global_pool, self.classifier = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
@ -481,7 +487,8 @@ def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwar
def _gen_mobilenet_v1(
variant, channel_multiplier=1.0, depth_multiplier=1.0, fix_stem_head=False, pretrained=False, **kwargs):
variant, channel_multiplier=1.0, depth_multiplier=1.0,
fix_stem_head=False, head_conv=False, pretrained=False, **kwargs):
"""
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
Paper: https://arxiv.org/abs/1801.04381
@ -494,9 +501,10 @@ def _gen_mobilenet_v1(
['dsa_r2_k3_s2_c1024'],
]
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
head_features = (1024 if fix_stem_head else max(1024, round_chs_fn(1024))) if head_conv else 0
model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head),
num_features=1024 if fix_stem_head else max(1024, round_chs_fn(1024)),
num_features=head_features,
stem_size=32,
fix_stem=fix_stem_head,
round_chs_fn=round_chs_fn,
@ -1206,6 +1214,7 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='timm/'),
'mobilenet_100.untrained': _cfg(),
'mobilenet_100h.untrained': _cfg(),
'mobilenet_125.untrained': _cfg(),
'mobilenetv2_035.untrained': _cfg(),
@ -1795,6 +1804,13 @@ def mobilenet_100(pretrained=False, **kwargs) -> EfficientNet:
return model
@register_model
def mobilenet_100h(pretrained=False, **kwargs) -> EfficientNet:
""" MobileNet V1 """
model = _gen_mobilenet_v1('mobilenet_100h', 1.0, head_conv=True, pretrained=pretrained, **kwargs)
return model
@register_model
def mobilenet_125(pretrained=False, **kwargs) -> EfficientNet:
""" MobileNet V1 """