mmcv/tests/test_cnn/test_weight_init.py

362 lines
14 KiB
Python

# Copyright (c) Open-MMLab. All rights reserved.
from tempfile import TemporaryDirectory
import numpy as np
import pytest
import torch
from torch import nn
from mmcv.cnn import (ConstantInit, KaimingInit, NormalInit, PretrainedInit,
UniformInit, XavierInit, bias_init_with_prob,
caffe2_xavier_init, constant_init, initialize,
kaiming_init, normal_init, uniform_init, xavier_init)
def test_constant_init():
conv_module = nn.Conv2d(3, 16, 3)
constant_init(conv_module, 0.1)
assert conv_module.weight.allclose(
torch.full_like(conv_module.weight, 0.1))
assert conv_module.bias.allclose(torch.zeros_like(conv_module.bias))
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
constant_init(conv_module_no_bias, 0.1)
assert conv_module.weight.allclose(
torch.full_like(conv_module.weight, 0.1))
def test_xavier_init():
conv_module = nn.Conv2d(3, 16, 3)
xavier_init(conv_module, bias=0.1)
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
xavier_init(conv_module, distribution='uniform')
# TODO: sanity check of weight distribution, e.g. mean, std
with pytest.raises(AssertionError):
xavier_init(conv_module, distribution='student-t')
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
xavier_init(conv_module_no_bias)
def test_normal_init():
conv_module = nn.Conv2d(3, 16, 3)
normal_init(conv_module, bias=0.1)
# TODO: sanity check of weight distribution, e.g. mean, std
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
normal_init(conv_module_no_bias)
# TODO: sanity check distribution, e.g. mean, std
def test_uniform_init():
conv_module = nn.Conv2d(3, 16, 3)
uniform_init(conv_module, bias=0.1)
# TODO: sanity check of weight distribution, e.g. mean, std
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
uniform_init(conv_module_no_bias)
def test_kaiming_init():
conv_module = nn.Conv2d(3, 16, 3)
kaiming_init(conv_module, bias=0.1)
# TODO: sanity check of weight distribution, e.g. mean, std
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
kaiming_init(conv_module, distribution='uniform')
with pytest.raises(AssertionError):
kaiming_init(conv_module, distribution='student-t')
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
kaiming_init(conv_module_no_bias)
def test_caffe_xavier_init():
conv_module = nn.Conv2d(3, 16, 3)
caffe2_xavier_init(conv_module)
def test_bias_init_with_prob():
conv_module = nn.Conv2d(3, 16, 3)
prior_prob = 0.1
normal_init(conv_module, bias=bias_init_with_prob(0.1))
# TODO: sanity check of weight distribution, e.g. mean, std
bias = float(-np.log((1 - prior_prob) / prior_prob))
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, bias))
def test_constaninit():
"""test ConstantInit class."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = ConstantInit(val=1, bias=2, layer='Conv2d')
func(model)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert not torch.equal(model[2].weight,
torch.full(model[2].weight.shape, 1.))
assert not torch.equal(model[2].bias, torch.full(model[2].bias.shape, 2.))
func = ConstantInit(val=3, bias_prob=0.01, layer='Linear')
func(model)
res = bias_init_with_prob(0.01)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 3.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, res))
func = ConstantInit(val=4, bias=5)
func(model)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 4.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 4.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 5.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 5.))
# test bias input type
with pytest.raises(TypeError):
func = ConstantInit(val=1, bias='1')
# test bias_prob type
with pytest.raises(TypeError):
func = ConstantInit(val=1, bias_prob='1')
# test layer input type
with pytest.raises(TypeError):
func = ConstantInit(val=1, layer=1)
def test_xavierinit():
"""test XavierInit class."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = XavierInit(bias=0.1, layer='Conv2d')
func(model)
assert model[0].bias.allclose(torch.full_like(model[2].bias, 0.1))
assert not model[2].bias.allclose(torch.full_like(model[0].bias, 0.1))
constant_func = ConstantInit(val=0, bias=0)
func = XavierInit(gain=100, bias_prob=0.01)
model.apply(constant_func)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.))
res = bias_init_with_prob(0.01)
func(model)
assert not torch.equal(model[0].weight,
torch.full(model[0].weight.shape, 0.))
assert not torch.equal(model[2].weight,
torch.full(model[2].weight.shape, 0.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, res))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, res))
# test bias input type
with pytest.raises(TypeError):
func = XavierInit(bias='0.1', layer='Conv2d')
# test layer inpur type
with pytest.raises(TypeError):
func = XavierInit(bias=0.1, layer=1)
def test_normalinit():
"""test Normalinit class."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = NormalInit(mean=100, std=1e-5, bias=200)
func(model)
assert model[0].weight.allclose(torch.tensor(100.))
assert model[2].weight.allclose(torch.tensor(100.))
assert model[0].bias.allclose(torch.tensor(200.))
assert model[2].bias.allclose(torch.tensor(200.))
func = NormalInit(
mean=300, std=1e-5, bias_prob=0.01, layer=['Conv2d', 'Linear'])
res = bias_init_with_prob(0.01)
func(model)
assert model[0].weight.allclose(torch.tensor(300.))
assert model[2].weight.allclose(torch.tensor(300.))
assert model[0].bias.allclose(torch.tensor(res))
assert model[2].bias.allclose(torch.tensor(res))
def test_uniforminit():
""""test UniformInit class."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = UniformInit(a=1, b=1, bias=2)
func(model)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 1.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 2.))
func = UniformInit(a=100, b=100, layer=['Conv2d', 'Linear'], bias=10)
func(model)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape,
100.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape,
100.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))
def test_kaiminginit():
"""test KaimingInit class."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = KaimingInit(bias=0.1, layer='Conv2d')
func(model)
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.1))
assert not torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.1))
func = KaimingInit(a=100, bias=10)
constant_func = ConstantInit(val=0, bias=0)
model.apply(constant_func)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.))
func(model)
assert not torch.equal(model[0].weight,
torch.full(model[0].weight.shape, 0.))
assert not torch.equal(model[2].weight,
torch.full(model[2].weight.shape, 0.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))
class FooModule(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1, 2)
self.conv2d = nn.Conv2d(3, 1, 3)
self.conv2d_2 = nn.Conv2d(3, 2, 3)
def test_pretrainedinit():
"""test PretrainedInit class."""
modelA = FooModule()
constant_func = ConstantInit(val=1, bias=2)
modelA.apply(constant_func)
modelB = FooModule()
funcB = PretrainedInit(checkpoint='modelA.pth')
modelC = nn.Linear(1, 2)
funcC = PretrainedInit(checkpoint='modelA.pth', prefix='linear.')
with TemporaryDirectory():
torch.save(modelA.state_dict(), 'modelA.pth')
funcB(modelB)
assert torch.equal(modelB.linear.weight,
torch.full(modelB.linear.weight.shape, 1.))
assert torch.equal(modelB.linear.bias,
torch.full(modelB.linear.bias.shape, 2.))
assert torch.equal(modelB.conv2d.weight,
torch.full(modelB.conv2d.weight.shape, 1.))
assert torch.equal(modelB.conv2d.bias,
torch.full(modelB.conv2d.bias.shape, 2.))
assert torch.equal(modelB.conv2d_2.weight,
torch.full(modelB.conv2d_2.weight.shape, 1.))
assert torch.equal(modelB.conv2d_2.bias,
torch.full(modelB.conv2d_2.bias.shape, 2.))
funcC(modelC)
assert torch.equal(modelC.weight, torch.full(modelC.weight.shape, 1.))
assert torch.equal(modelC.bias, torch.full(modelC.bias.shape, 2.))
def test_initialize():
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
foonet = FooModule()
init_cfg = dict(type='Constant', val=1, bias=2)
initialize(model, init_cfg)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 1.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 2.))
init_cfg = [
dict(type='Constant', layer='Conv1d', val=1, bias=2),
dict(type='Constant', layer='Linear', val=3, bias=4)
]
initialize(model, init_cfg)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 3.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 4.))
init_cfg = dict(
type='Constant',
val=1,
bias=2,
layer=['Conv2d', 'Linear'],
override=dict(type='Constant', name='conv2d_2', val=3, bias=4))
initialize(foonet, init_cfg)
assert torch.equal(foonet.linear.weight,
torch.full(foonet.linear.weight.shape, 1.))
assert torch.equal(foonet.linear.bias,
torch.full(foonet.linear.bias.shape, 2.))
assert torch.equal(foonet.conv2d.weight,
torch.full(foonet.conv2d.weight.shape, 1.))
assert torch.equal(foonet.conv2d.bias,
torch.full(foonet.conv2d.bias.shape, 2.))
assert torch.equal(foonet.conv2d_2.weight,
torch.full(foonet.conv2d_2.weight.shape, 3.))
assert torch.equal(foonet.conv2d_2.bias,
torch.full(foonet.conv2d_2.bias.shape, 4.))
init_cfg = dict(
type='Pretrained',
checkpoint='modelA.pth',
override=dict(type='Constant', name='conv2d_2', val=3, bias=4))
modelA = FooModule()
constant_func = ConstantInit(val=1, bias=2)
modelA.apply(constant_func)
with TemporaryDirectory():
torch.save(modelA.state_dict(), 'modelA.pth')
initialize(foonet, init_cfg)
assert torch.equal(foonet.linear.weight,
torch.full(foonet.linear.weight.shape, 1.))
assert torch.equal(foonet.linear.bias,
torch.full(foonet.linear.bias.shape, 2.))
assert torch.equal(foonet.conv2d.weight,
torch.full(foonet.conv2d.weight.shape, 1.))
assert torch.equal(foonet.conv2d.bias,
torch.full(foonet.conv2d.bias.shape, 2.))
assert torch.equal(foonet.conv2d_2.weight,
torch.full(foonet.conv2d_2.weight.shape, 3.))
assert torch.equal(foonet.conv2d_2.bias,
torch.full(foonet.conv2d_2.bias.shape, 4.))
# test init_cfg type
with pytest.raises(TypeError):
init_cfg = 'init_cfg'
initialize(foonet, init_cfg)
# test override value type
with pytest.raises(TypeError):
init_cfg = dict(
type='Constant',
val=1,
bias=2,
layer=['Conv2d', 'Linear'],
override='conv')
initialize(foonet, init_cfg)
# test override name
with pytest.raises(RuntimeError):
init_cfg = dict(
type='Constant',
val=1,
bias=2,
layer=['Conv2d', 'Linear'],
override=dict(type='Constant', name='conv2d_3', val=3, bias=4))
initialize(foonet, init_cfg)
# test list override name
with pytest.raises(RuntimeError):
init_cfg = dict(
type='Constant',
val=1,
bias=2,
layer=['Conv2d', 'Linear'],
override=[
dict(type='Constant', name='conv2d', val=3, bias=4),
dict(type='Constant', name='conv2d_3', val=5, bias=6)
])
initialize(foonet, init_cfg)