mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Remove unused default_init for EfficientNets, experimenting with fanout calc for #84
This commit is contained in:
parent
cade829105
commit
d0eb59ef46
@ -358,15 +358,24 @@ class EfficientNetBuilder:
|
||||
return stages
|
||||
|
||||
|
||||
def _init_weight_goog(m, n=''):
|
||||
def _init_weight_goog(m, n='', fix_group_fanout=False):
|
||||
""" Weight initialization as per Tensorflow official implementations.
|
||||
|
||||
Args:
|
||||
m (nn.Module): module to init
|
||||
n (str): module name
|
||||
fix_group_fanout (bool): enable correct fanout calculation w/ group convs
|
||||
|
||||
FIXME change fix_group_fanout to default to True if experiments show better training results
|
||||
|
||||
Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc:
|
||||
* https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
|
||||
* https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
|
||||
"""
|
||||
if isinstance(m, CondConv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
if fix_group_fanout:
|
||||
fan_out //= m.groups
|
||||
init_weight_fn = get_condconv_initializer(
|
||||
lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
|
||||
init_weight_fn(m.weight)
|
||||
@ -374,6 +383,8 @@ def _init_weight_goog(m, n=''):
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
if fix_group_fanout:
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
@ -390,21 +401,6 @@ def _init_weight_goog(m, n=''):
|
||||
m.bias.data.zero_()
|
||||
|
||||
|
||||
def _init_weight_default(m, n=''):
|
||||
""" Basic ResNet (Kaiming) style weight init"""
|
||||
if isinstance(m, CondConv2d):
|
||||
init_fn = get_condconv_initializer(partial(
|
||||
nn.init.kaiming_normal_, mode='fan_out', nonlinearity='relu'), m.num_experts, m.weight_shape)
|
||||
init_fn(m.weight)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1.0)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear')
|
||||
|
||||
|
||||
def efficientnet_init_weights(model: nn.Module, init_fn=None):
|
||||
init_fn = init_fn or _init_weight_goog
|
||||
for n, m in model.named_modules():
|
||||
|
Loading…
x
Reference in New Issue
Block a user