[Refactoring] Add Caffe2Xavier Initializer (#902)

* [Refactoring] Add Caffe2Xavier Initializer

* fix lint
pull/890/head
Miao Zheng 2021-03-24 13:25:36 +08:00 committed by GitHub
parent 933b052d95
commit 5f5e8e83c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 38 additions and 11 deletions

View File

@ -13,8 +13,8 @@ from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
build_upsample_layer, conv_ws_2d, is_norm)
# yapf: enable
from .resnet import ResNet, make_res_layer
from .utils import (INITIALIZERS, ConstantInit, KaimingInit, NormalInit,
PretrainedInit, UniformInit, XavierInit,
from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit,
NormalInit, PretrainedInit, UniformInit, XavierInit,
bias_init_with_prob, caffe2_xavier_init, constant_init,
fuse_conv_bn, get_model_complexity_info, initialize,
kaiming_init, normal_init, uniform_init, xavier_init)
@ -33,5 +33,6 @@ __all__ = [
'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule',
'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d',
'MaxPool3d', 'Conv3d', 'initialize', 'INITIALIZERS', 'ConstantInit',
'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit'
'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
'Caffe2XavierInit'
]

View File

@ -1,9 +1,9 @@
# Copyright (c) Open-MMLab. All rights reserved.
from .flops_counter import get_model_complexity_info
from .fuse_conv_bn import fuse_conv_bn
from .weight_init import (INITIALIZERS, ConstantInit, KaimingInit, NormalInit,
PretrainedInit, UniformInit, XavierInit,
bias_init_with_prob, caffe2_xavier_init,
from .weight_init import (INITIALIZERS, Caffe2XavierInit, ConstantInit,
KaimingInit, NormalInit, PretrainedInit, UniformInit,
XavierInit, bias_init_with_prob, caffe2_xavier_init,
constant_init, initialize, kaiming_init, normal_init,
uniform_init, xavier_init)
@ -12,5 +12,5 @@ __all__ = [
'constant_init', 'kaiming_init', 'normal_init', 'uniform_init',
'xavier_init', 'fuse_conv_bn', 'initialize', 'INITIALIZERS',
'ConstantInit', 'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit',
'PretrainedInit'
'PretrainedInit', 'Caffe2XavierInit'
]

View File

@ -298,6 +298,22 @@ class KaimingInit(BaseInit):
module.apply(init)
@INITIALIZERS.register_module(name='Caffe2Xavier')
class Caffe2XavierInit(KaimingInit):
# `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
# Acknowledgment to FAIR's internal code
def __init__(self, **kwargs):
super().__init__(
a=1,
mode='fan_in',
nonlinearity='leaky_relu',
distribution='uniform',
**kwargs)
def __call__(self, module):
super().__call__(module)
@INITIALIZERS.register_module(name='Pretrained')
class PretrainedInit(object):
"""Initialize module by loading a pretrained model.

View File

@ -6,10 +6,11 @@ import pytest
import torch
from torch import nn
from mmcv.cnn import (ConstantInit, KaimingInit, NormalInit, PretrainedInit,
UniformInit, XavierInit, bias_init_with_prob,
caffe2_xavier_init, constant_init, initialize,
kaiming_init, normal_init, uniform_init, xavier_init)
from mmcv.cnn import (Caffe2XavierInit, ConstantInit, KaimingInit, NormalInit,
PretrainedInit, UniformInit, XavierInit,
bias_init_with_prob, caffe2_xavier_init, constant_init,
initialize, kaiming_init, normal_init, uniform_init,
xavier_init)
def test_constant_init():
@ -219,6 +220,15 @@ def test_kaiminginit():
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))
def test_caffe2xavierinit():
"""test Caffe2XavierInit."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = Caffe2XavierInit(bias=0.1, layer='Conv2d')
func(model)
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.1))
assert not torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.1))
class FooModule(nn.Module):
def __init__(self):