Add support back to EfficientNet to disable head_conv / bn2 so mobilnetv1 can be implemented properly
parent
800405d941
commit
1334598462
|
@ -98,7 +98,6 @@ class EfficientNet(nn.Module):
|
|||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
||||
se_layer = se_layer or SqueezeExcite
|
||||
self.num_classes = num_classes
|
||||
self.num_features = self.head_hidden_size = num_features
|
||||
self.drop_rate = drop_rate
|
||||
self.grad_checkpointing = False
|
||||
|
||||
|
@ -125,8 +124,15 @@ class EfficientNet(nn.Module):
|
|||
head_chs = builder.in_chs
|
||||
|
||||
# Head + Pooling
|
||||
self.conv_head = create_conv2d(head_chs, self.num_features, 1, padding=pad_type)
|
||||
self.bn2 = norm_act_layer(self.num_features, inplace=True)
|
||||
if num_features > 0:
|
||||
self.conv_head = create_conv2d(head_chs, num_features, 1, padding=pad_type)
|
||||
self.bn2 = norm_act_layer(num_features, inplace=True)
|
||||
self.num_features = self.head_hidden_size = num_features
|
||||
else:
|
||||
self.conv_head = nn.Identity()
|
||||
self.bn2 = nn.Identity()
|
||||
self.num_features = head_chs
|
||||
|
||||
self.global_pool, self.classifier = create_classifier(
|
||||
self.num_features, self.num_classes, pool_type=global_pool)
|
||||
|
||||
|
@ -481,7 +487,8 @@ def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwar
|
|||
|
||||
|
||||
def _gen_mobilenet_v1(
|
||||
variant, channel_multiplier=1.0, depth_multiplier=1.0, fix_stem_head=False, pretrained=False, **kwargs):
|
||||
variant, channel_multiplier=1.0, depth_multiplier=1.0,
|
||||
fix_stem_head=False, head_conv=False, pretrained=False, **kwargs):
|
||||
"""
|
||||
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
|
||||
Paper: https://arxiv.org/abs/1801.04381
|
||||
|
@ -494,9 +501,10 @@ def _gen_mobilenet_v1(
|
|||
['dsa_r2_k3_s2_c1024'],
|
||||
]
|
||||
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
|
||||
head_features = (1024 if fix_stem_head else max(1024, round_chs_fn(1024))) if head_conv else 0
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head),
|
||||
num_features=1024 if fix_stem_head else max(1024, round_chs_fn(1024)),
|
||||
num_features=head_features,
|
||||
stem_size=32,
|
||||
fix_stem=fix_stem_head,
|
||||
round_chs_fn=round_chs_fn,
|
||||
|
@ -1206,6 +1214,7 @@ default_cfgs = generate_default_cfgs({
|
|||
hf_hub_id='timm/'),
|
||||
|
||||
'mobilenet_100.untrained': _cfg(),
|
||||
'mobilenet_100h.untrained': _cfg(),
|
||||
'mobilenet_125.untrained': _cfg(),
|
||||
|
||||
'mobilenetv2_035.untrained': _cfg(),
|
||||
|
@ -1795,6 +1804,13 @@ def mobilenet_100(pretrained=False, **kwargs) -> EfficientNet:
|
|||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenet_100h(pretrained=False, **kwargs) -> EfficientNet:
|
||||
""" MobileNet V1 """
|
||||
model = _gen_mobilenet_v1('mobilenet_100h', 1.0, head_conv=True, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenet_125(pretrained=False, **kwargs) -> EfficientNet:
|
||||
""" MobileNet V1 """
|
||||
|
|
Loading…
Reference in New Issue