33 lines
1000 B
Python
33 lines
1000 B
Python
# Copyright (c) Malong Technologies Co., Ltd.
|
|
# All rights reserved.
|
|
#
|
|
# Contact: github@malong.com
|
|
#
|
|
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
|
|
def weights_init_kaiming(m):
|
|
classname = m.__class__.__name__
|
|
if classname.find('Linear') != -1:
|
|
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
|
|
nn.init.constant_(m.bias, 0.0)
|
|
elif classname.find('Conv') != -1:
|
|
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0.0)
|
|
elif classname.find('BatchNorm') != -1:
|
|
if m.affine:
|
|
nn.init.constant_(m.weight, 1.0)
|
|
nn.init.constant_(m.bias, 0.0)
|
|
|
|
|
|
def weights_init_classifier(m):
|
|
classname = m.__class__.__name__
|
|
if classname.find('Linear') != -1:
|
|
nn.init.normal_(m.weight, std=0.001)
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0.0)
|