# Copyright (c) Alibaba, Inc. and its affiliates. import logging import math import torch import torch.nn as nn from mmcv.cnn import kaiming_init, normal_init def _no_grad_trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1. + math.erf(x / math.sqrt(2.))) / 2. if (mean < a - 2 * std) or (mean > b + 2 * std): logging.warning( 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. ' 'The distribution of values may be incorrect.', stacklevel=2) with torch.no_grad(): # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) return tensor def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): # (Tensor, float, float, float, float) -> Tensor return _no_grad_trunc_normal_(tensor, mean, std, a, b) def _init_weights(module, init_linear='normal', std=0.01, bias=0.): assert init_linear in ['normal', 'kaiming'], \ 'Undefined init_linear: {}'.format(init_linear) for m in module.modules(): if isinstance(m, nn.Linear): if init_linear == 'normal': normal_init(m, std=std, bias=bias) else: kaiming_init(m, mode='fan_in', nonlinearity='relu') elif isinstance( m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)): if m.weight is not None: nn.init.constant_(m.weight, 1) if m.bias is not None: nn.init.constant_(m.bias, 0)