2020-02-10 07:38:56 +08:00
|
|
|
# encoding: utf-8
|
|
|
|
"""
|
2020-05-01 09:02:46 +08:00
|
|
|
@author: xingyu liao
|
2020-07-29 17:43:39 +08:00
|
|
|
@contact: sherlockliao01@gmail.com
|
2020-02-10 07:38:56 +08:00
|
|
|
"""
|
2020-05-01 09:02:46 +08:00
|
|
|
|
2020-02-10 07:38:56 +08:00
|
|
|
from torch import nn
|
|
|
|
|
2020-05-01 09:02:46 +08:00
|
|
|
__all__ = [
|
|
|
|
'weights_init_classifier',
|
|
|
|
'weights_init_kaiming',
|
|
|
|
]
|
2020-02-10 07:38:56 +08:00
|
|
|
|
|
|
|
|
|
|
|
def weights_init_kaiming(m):
|
|
|
|
classname = m.__class__.__name__
|
|
|
|
if classname.find('Linear') != -1:
|
2020-03-25 10:58:26 +08:00
|
|
|
nn.init.normal_(m.weight, 0, 0.01)
|
2020-02-13 20:37:08 +08:00
|
|
|
if m.bias is not None:
|
2020-02-13 00:19:15 +08:00
|
|
|
nn.init.constant_(m.bias, 0.0)
|
2020-02-10 07:38:56 +08:00
|
|
|
elif classname.find('Conv') != -1:
|
2020-12-07 14:19:20 +08:00
|
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
2020-02-10 07:38:56 +08:00
|
|
|
if m.bias is not None:
|
|
|
|
nn.init.constant_(m.bias, 0.0)
|
|
|
|
elif classname.find('BatchNorm') != -1:
|
|
|
|
if m.affine:
|
2020-12-07 14:19:20 +08:00
|
|
|
nn.init.constant_(m.weight, 1.0)
|
2020-02-10 07:38:56 +08:00
|
|
|
nn.init.constant_(m.bias, 0.0)
|
|
|
|
|
|
|
|
|
|
|
|
def weights_init_classifier(m):
|
|
|
|
classname = m.__class__.__name__
|
|
|
|
if classname.find('Linear') != -1:
|
|
|
|
nn.init.normal_(m.weight, std=0.001)
|
2020-02-13 20:37:08 +08:00
|
|
|
if m.bias is not None:
|
2020-02-10 07:38:56 +08:00
|
|
|
nn.init.constant_(m.bias, 0.0)
|