# encoding: utf-8 """ @author: xingyu liao @contact: liaoxingyu5@jd.com """ from torch import nn __all__ = [ 'weights_init_classifier', 'weights_init_kaiming', ] def weights_init_kaiming(m): classname = m.__class__.__name__ if classname.find('Linear') != -1: nn.init.normal_(m.weight, 0, 0.01) if m.bias is not None: nn.init.constant_(m.bias, 0.0) elif classname.find('Conv') != -1: nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0.0) elif classname.find('BatchNorm') != -1: if m.affine: nn.init.normal_(m.weight, 1.0, 0.02) nn.init.constant_(m.bias, 0.0) def weights_init_classifier(m): classname = m.__class__.__name__ if classname.find('Linear') != -1: nn.init.normal_(m.weight, std=0.001) if m.bias is not None: nn.init.constant_(m.bias, 0.0)