[Feature] Add truncated normal weight init (#935)

* [Feature] Add truncated normal weight init

* [Feature] Add truncated normal weight init

* [Feature] Add truncated normal weight init

* update docstring

* delete modelA.pth

* modify according to comment

* use kstest to check truncated normal

* delete modelA.pth

* fix test.txt
pull/1045/head
Zaida Zhou 2021-05-23 20:54:54 +08:00 committed by GitHub
parent 4bd3b5027a
commit 55b4847a41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 220 additions and 27 deletions

View File

@ -15,25 +15,27 @@ from .builder import MODELS, build_model_from_cfg
# yapf: enable
from .resnet import ResNet, make_res_layer
from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit,
NormalInit, PretrainedInit, UniformInit, XavierInit,
bias_init_with_prob, caffe2_xavier_init, constant_init,
fuse_conv_bn, get_model_complexity_info, initialize,
kaiming_init, normal_init, uniform_init, xavier_init)
NormalInit, PretrainedInit, TruncNormalInit, UniformInit,
XavierInit, bias_init_with_prob, caffe2_xavier_init,
constant_init, fuse_conv_bn, get_model_complexity_info,
initialize, kaiming_init, normal_init, trunc_normal_init,
uniform_init, xavier_init)
from .vgg import VGG, make_vgg_layer
__all__ = [
'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer',
'constant_init', 'xavier_init', 'normal_init', 'uniform_init',
'kaiming_init', 'caffe2_xavier_init', 'bias_init_with_prob', 'ConvModule',
'build_activation_layer', 'build_conv_layer', 'build_norm_layer',
'build_padding_layer', 'build_upsample_layer', 'build_plugin_layer',
'is_norm', 'NonLocal1d', 'NonLocal2d', 'NonLocal3d', 'ContextBlock',
'HSigmoid', 'Swish', 'HSwish', 'GeneralizedAttention', 'ACTIVATION_LAYERS',
'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS',
'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info', 'conv_ws_2d',
'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule',
'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d',
'MaxPool3d', 'Conv3d', 'initialize', 'INITIALIZERS', 'ConstantInit',
'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
'constant_init', 'xavier_init', 'normal_init', 'trunc_normal_init',
'uniform_init', 'kaiming_init', 'caffe2_xavier_init',
'bias_init_with_prob', 'ConvModule', 'build_activation_layer',
'build_conv_layer', 'build_norm_layer', 'build_padding_layer',
'build_upsample_layer', 'build_plugin_layer', 'is_norm', 'NonLocal1d',
'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'HSigmoid', 'Swish', 'HSwish',
'GeneralizedAttention', 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS',
'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale',
'get_model_complexity_info', 'conv_ws_2d', 'ConvAWS2d', 'ConvWS2d',
'fuse_conv_bn', 'DepthwiseSeparableConvModule', 'Linear', 'Conv2d',
'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d', 'Conv3d',
'initialize', 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
'Caffe2XavierInit', 'MODELS', 'build_model_from_cfg'
]

View File

@ -2,15 +2,17 @@
from .flops_counter import get_model_complexity_info
from .fuse_conv_bn import fuse_conv_bn
from .weight_init import (INITIALIZERS, Caffe2XavierInit, ConstantInit,
KaimingInit, NormalInit, PretrainedInit, UniformInit,
XavierInit, bias_init_with_prob, caffe2_xavier_init,
KaimingInit, NormalInit, PretrainedInit,
TruncNormalInit, UniformInit, XavierInit,
bias_init_with_prob, caffe2_xavier_init,
constant_init, initialize, kaiming_init, normal_init,
uniform_init, xavier_init)
trunc_normal_init, uniform_init, xavier_init)
__all__ = [
'get_model_complexity_info', 'bias_init_with_prob', 'caffe2_xavier_init',
'constant_init', 'kaiming_init', 'normal_init', 'uniform_init',
'xavier_init', 'fuse_conv_bn', 'initialize', 'INITIALIZERS',
'ConstantInit', 'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit',
'PretrainedInit', 'Caffe2XavierInit'
'constant_init', 'kaiming_init', 'normal_init', 'trunc_normal_init',
'uniform_init', 'xavier_init', 'fuse_conv_bn', 'initialize',
'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
'Caffe2XavierInit'
]

View File

