mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Add truncated normal weight init (#935)
* [Feature] Add truncated normal weight init * [Feature] Add truncated normal weight init * [Feature] Add truncated normal weight init * update docstring * delete modelA.pth * modify according to comment * use kstest to check truncated normal * delete modelA.pth * fix test.txtpull/1045/head
parent
4bd3b5027a
commit
55b4847a41
|
@ -15,25 +15,27 @@ from .builder import MODELS, build_model_from_cfg
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from .resnet import ResNet, make_res_layer
|
from .resnet import ResNet, make_res_layer
|
||||||
from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit,
|
from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit,
|
||||||
NormalInit, PretrainedInit, UniformInit, XavierInit,
|
NormalInit, PretrainedInit, TruncNormalInit, UniformInit,
|
||||||
bias_init_with_prob, caffe2_xavier_init, constant_init,
|
XavierInit, bias_init_with_prob, caffe2_xavier_init,
|
||||||
fuse_conv_bn, get_model_complexity_info, initialize,
|
constant_init, fuse_conv_bn, get_model_complexity_info,
|
||||||
kaiming_init, normal_init, uniform_init, xavier_init)
|
initialize, kaiming_init, normal_init, trunc_normal_init,
|
||||||
|
uniform_init, xavier_init)
|
||||||
from .vgg import VGG, make_vgg_layer
|
from .vgg import VGG, make_vgg_layer
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer',
|
'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer',
|
||||||
'constant_init', 'xavier_init', 'normal_init', 'uniform_init',
|
'constant_init', 'xavier_init', 'normal_init', 'trunc_normal_init',
|
||||||
'kaiming_init', 'caffe2_xavier_init', 'bias_init_with_prob', 'ConvModule',
|
'uniform_init', 'kaiming_init', 'caffe2_xavier_init',
|
||||||
'build_activation_layer', 'build_conv_layer', 'build_norm_layer',
|
'bias_init_with_prob', 'ConvModule', 'build_activation_layer',
|
||||||
'build_padding_layer', 'build_upsample_layer', 'build_plugin_layer',
|
'build_conv_layer', 'build_norm_layer', 'build_padding_layer',
|
||||||
'is_norm', 'NonLocal1d', 'NonLocal2d', 'NonLocal3d', 'ContextBlock',
|
'build_upsample_layer', 'build_plugin_layer', 'is_norm', 'NonLocal1d',
|
||||||
'HSigmoid', 'Swish', 'HSwish', 'GeneralizedAttention', 'ACTIVATION_LAYERS',
|
'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'HSigmoid', 'Swish', 'HSwish',
|
||||||
'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS',
|
'GeneralizedAttention', 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS',
|
||||||
'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info', 'conv_ws_2d',
|
'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale',
|
||||||
'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule',
|
'get_model_complexity_info', 'conv_ws_2d', 'ConvAWS2d', 'ConvWS2d',
|
||||||
'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d',
|
'fuse_conv_bn', 'DepthwiseSeparableConvModule', 'Linear', 'Conv2d',
|
||||||
'MaxPool3d', 'Conv3d', 'initialize', 'INITIALIZERS', 'ConstantInit',
|
'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d', 'Conv3d',
|
||||||
'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
|
'initialize', 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
|
||||||
|
'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
|
||||||
'Caffe2XavierInit', 'MODELS', 'build_model_from_cfg'
|
'Caffe2XavierInit', 'MODELS', 'build_model_from_cfg'
|
||||||
]
|
]
|
||||||
|
|
|
@ -2,15 +2,17 @@
|
||||||
from .flops_counter import get_model_complexity_info
|
from .flops_counter import get_model_complexity_info
|
||||||
from .fuse_conv_bn import fuse_conv_bn
|
from .fuse_conv_bn import fuse_conv_bn
|
||||||
from .weight_init import (INITIALIZERS, Caffe2XavierInit, ConstantInit,
|
from .weight_init import (INITIALIZERS, Caffe2XavierInit, ConstantInit,
|
||||||
KaimingInit, NormalInit, PretrainedInit, UniformInit,
|
KaimingInit, NormalInit, PretrainedInit,
|
||||||
XavierInit, bias_init_with_prob, caffe2_xavier_init,
|
TruncNormalInit, UniformInit, XavierInit,
|
||||||
|
bias_init_with_prob, caffe2_xavier_init,
|
||||||
constant_init, initialize, kaiming_init, normal_init,
|
constant_init, initialize, kaiming_init, normal_init,
|
||||||
uniform_init, xavier_init)
|
trunc_normal_init, uniform_init, xavier_init)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'get_model_complexity_info', 'bias_init_with_prob', 'caffe2_xavier_init',
|
'get_model_complexity_info', 'bias_init_with_prob', 'caffe2_xavier_init',
|
||||||
'constant_init', 'kaiming_init', 'normal_init', 'uniform_init',
|
'constant_init', 'kaiming_init', 'normal_init', 'trunc_normal_init',
|
||||||
'xavier_init', 'fuse_conv_bn', 'initialize', 'INITIALIZERS',
|
'uniform_init', 'xavier_init', 'fuse_conv_bn', 'initialize',
|
||||||
'ConstantInit', 'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit',
|
'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
|
||||||
'PretrainedInit', 'Caffe2XavierInit'
|
'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
|
||||||
|
'Caffe2XavierInit'
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,9 +1,12 @@
|
||||||
# Copyright (c) Open-MMLab. All rights reserved.
|
# Copyright (c) Open-MMLab. All rights reserved.
|
||||||
import copy
|
import copy
|
||||||
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
from mmcv.utils import Registry, build_from_cfg, get_logger, print_log
|
from mmcv.utils import Registry, build_from_cfg, get_logger, print_log
|
||||||
|
|
||||||
|
@ -35,6 +38,18 @@ def normal_init(module, mean=0, std=1, bias=0):
|
||||||
nn.init.constant_(module.bias, bias)
|
nn.init.constant_(module.bias, bias)
|
||||||
|
|
||||||
|
|
||||||
|
def trunc_normal_init(module: nn.Module,
|
||||||
|
mean: float = 0,
|
||||||
|
std: float = 1,
|
||||||
|
a: float = -2,
|
||||||
|
b: float = 2,
|
||||||
|
bias: float = 0) -> None:
|
||||||
|
if hasattr(module, 'weight') and module.weight is not None:
|
||||||
|
trunc_normal_(module.weight, mean, std, a, b) # type: ignore
|
||||||
|
if hasattr(module, 'bias') and module.bias is not None:
|
||||||
|
nn.init.constant_(module.bias, bias) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def uniform_init(module, a=0, b=1, bias=0):
|
def uniform_init(module, a=0, b=1, bias=0):
|
||||||
if hasattr(module, 'weight') and module.weight is not None:
|
if hasattr(module, 'weight') and module.weight is not None:
|
||||||
nn.init.uniform_(module.weight, a, b)
|
nn.init.uniform_(module.weight, a, b)
|
||||||
|
@ -211,6 +226,55 @@ class NormalInit(BaseInit):
|
||||||
module.apply(init)
|
module.apply(init)
|
||||||
|
|
||||||
|
|
||||||
|
@INITIALIZERS.register_module(name='TruncNormal')
|
||||||
|
class TruncNormalInit(BaseInit):
|
||||||
|
r"""Initialize module parameters with the values drawn from the normal
|
||||||
|
distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values
|
||||||
|
outside :math:`[a, b]`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mean (float): the mean of the normal distribution. Defaults to 0.
|
||||||
|
std (float): the standard deviation of the normal distribution.
|
||||||
|
Defaults to 1.
|
||||||
|
a (float): The minimum cutoff value.
|
||||||
|
b ( float): The maximum cutoff value.
|
||||||
|
bias (float): the value to fill the bias or define
|
||||||
|
initialization type for bias. Defaults to 0.
|
||||||
|
bias_prob (float, optional): the probability for bias initialization.
|
||||||
|
Defaults to None.
|
||||||
|
layer (str | list[str], optional): the layer will be initialized.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
mean: float = 0,
|
||||||
|
std: float = 1,
|
||||||
|
a: float = -2,
|
||||||
|
b: float = 2,
|
||||||
|
**kwargs) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.mean = mean
|
||||||
|
self.std = std
|
||||||
|
self.a = a
|
||||||
|
self.b = b
|
||||||
|
|
||||||
|
def __call__(self, module: nn.Module) -> None:
|
||||||
|
|
||||||
|
def init(m):
|
||||||
|
if self.wholemodule:
|
||||||
|
trunc_normal_init(m, self.mean, self.std, self.a, self.b,
|
||||||
|
self.bias)
|
||||||
|
else:
|
||||||
|
layername = m.__class__.__name__
|
||||||
|
for layer_ in self.layer:
|
||||||
|
if layername == layer_:
|
||||||
|
trunc_normal_init(m, self.mean, self.std, self.a,
|
||||||
|
self.b, self.bias)
|
||||||
|
|
||||||
|
module.apply(init)
|
||||||
|
|
||||||
|
|
||||||
@INITIALIZERS.register_module(name='Uniform')
|
@INITIALIZERS.register_module(name='Uniform')
|
||||||
class UniformInit(BaseInit):
|
class UniformInit(BaseInit):
|
||||||
r"""Initialize module parameters with values drawn from the uniform
|
r"""Initialize module parameters with values drawn from the uniform
|
||||||
|
@ -468,3 +532,68 @@ def initialize(module, init_cfg):
|
||||||
else:
|
else:
|
||||||
# All attributes in module have same initialization.
|
# All attributes in module have same initialization.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float,
|
||||||
|
b: float) -> Tensor:
|
||||||
|
# Method based on
|
||||||
|
# https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
||||||
|
# Modified from
|
||||||
|
# https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
|
||||||
|
def norm_cdf(x):
|
||||||
|
# Computes standard normal cumulative distribution function
|
||||||
|
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
||||||
|
|
||||||
|
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
||||||
|
warnings.warn(
|
||||||
|
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
|
||||||
|
'The distribution of values may be incorrect.',
|
||||||
|
stacklevel=2)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# Values are generated by using a truncated uniform distribution and
|
||||||
|
# then using the inverse CDF for the normal distribution.
|
||||||
|
# Get upper and lower cdf values
|
||||||
|
lower = norm_cdf((a - mean) / std)
|
||||||
|
upper = norm_cdf((b - mean) / std)
|
||||||
|
|
||||||
|
# Uniformly fill tensor with values from [lower, upper], then translate
|
||||||
|
# to [2lower-1, 2upper-1].
|
||||||
|
tensor.uniform_(2 * lower - 1, 2 * upper - 1)
|
||||||
|
|
||||||
|
# Use inverse cdf transform for normal distribution to get truncated
|
||||||
|
# standard normal
|
||||||
|
tensor.erfinv_()
|
||||||
|
|
||||||
|
# Transform to proper mean, std
|
||||||
|
tensor.mul_(std * math.sqrt(2.))
|
||||||
|
tensor.add_(mean)
|
||||||
|
|
||||||
|
# Clamp to ensure it's in the proper range
|
||||||
|
tensor.clamp_(min=a, max=b)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def trunc_normal_(tensor: Tensor,
|
||||||
|
mean: float = 0.,
|
||||||
|
std: float = 1.,
|
||||||
|
a: float = -2.,
|
||||||
|
b: float = 2.) -> Tensor:
|
||||||
|
r"""Fills the input Tensor with values drawn from a truncated
|
||||||
|
normal distribution. The values are effectively drawn from the
|
||||||
|
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
||||||
|
with values outside :math:`[a, b]` redrawn until they are within
|
||||||
|
the bounds. The method used for generating the random values works
|
||||||
|
best when :math:`a \leq \text{mean} \leq b`.
|
||||||
|
|
||||||
|
Modified from
|
||||||
|
https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`.
|
||||||
|
mean (float): the mean of the normal distribution.
|
||||||
|
std (float): the standard deviation of the normal distribution.
|
||||||
|
a (float): the minimum cutoff value.
|
||||||
|
b (float): the maximum cutoff value.
|
||||||
|
"""
|
||||||
|
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
commit_id = '59b2b1c'
|
|
@ -5,4 +5,5 @@ onnxoptimizer
|
||||||
onnxruntime==1.4.0
|
onnxruntime==1.4.0
|
||||||
pytest
|
pytest
|
||||||
PyTurboJPEG
|
PyTurboJPEG
|
||||||
|
scipy
|
||||||
tiffile
|
tiffile
|
||||||
|
|
|
@ -14,6 +14,6 @@ line_length = 79
|
||||||
multi_line_output = 0
|
multi_line_output = 0
|
||||||
known_standard_library = pkg_resources,setuptools,logging,os,warnings,abc
|
known_standard_library = pkg_resources,setuptools,logging,os,warnings,abc
|
||||||
known_first_party = mmcv
|
known_first_party = mmcv
|
||||||
known_third_party = addict,cv2,m2r,numpy,onnx,onnxruntime,packaging,pytest,recommonmark,resnet_cifar,tensorrt,torch,torchvision,yaml,yapf
|
known_third_party = addict,cv2,m2r,numpy,onnx,onnxruntime,packaging,pytest,recommonmark,resnet_cifar,scipy,tensorrt,torch,torchvision,yaml,yapf
|
||||||
no_lines_before = STDLIB,LOCALFOLDER
|
no_lines_before = STDLIB,LOCALFOLDER
|
||||||
default_section = THIRDPARTY
|
default_section = THIRDPARTY
|
||||||
|
|
|
@ -1,16 +1,18 @@
|
||||||
# Copyright (c) Open-MMLab. All rights reserved.
|
# Copyright (c) Open-MMLab. All rights reserved.
|
||||||
|
import random
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from scipy import stats
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from mmcv.cnn import (Caffe2XavierInit, ConstantInit, KaimingInit, NormalInit,
|
from mmcv.cnn import (Caffe2XavierInit, ConstantInit, KaimingInit, NormalInit,
|
||||||
PretrainedInit, UniformInit, XavierInit,
|
PretrainedInit, TruncNormalInit, UniformInit, XavierInit,
|
||||||
bias_init_with_prob, caffe2_xavier_init, constant_init,
|
bias_init_with_prob, caffe2_xavier_init, constant_init,
|
||||||
initialize, kaiming_init, normal_init, uniform_init,
|
initialize, kaiming_init, normal_init, trunc_normal_init,
|
||||||
xavier_init)
|
uniform_init, xavier_init)
|
||||||
|
|
||||||
|
|
||||||
def test_constant_init():
|
def test_constant_init():
|
||||||
|
@ -47,6 +49,35 @@ def test_normal_init():
|
||||||
# TODO: sanity check distribution, e.g. mean, std
|
# TODO: sanity check distribution, e.g. mean, std
|
||||||
|
|
||||||
|
|
||||||
|
def test_trunc_normal_init():
|
||||||
|
|
||||||
|
def _random_float(a, b):
|
||||||
|
return (b - a) * random.random() + a
|
||||||
|
|
||||||
|
def _is_trunc_normal(tensor, mean, std, a, b):
|
||||||
|
# scipy's trunc norm is suited for data drawn from N(0, 1),
|
||||||
|
# so we need to transform our data to test it using scipy.
|
||||||
|
z_samples = (tensor.view(-1) - mean) / std
|
||||||
|
z_samples = z_samples.tolist()
|
||||||
|
a0 = (a - mean) / std
|
||||||
|
b0 = (b - mean) / std
|
||||||
|
p_value = stats.kstest(z_samples, 'truncnorm', args=(a0, b0))[1]
|
||||||
|
return p_value > 0.0001
|
||||||
|
|
||||||
|
conv_module = nn.Conv2d(3, 16, 3)
|
||||||
|
mean = _random_float(-3, 3)
|
||||||
|
std = _random_float(.01, 1)
|
||||||
|
a = _random_float(mean - 2 * std, mean)
|
||||||
|
b = _random_float(mean, mean + 2 * std)
|
||||||
|
trunc_normal_init(conv_module, mean, std, a, b, bias=0.1)
|
||||||
|
assert _is_trunc_normal(conv_module.weight, mean, std, a, b)
|
||||||
|
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
|
||||||
|
|
||||||
|
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
|
||||||
|
trunc_normal_init(conv_module_no_bias)
|
||||||
|
# TODO: sanity check distribution, e.g. mean, std
|
||||||
|
|
||||||
|
|
||||||
def test_uniform_init():
|
def test_uniform_init():
|
||||||
conv_module = nn.Conv2d(3, 16, 3)
|
conv_module = nn.Conv2d(3, 16, 3)
|
||||||
uniform_init(conv_module, bias=0.1)
|
uniform_init(conv_module, bias=0.1)
|
||||||
|
@ -168,6 +199,33 @@ def test_normalinit():
|
||||||
assert model[2].bias.allclose(torch.tensor(res))
|
assert model[2].bias.allclose(torch.tensor(res))
|
||||||
|
|
||||||
|
|
||||||
|
def test_truncnormalinit():
|
||||||
|
"""test TruncNormalInit class."""
|
||||||
|
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
|
||||||
|
|
||||||
|
func = TruncNormalInit(
|
||||||
|
mean=100, std=1e-5, bias=200, a=0, b=200, layer=['Conv2d', 'Linear'])
|
||||||
|
func(model)
|
||||||
|
assert model[0].weight.allclose(torch.tensor(100.))
|
||||||
|
assert model[2].weight.allclose(torch.tensor(100.))
|
||||||
|
assert model[0].bias.allclose(torch.tensor(200.))
|
||||||
|
assert model[2].bias.allclose(torch.tensor(200.))
|
||||||
|
|
||||||
|
func = TruncNormalInit(
|
||||||
|
mean=300,
|
||||||
|
std=1e-5,
|
||||||
|
a=100,
|
||||||
|
b=400,
|
||||||
|
bias_prob=0.01,
|
||||||
|
layer=['Conv2d', 'Linear'])
|
||||||
|
res = bias_init_with_prob(0.01)
|
||||||
|
func(model)
|
||||||
|
assert model[0].weight.allclose(torch.tensor(300.))
|
||||||
|
assert model[2].weight.allclose(torch.tensor(300.))
|
||||||
|
assert model[0].bias.allclose(torch.tensor(res))
|
||||||
|
assert model[2].bias.allclose(torch.tensor(res))
|
||||||
|
|
||||||
|
|
||||||
def test_uniforminit():
|
def test_uniforminit():
|
||||||
""""test UniformInit class."""
|
""""test UniformInit class."""
|
||||||
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
|
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
|
||||||
|
|
Loading…
Reference in New Issue