mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Add truncated normal weight init (#935)
* [Feature] Add truncated normal weight init * [Feature] Add truncated normal weight init * [Feature] Add truncated normal weight init * update docstring * delete modelA.pth * modify according to comment * use kstest to check truncated normal * delete modelA.pth * fix test.txtpull/1045/head
parent
4bd3b5027a
commit
55b4847a41
|
@ -15,25 +15,27 @@ from .builder import MODELS, build_model_from_cfg
|
|||
# yapf: enable
|
||||
from .resnet import ResNet, make_res_layer
|
||||
from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit,
|
||||
NormalInit, PretrainedInit, UniformInit, XavierInit,
|
||||
bias_init_with_prob, caffe2_xavier_init, constant_init,
|
||||
fuse_conv_bn, get_model_complexity_info, initialize,
|
||||
kaiming_init, normal_init, uniform_init, xavier_init)
|
||||
NormalInit, PretrainedInit, TruncNormalInit, UniformInit,
|
||||
XavierInit, bias_init_with_prob, caffe2_xavier_init,
|
||||
constant_init, fuse_conv_bn, get_model_complexity_info,
|
||||
initialize, kaiming_init, normal_init, trunc_normal_init,
|
||||
uniform_init, xavier_init)
|
||||
from .vgg import VGG, make_vgg_layer
|
||||
|
||||
__all__ = [
|
||||
'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer',
|
||||
'constant_init', 'xavier_init', 'normal_init', 'uniform_init',
|
||||
'kaiming_init', 'caffe2_xavier_init', 'bias_init_with_prob', 'ConvModule',
|
||||
'build_activation_layer', 'build_conv_layer', 'build_norm_layer',
|
||||
'build_padding_layer', 'build_upsample_layer', 'build_plugin_layer',
|
||||
'is_norm', 'NonLocal1d', 'NonLocal2d', 'NonLocal3d', 'ContextBlock',
|
||||
'HSigmoid', 'Swish', 'HSwish', 'GeneralizedAttention', 'ACTIVATION_LAYERS',
|
||||
'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS',
|
||||
'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info', 'conv_ws_2d',
|
||||
'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule',
|
||||
'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d',
|
||||
'MaxPool3d', 'Conv3d', 'initialize', 'INITIALIZERS', 'ConstantInit',
|
||||
'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
|
||||
'constant_init', 'xavier_init', 'normal_init', 'trunc_normal_init',
|
||||
'uniform_init', 'kaiming_init', 'caffe2_xavier_init',
|
||||
'bias_init_with_prob', 'ConvModule', 'build_activation_layer',
|
||||
'build_conv_layer', 'build_norm_layer', 'build_padding_layer',
|
||||
'build_upsample_layer', 'build_plugin_layer', 'is_norm', 'NonLocal1d',
|
||||
'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'HSigmoid', 'Swish', 'HSwish',
|
||||
'GeneralizedAttention', 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS',
|
||||
'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale',
|
||||
'get_model_complexity_info', 'conv_ws_2d', 'ConvAWS2d', 'ConvWS2d',
|
||||
'fuse_conv_bn', 'DepthwiseSeparableConvModule', 'Linear', 'Conv2d',
|
||||
'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d', 'Conv3d',
|
||||
'initialize', 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
|
||||
'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
|
||||
'Caffe2XavierInit', 'MODELS', 'build_model_from_cfg'
|
||||
]
|
||||
|
|
|
@ -2,15 +2,17 @@
|
|||
from .flops_counter import get_model_complexity_info
|
||||
from .fuse_conv_bn import fuse_conv_bn
|
||||
from .weight_init import (INITIALIZERS, Caffe2XavierInit, ConstantInit,
|
||||
KaimingInit, NormalInit, PretrainedInit, UniformInit,
|
||||
XavierInit, bias_init_with_prob, caffe2_xavier_init,
|
||||
KaimingInit, NormalInit, PretrainedInit,
|
||||
TruncNormalInit, UniformInit, XavierInit,
|
||||
bias_init_with_prob, caffe2_xavier_init,
|
||||
constant_init, initialize, kaiming_init, normal_init,
|
||||
uniform_init, xavier_init)
|
||||
trunc_normal_init, uniform_init, xavier_init)
|
||||
|
||||
__all__ = [
|
||||
'get_model_complexity_info', 'bias_init_with_prob', 'caffe2_xavier_init',
|
||||
'constant_init', 'kaiming_init', 'normal_init', 'uniform_init',
|
||||
'xavier_init', 'fuse_conv_bn', 'initialize', 'INITIALIZERS',
|
||||
'ConstantInit', 'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit',
|
||||
'PretrainedInit', 'Caffe2XavierInit'
|
||||
'constant_init', 'kaiming_init', 'normal_init', 'trunc_normal_init',
|
||||
'uniform_init', 'xavier_init', 'fuse_conv_bn', 'initialize',
|
||||
'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
|
||||
'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
|
||||
'Caffe2XavierInit'
|
||||
]
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
# Copyright (c) Open-MMLab. All rights reserved.
|
||||
import copy
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from mmcv.utils import Registry, build_from_cfg, get_logger, print_log
|
||||
|
||||
|
@ -35,6 +38,18 @@ def normal_init(module, mean=0, std=1, bias=0):
|
|||
nn.init.constant_(module.bias, bias)
|
||||
|
||||
|
||||
def trunc_normal_init(module: nn.Module,
|
||||
mean: float = 0,
|
||||
std: float = 1,
|
||||
a: float = -2,
|
||||
b: float = 2,
|
||||
bias: float = 0) -> None:
|
||||
if hasattr(module, 'weight') and module.weight is not None:
|
||||
trunc_normal_(module.weight, mean, std, a, b) # type: ignore
|
||||
if hasattr(module, 'bias') and module.bias is not None:
|
||||
nn.init.constant_(module.bias, bias) # type: ignore
|
||||
|
||||
|
||||
def uniform_init(module, a=0, b=1, bias=0):
|
||||
if hasattr(module, 'weight') and module.weight is not None:
|
||||
nn.init.uniform_(module.weight, a, b)
|
||||
|
@ -211,6 +226,55 @@ class NormalInit(BaseInit):
|
|||
module.apply(init)
|
||||
|
||||
|
||||
@INITIALIZERS.register_module(name='TruncNormal')
|
||||
class TruncNormalInit(BaseInit):
|
||||
r"""Initialize module parameters with the values drawn from the normal
|
||||
distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values
|
||||
outside :math:`[a, b]`.
|
||||
|
||||
Args:
|
||||
mean (float): the mean of the normal distribution. Defaults to 0.
|
||||
std (float): the standard deviation of the normal distribution.
|
||||
Defaults to 1.
|
||||
a (float): The minimum cutoff value.
|
||||
b ( float): The maximum cutoff value.
|
||||
bias (float): the value to fill the bias or define
|
||||
initialization type for bias. Defaults to 0.
|
||||
bias_prob (float, optional): the probability for bias initialization.
|
||||
Defaults to None.
|
||||
layer (str | list[str], optional): the layer will be initialized.
|
||||
Defaults to None.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
mean: float = 0,
|
||||
std: float = 1,
|
||||
a: float = -2,
|
||||
b: float = 2,
|
||||
**kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
def __call__(self, module: nn.Module) -> None:
|
||||
|
||||
def init(m):
|
||||
if self.wholemodule:
|
||||
trunc_normal_init(m, self.mean, self.std, self.a, self.b,
|
||||
self.bias)
|
||||
else:
|
||||
layername = m.__class__.__name__
|
||||
for layer_ in self.layer:
|
||||
if layername == layer_:
|
||||
trunc_normal_init(m, self.mean, self.std, self.a,
|
||||
self.b, self.bias)
|
||||
|
||||
module.apply(init)
|
||||
|
||||
|
||||
@INITIALIZERS.register_module(name='Uniform')
|
||||
class UniformInit(BaseInit):
|
||||
r"""Initialize module parameters with values drawn from the uniform
|
||||
|
@ -468,3 +532,68 @@ def initialize(module, init_cfg):
|
|||
else:
|
||||
# All attributes in module have same initialization.
|
||||
pass
|
||||
|
||||
|
||||
def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float,
|
||||
b: float) -> Tensor:
|
||||
# Method based on
|
||||
# https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
||||
# Modified from
|
||||
# https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
|
||||
def norm_cdf(x):
|
||||
# Computes standard normal cumulative distribution function
|
||||
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
||||
|
||||
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
||||
warnings.warn(
|
||||
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
|
||||
'The distribution of values may be incorrect.',
|
||||
stacklevel=2)
|
||||
|
||||
with torch.no_grad():
|
||||
# Values are generated by using a truncated uniform distribution and
|
||||
# then using the inverse CDF for the normal distribution.
|
||||
# Get upper and lower cdf values
|
||||
lower = norm_cdf((a - mean) / std)
|
||||
upper = norm_cdf((b - mean) / std)
|
||||
|
||||
# Uniformly fill tensor with values from [lower, upper], then translate
|
||||
# to [2lower-1, 2upper-1].
|
||||
tensor.uniform_(2 * lower - 1, 2 * upper - 1)
|
||||
|
||||
# Use inverse cdf transform for normal distribution to get truncated
|
||||
# standard normal
|
||||
tensor.erfinv_()
|
||||
|
||||
# Transform to proper mean, std
|
||||
tensor.mul_(std * math.sqrt(2.))
|
||||
tensor.add_(mean)
|
||||
|
||||
# Clamp to ensure it's in the proper range
|
||||
tensor.clamp_(min=a, max=b)
|
||||
return tensor
|
||||
|
||||
|
||||
def trunc_normal_(tensor: Tensor,
|
||||
mean: float = 0.,
|
||||
std: float = 1.,
|
||||
a: float = -2.,
|
||||
b: float = 2.) -> Tensor:
|
||||
r"""Fills the input Tensor with values drawn from a truncated
|
||||
normal distribution. The values are effectively drawn from the
|
||||
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
||||
with values outside :math:`[a, b]` redrawn until they are within
|
||||
the bounds. The method used for generating the random values works
|
||||
best when :math:`a \leq \text{mean} \leq b`.
|
||||
|
||||
Modified from
|
||||
https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
|
||||
|
||||
Args:
|
||||
tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`.
|
||||
mean (float): the mean of the normal distribution.
|
||||
std (float): the standard deviation of the normal distribution.
|
||||
a (float): the minimum cutoff value.
|
||||
b (float): the maximum cutoff value.
|
||||
"""
|
||||
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
commit_id = '59b2b1c'
|
|
@ -5,4 +5,5 @@ onnxoptimizer
|
|||
onnxruntime==1.4.0
|
||||
pytest
|
||||
PyTurboJPEG
|
||||
scipy
|
||||
tiffile
|
||||
|
|
|
@ -14,6 +14,6 @@ line_length = 79
|
|||
multi_line_output = 0
|
||||
known_standard_library = pkg_resources,setuptools,logging,os,warnings,abc
|
||||
known_first_party = mmcv
|
||||
known_third_party = addict,cv2,m2r,numpy,onnx,onnxruntime,packaging,pytest,recommonmark,resnet_cifar,tensorrt,torch,torchvision,yaml,yapf
|
||||
known_third_party = addict,cv2,m2r,numpy,onnx,onnxruntime,packaging,pytest,recommonmark,resnet_cifar,scipy,tensorrt,torch,torchvision,yaml,yapf
|
||||
no_lines_before = STDLIB,LOCALFOLDER
|
||||
default_section = THIRDPARTY
|
||||
|
|
|
@ -1,16 +1,18 @@
|
|||
# Copyright (c) Open-MMLab. All rights reserved.
|
||||
import random
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from scipy import stats
|
||||
from torch import nn
|
||||
|
||||
from mmcv.cnn import (Caffe2XavierInit, ConstantInit, KaimingInit, NormalInit,
|
||||
PretrainedInit, UniformInit, XavierInit,
|
||||
PretrainedInit, TruncNormalInit, UniformInit, XavierInit,
|
||||
bias_init_with_prob, caffe2_xavier_init, constant_init,
|
||||
initialize, kaiming_init, normal_init, uniform_init,
|
||||
xavier_init)
|
||||
initialize, kaiming_init, normal_init, trunc_normal_init,
|
||||
uniform_init, xavier_init)
|
||||
|
||||
|
||||
def test_constant_init():
|
||||
|
@ -47,6 +49,35 @@ def test_normal_init():
|
|||
# TODO: sanity check distribution, e.g. mean, std
|
||||
|
||||
|
||||
def test_trunc_normal_init():
|
||||
|
||||
def _random_float(a, b):
|
||||
return (b - a) * random.random() + a
|
||||
|
||||
def _is_trunc_normal(tensor, mean, std, a, b):
|
||||
# scipy's trunc norm is suited for data drawn from N(0, 1),
|
||||
# so we need to transform our data to test it using scipy.
|
||||
z_samples = (tensor.view(-1) - mean) / std
|
||||
z_samples = z_samples.tolist()
|
||||
a0 = (a - mean) / std
|
||||
b0 = (b - mean) / std
|
||||
p_value = stats.kstest(z_samples, 'truncnorm', args=(a0, b0))[1]
|
||||
return p_value > 0.0001
|
||||
|
||||
conv_module = nn.Conv2d(3, 16, 3)
|
||||
mean = _random_float(-3, 3)
|
||||
std = _random_float(.01, 1)
|
||||
a = _random_float(mean - 2 * std, mean)
|
||||
b = _random_float(mean, mean + 2 * std)
|
||||
trunc_normal_init(conv_module, mean, std, a, b, bias=0.1)
|
||||
assert _is_trunc_normal(conv_module.weight, mean, std, a, b)
|
||||
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
|
||||
|
||||
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
|
||||
trunc_normal_init(conv_module_no_bias)
|
||||
# TODO: sanity check distribution, e.g. mean, std
|
||||
|
||||
|
||||
def test_uniform_init():
|
||||
conv_module = nn.Conv2d(3, 16, 3)
|
||||
uniform_init(conv_module, bias=0.1)
|
||||
|
@ -168,6 +199,33 @@ def test_normalinit():
|
|||
assert model[2].bias.allclose(torch.tensor(res))
|
||||
|
||||
|
||||
def test_truncnormalinit():
|
||||
"""test TruncNormalInit class."""
|
||||
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
|
||||
|
||||
func = TruncNormalInit(
|
||||
mean=100, std=1e-5, bias=200, a=0, b=200, layer=['Conv2d', 'Linear'])
|
||||
func(model)
|
||||
assert model[0].weight.allclose(torch.tensor(100.))
|
||||
assert model[2].weight.allclose(torch.tensor(100.))
|
||||
assert model[0].bias.allclose(torch.tensor(200.))
|
||||
assert model[2].bias.allclose(torch.tensor(200.))
|
||||
|
||||
func = TruncNormalInit(
|
||||
mean=300,
|
||||
std=1e-5,
|
||||
a=100,
|
||||
b=400,
|
||||
bias_prob=0.01,
|
||||
layer=['Conv2d', 'Linear'])
|
||||
res = bias_init_with_prob(0.01)
|
||||
func(model)
|
||||
assert model[0].weight.allclose(torch.tensor(300.))
|
||||
assert model[2].weight.allclose(torch.tensor(300.))
|
||||
assert model[0].bias.allclose(torch.tensor(res))
|
||||
assert model[2].bias.allclose(torch.tensor(res))
|
||||
|
||||
|
||||
def test_uniforminit():
|
||||
""""test UniformInit class."""
|
||||
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
|
||||
|
|
Loading…
Reference in New Issue