@ -1,9 +1,12 @@
# Copyright (c) Open-MMLab. All rights reserved.
import copy
import math
import warnings
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from mmcv.utils import Registry, build_from_cfg, get_logger, print_log
@ -35,6 +38,18 @@ def normal_init(module, mean=0, std=1, bias=0):
nn.init.constant_(module.bias, bias)
def trunc_normal_init(module: nn.Module,
mean: float = 0,
std: float = 1,
a: float = -2,
b: float = 2,
bias: float = 0) -> None:
if hasattr(module, 'weight') and module.weight is not None:
trunc_normal_(module.weight, mean, std, a, b) # type: ignore
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias) # type: ignore
def uniform_init(module, a=0, b=1, bias=0):
if hasattr(module, 'weight') and module.weight is not None:
nn.init.uniform_(module.weight, a, b)
@ -211,6 +226,55 @@ class NormalInit(BaseInit):
module.apply(init)
@INITIALIZERS.register_module(name='TruncNormal')
class TruncNormalInit(BaseInit):
r"""Initialize module parameters with the values drawn from the normal
distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values
outside :math:`[a, b]`.
Args:
mean (float): the mean of the normal distribution. Defaults to 0.
std (float): the standard deviation of the normal distribution.
Defaults to 1.
a (float): The minimum cutoff value.
b ( float): The maximum cutoff value.
bias (float): the value to fill the bias or define
initialization type for bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def __init__(self,
mean: float = 0,
std: float = 1,
a: float = -2,
b: float = 2,
**kwargs) -> None:
super().__init__(**kwargs)
self.mean = mean
self.std = std
self.a = a
self.b = b
def __call__(self, module: nn.Module) -> None:
def init(m):
if self.wholemodule:
trunc_normal_init(m, self.mean, self.std, self.a, self.b,
self.bias)
else:
layername = m.__class__.__name__
for layer_ in self.layer:
if layername == layer_:
trunc_normal_init(m, self.mean, self.std, self.a,
self.b, self.bias)
module.apply(init)
@INITIALIZERS.register_module(name='Uniform')
class UniformInit(BaseInit):
r"""Initialize module parameters with values drawn from the uniform
@ -468,3 +532,68 @@ def initialize(module, init_cfg):
else:
# All attributes in module have same initialization.
pass
def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float,
b: float) -> Tensor:
# Method based on
# https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
# Modified from
# https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
'The distribution of values may be incorrect.',
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
lower = norm_cdf((a - mean) / std)
upper = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [lower, upper], then translate
# to [2lower-1, 2upper-1].
tensor.uniform_(2 * lower - 1, 2 * upper - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor: Tensor,
mean: float = 0.,
std: float = 1.,
a: float = -2.,
b: float = 2.) -> Tensor:
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Modified from
https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
Args:
tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`.
mean (float): the mean of the normal distribution.
std (float): the standard deviation of the normal distribution.
a (float): the minimum cutoff value.
b (float): the maximum cutoff value.
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)

View File

@ -0,0 +1 @@
commit_id = '59b2b1c'

View File

@ -5,4 +5,5 @@ onnxoptimizer
onnxruntime==1.4.0
pytest
PyTurboJPEG
scipy
tiffile

View File

@ -14,6 +14,6 @@ line_length = 79
multi_line_output = 0
known_standard_library = pkg_resources,setuptools,logging,os,warnings,abc
known_first_party = mmcv
known_third_party = addict,cv2,m2r,numpy,onnx,onnxruntime,packaging,pytest,recommonmark,resnet_cifar,tensorrt,torch,torchvision,yaml,yapf
known_third_party = addict,cv2,m2r,numpy,onnx,onnxruntime,packaging,pytest,recommonmark,resnet_cifar,scipy,tensorrt,torch,torchvision,yaml,yapf
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY

View File

@ -1,16 +1,18 @@
# Copyright (c) Open-MMLab. All rights reserved.
import random
from tempfile import TemporaryDirectory
import numpy as np
import pytest
import torch
from scipy import stats
from torch import nn
from mmcv.cnn import (Caffe2XavierInit, ConstantInit, KaimingInit, NormalInit,
PretrainedInit, UniformInit, XavierInit,
PretrainedInit, TruncNormalInit, UniformInit, XavierInit,
bias_init_with_prob, caffe2_xavier_init, constant_init,
initialize, kaiming_init, normal_init, uniform_init,
xavier_init)
initialize, kaiming_init, normal_init, trunc_normal_init,
uniform_init, xavier_init)
def test_constant_init():
@ -47,6 +49,35 @@ def test_normal_init():
# TODO: sanity check distribution, e.g. mean, std
def test_trunc_normal_init():
def _random_float(a, b):
return (b - a) * random.random() + a
def _is_trunc_normal(tensor, mean, std, a, b):
# scipy's trunc norm is suited for data drawn from N(0, 1),
# so we need to transform our data to test it using scipy.
z_samples = (tensor.view(-1) - mean) / std
z_samples = z_samples.tolist()
a0 = (a - mean) / std
b0 = (b - mean) / std
p_value = stats.kstest(z_samples, 'truncnorm', args=(a0, b0))[1]
return p_value > 0.0001
conv_module = nn.Conv2d(3, 16, 3)
mean = _random_float(-3, 3)
std = _random_float(.01, 1)
a = _random_float(mean - 2 * std, mean)
b = _random_float(mean, mean + 2 * std)
trunc_normal_init(conv_module, mean, std, a, b, bias=0.1)
assert _is_trunc_normal(conv_module.weight, mean, std, a, b)
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
trunc_normal_init(conv_module_no_bias)
# TODO: sanity check distribution, e.g. mean, std
def test_uniform_init():
conv_module = nn.Conv2d(3, 16, 3)
uniform_init(conv_module, bias=0.1)
@ -168,6 +199,33 @@ def test_normalinit():
assert model[2].bias.allclose(torch.tensor(res))
def test_truncnormalinit():
"""test TruncNormalInit class."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = TruncNormalInit(
mean=100, std=1e-5, bias=200, a=0, b=200, layer=['Conv2d', 'Linear'])
func(model)
assert model[0].weight.allclose(torch.tensor(100.))
assert model[2].weight.allclose(torch.tensor(100.))
assert model[0].bias.allclose(torch.tensor(200.))
assert model[2].bias.allclose(torch.tensor(200.))
func = TruncNormalInit(
mean=300,
std=1e-5,
a=100,
b=400,
bias_prob=0.01,
layer=['Conv2d', 'Linear'])
res = bias_init_with_prob(0.01)
func(model)
assert model[0].weight.allclose(torch.tensor(300.))
assert model[2].weight.allclose(torch.tensor(300.))
assert model[0].bias.allclose(torch.tensor(res))
assert model[2].bias.allclose(torch.tensor(res))
def test_uniforminit():
""""test UniformInit class."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))