fast-reid/fastreid/utils/weight_init.py

40 lines
1.1 KiB
Python
Raw Normal View History

2020-02-10 07:38:56 +08:00
# encoding: utf-8
"""
@author: xingyu liao
@contact: liaoxingyu5@jd.com
2020-02-10 07:38:56 +08:00
"""
import math
2020-02-10 07:38:56 +08:00
from torch import nn
__all__ = [
'weights_init_classifier',
'weights_init_kaiming',
]
2020-02-10 07:38:56 +08:00
def weights_init_kaiming(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
2020-03-25 10:58:26 +08:00
nn.init.normal_(m.weight, 0, 0.01)
2020-02-13 20:37:08 +08:00
if m.bias is not None:
nn.init.constant_(m.bias, 0.0)
2020-02-10 07:38:56 +08:00
elif classname.find('Conv') != -1:
2020-03-25 10:58:26 +08:00
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
2020-02-10 07:38:56 +08:00
if m.bias is not None:
nn.init.constant_(m.bias, 0.0)
elif classname.find('BatchNorm') != -1:
if m.affine:
2020-03-25 10:58:26 +08:00
nn.init.normal_(m.weight, 1.0, 0.02)
2020-02-10 07:38:56 +08:00
nn.init.constant_(m.bias, 0.0)
def weights_init_classifier(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
nn.init.normal_(m.weight, std=0.001)
2020-02-13 20:37:08 +08:00
if m.bias is not None:
2020-02-10 07:38:56 +08:00
nn.init.constant_(m.bias, 0.0)
elif classname.find("Arcface") != -1 or classname.find("Circle") != -1:
nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))