mirror of https://github.com/open-mmlab/mmcv.git
[Refactoring] Add Caffe2Xavier Initializer (#902)
* [Refactoring] Add Caffe2Xavier Initializer * fix lintpull/890/head
parent
933b052d95
commit
5f5e8e83c2
|
@ -13,8 +13,8 @@ from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
|
|||
build_upsample_layer, conv_ws_2d, is_norm)
|
||||
# yapf: enable
|
||||
from .resnet import ResNet, make_res_layer
|
||||
from .utils import (INITIALIZERS, ConstantInit, KaimingInit, NormalInit,
|
||||
PretrainedInit, UniformInit, XavierInit,
|
||||
from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit,
|
||||
NormalInit, PretrainedInit, UniformInit, XavierInit,
|
||||
bias_init_with_prob, caffe2_xavier_init, constant_init,
|
||||
fuse_conv_bn, get_model_complexity_info, initialize,
|
||||
kaiming_init, normal_init, uniform_init, xavier_init)
|
||||
|
@ -33,5 +33,6 @@ __all__ = [
|
|||
'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule',
|
||||
'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d',
|
||||
'MaxPool3d', 'Conv3d', 'initialize', 'INITIALIZERS', 'ConstantInit',
|
||||
'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit'
|
||||
'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
|
||||
'Caffe2XavierInit'
|
||||
]
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
# Copyright (c) Open-MMLab. All rights reserved.
|
||||
from .flops_counter import get_model_complexity_info
|
||||
from .fuse_conv_bn import fuse_conv_bn
|
||||
from .weight_init import (INITIALIZERS, ConstantInit, KaimingInit, NormalInit,
|
||||
PretrainedInit, UniformInit, XavierInit,
|
||||
bias_init_with_prob, caffe2_xavier_init,
|
||||
from .weight_init import (INITIALIZERS, Caffe2XavierInit, ConstantInit,
|
||||
KaimingInit, NormalInit, PretrainedInit, UniformInit,
|
||||
XavierInit, bias_init_with_prob, caffe2_xavier_init,
|
||||
constant_init, initialize, kaiming_init, normal_init,
|
||||
uniform_init, xavier_init)
|
||||
|
||||
|
@ -12,5 +12,5 @@ __all__ = [
|
|||
'constant_init', 'kaiming_init', 'normal_init', 'uniform_init',
|
||||
'xavier_init', 'fuse_conv_bn', 'initialize', 'INITIALIZERS',
|
||||
'ConstantInit', 'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit',
|
||||
'PretrainedInit'
|
||||
'PretrainedInit', 'Caffe2XavierInit'
|
||||
]
|
||||
|
|
|
@ -298,6 +298,22 @@ class KaimingInit(BaseInit):
|
|||
module.apply(init)
|
||||
|
||||
|
||||
@INITIALIZERS.register_module(name='Caffe2Xavier')
|
||||
class Caffe2XavierInit(KaimingInit):
|
||||
# `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
|
||||
# Acknowledgment to FAIR's internal code
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(
|
||||
a=1,
|
||||
mode='fan_in',
|
||||
nonlinearity='leaky_relu',
|
||||
distribution='uniform',
|
||||
**kwargs)
|
||||
|
||||
def __call__(self, module):
|
||||
super().__call__(module)
|
||||
|
||||
|
||||
@INITIALIZERS.register_module(name='Pretrained')
|
||||
class PretrainedInit(object):
|
||||
"""Initialize module by loading a pretrained model.
|
||||
|
|
|
@ -6,10 +6,11 @@ import pytest
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from mmcv.cnn import (ConstantInit, KaimingInit, NormalInit, PretrainedInit,
|
||||
UniformInit, XavierInit, bias_init_with_prob,
|
||||
caffe2_xavier_init, constant_init, initialize,
|
||||
kaiming_init, normal_init, uniform_init, xavier_init)
|
||||
from mmcv.cnn import (Caffe2XavierInit, ConstantInit, KaimingInit, NormalInit,
|
||||
PretrainedInit, UniformInit, XavierInit,
|
||||
bias_init_with_prob, caffe2_xavier_init, constant_init,
|
||||
initialize, kaiming_init, normal_init, uniform_init,
|
||||
xavier_init)
|
||||
|
||||
|
||||
def test_constant_init():
|
||||
|
@ -219,6 +220,15 @@ def test_kaiminginit():
|
|||
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))
|
||||
|
||||
|
||||
def test_caffe2xavierinit():
|
||||
"""test Caffe2XavierInit."""
|
||||
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
|
||||
func = Caffe2XavierInit(bias=0.1, layer='Conv2d')
|
||||
func(model)
|
||||
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.1))
|
||||
assert not torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.1))
|
||||
|
||||
|
||||
class FooModule(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
Loading…
Reference in New Issue