mmcv/tests/test_cnn/test_weight_init.py
Miao Zheng 2fadb1a5ec
[Fix] Revise override in init_cfg (#930)
* [Fix] Config deep copy in initialize_override

* add asserts&comments

* add test

* test org init_cfg

* test override without name

* typo
2021-04-12 14:05:28 +08:00

422 lines
17 KiB
Python

# Copyright (c) Open-MMLab. All rights reserved.
from tempfile import TemporaryDirectory
import numpy as np
import pytest
import torch
from torch import nn
from mmcv.cnn import (Caffe2XavierInit, ConstantInit, KaimingInit, NormalInit,
PretrainedInit, UniformInit, XavierInit,
bias_init_with_prob, caffe2_xavier_init, constant_init,
initialize, kaiming_init, normal_init, uniform_init,
xavier_init)
def test_constant_init():
conv_module = nn.Conv2d(3, 16, 3)
constant_init(conv_module, 0.1)
assert conv_module.weight.allclose(
torch.full_like(conv_module.weight, 0.1))
assert conv_module.bias.allclose(torch.zeros_like(conv_module.bias))
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
constant_init(conv_module_no_bias, 0.1)
assert conv_module.weight.allclose(
torch.full_like(conv_module.weight, 0.1))
def test_xavier_init():
conv_module = nn.Conv2d(3, 16, 3)
xavier_init(conv_module, bias=0.1)
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
xavier_init(conv_module, distribution='uniform')
# TODO: sanity check of weight distribution, e.g. mean, std
with pytest.raises(AssertionError):
xavier_init(conv_module, distribution='student-t')
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
xavier_init(conv_module_no_bias)
def test_normal_init():
conv_module = nn.Conv2d(3, 16, 3)
normal_init(conv_module, bias=0.1)
# TODO: sanity check of weight distribution, e.g. mean, std
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
normal_init(conv_module_no_bias)
# TODO: sanity check distribution, e.g. mean, std
def test_uniform_init():
conv_module = nn.Conv2d(3, 16, 3)
uniform_init(conv_module, bias=0.1)
# TODO: sanity check of weight distribution, e.g. mean, std
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
uniform_init(conv_module_no_bias)
def test_kaiming_init():
conv_module = nn.Conv2d(3, 16, 3)
kaiming_init(conv_module, bias=0.1)
# TODO: sanity check of weight distribution, e.g. mean, std
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
kaiming_init(conv_module, distribution='uniform')
with pytest.raises(AssertionError):
kaiming_init(conv_module, distribution='student-t')
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
kaiming_init(conv_module_no_bias)
def test_caffe_xavier_init():
conv_module = nn.Conv2d(3, 16, 3)
caffe2_xavier_init(conv_module)
def test_bias_init_with_prob():
conv_module = nn.Conv2d(3, 16, 3)
prior_prob = 0.1
normal_init(conv_module, bias=bias_init_with_prob(0.1))
# TODO: sanity check of weight distribution, e.g. mean, std
bias = float(-np.log((1 - prior_prob) / prior_prob))
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, bias))
def test_constaninit():
"""test ConstantInit class."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = ConstantInit(val=1, bias=2, layer='Conv2d')
func(model)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert not torch.equal(model[2].weight,
torch.full(model[2].weight.shape, 1.))
assert not torch.equal(model[2].bias, torch.full(model[2].bias.shape, 2.))
func = ConstantInit(val=3, bias_prob=0.01, layer='Linear')
func(model)
res = bias_init_with_prob(0.01)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 3.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, res))
# test bias input type
with pytest.raises(TypeError):
func = ConstantInit(val=1, bias='1')
# test bias_prob type
with pytest.raises(TypeError):
func = ConstantInit(val=1, bias_prob='1')
# test layer input type
with pytest.raises(TypeError):
func = ConstantInit(val=1, layer=1)
def test_xavierinit():
"""test XavierInit class."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = XavierInit(bias=0.1, layer='Conv2d')
func(model)
assert model[0].bias.allclose(torch.full_like(model[2].bias, 0.1))
assert not model[2].bias.allclose(torch.full_like(model[0].bias, 0.1))
constant_func = ConstantInit(val=0, bias=0, layer=['Conv2d', 'Linear'])
func = XavierInit(gain=100, bias_prob=0.01, layer=['Conv2d', 'Linear'])
model.apply(constant_func)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.))
res = bias_init_with_prob(0.01)
func(model)
assert not torch.equal(model[0].weight,
torch.full(model[0].weight.shape, 0.))
assert not torch.equal(model[2].weight,
torch.full(model[2].weight.shape, 0.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, res))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, res))
# test bias input type
with pytest.raises(TypeError):
func = XavierInit(bias='0.1', layer='Conv2d')
# test layer inpur type
with pytest.raises(TypeError):
func = XavierInit(bias=0.1, layer=1)
def test_normalinit():
"""test Normalinit class."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = NormalInit(mean=100, std=1e-5, bias=200, layer=['Conv2d', 'Linear'])
func(model)
assert model[0].weight.allclose(torch.tensor(100.))
assert model[2].weight.allclose(torch.tensor(100.))
assert model[0].bias.allclose(torch.tensor(200.))
assert model[2].bias.allclose(torch.tensor(200.))
func = NormalInit(
mean=300, std=1e-5, bias_prob=0.01, layer=['Conv2d', 'Linear'])
res = bias_init_with_prob(0.01)
func(model)
assert model[0].weight.allclose(torch.tensor(300.))
assert model[2].weight.allclose(torch.tensor(300.))
assert model[0].bias.allclose(torch.tensor(res))
assert model[2].bias.allclose(torch.tensor(res))
def test_uniforminit():
""""test UniformInit class."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = UniformInit(a=1, b=1, bias=2, layer=['Conv2d', 'Linear'])
func(model)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 1.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 2.))
func = UniformInit(a=100, b=100, layer=['Conv2d', 'Linear'], bias=10)
func(model)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape,
100.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape,
100.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))
def test_kaiminginit():
"""test KaimingInit class."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = KaimingInit(bias=0.1, layer='Conv2d')
func(model)
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.1))
assert not torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.1))
func = KaimingInit(a=100, bias=10, layer=['Conv2d', 'Linear'])
constant_func = ConstantInit(val=0, bias=0, layer=['Conv2d', 'Linear'])
model.apply(constant_func)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.))
func(model)
assert not torch.equal(model[0].weight,
torch.full(model[0].weight.shape, 0.))
assert not torch.equal(model[2].weight,
torch.full(model[2].weight.shape, 0.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))
def test_caffe2xavierinit():
"""test Caffe2XavierInit."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = Caffe2XavierInit(bias=0.1, layer='Conv2d')
func(model)
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.1))
assert not torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.1))
class FooModule(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1, 2)
self.conv2d = nn.Conv2d(3, 1, 3)
self.conv2d_2 = nn.Conv2d(3, 2, 3)
def test_pretrainedinit():
"""test PretrainedInit class."""
modelA = FooModule()
constant_func = ConstantInit(val=1, bias=2, layer=['Conv2d', 'Linear'])
modelA.apply(constant_func)
modelB = FooModule()
funcB = PretrainedInit(checkpoint='modelA.pth')
modelC = nn.Linear(1, 2)
funcC = PretrainedInit(checkpoint='modelA.pth', prefix='linear.')
with TemporaryDirectory():
torch.save(modelA.state_dict(), 'modelA.pth')
funcB(modelB)
assert torch.equal(modelB.linear.weight,
torch.full(modelB.linear.weight.shape, 1.))
assert torch.equal(modelB.linear.bias,
torch.full(modelB.linear.bias.shape, 2.))
assert torch.equal(modelB.conv2d.weight,
torch.full(modelB.conv2d.weight.shape, 1.))
assert torch.equal(modelB.conv2d.bias,
torch.full(modelB.conv2d.bias.shape, 2.))
assert torch.equal(modelB.conv2d_2.weight,
torch.full(modelB.conv2d_2.weight.shape, 1.))
assert torch.equal(modelB.conv2d_2.bias,
torch.full(modelB.conv2d_2.bias.shape, 2.))
funcC(modelC)
assert torch.equal(modelC.weight, torch.full(modelC.weight.shape, 1.))
assert torch.equal(modelC.bias, torch.full(modelC.bias.shape, 2.))
def test_initialize():
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
foonet = FooModule()
# test layer key
init_cfg = dict(type='Constant', layer=['Conv2d', 'Linear'], val=1, bias=2)
initialize(model, init_cfg)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 1.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 2.))
assert init_cfg == dict(
type='Constant', layer=['Conv2d', 'Linear'], val=1, bias=2)
# test init_cfg with list type
init_cfg = [
dict(type='Constant', layer='Conv2d', val=1, bias=2),
dict(type='Constant', layer='Linear', val=3, bias=4)
]
initialize(model, init_cfg)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 3.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 4.))
assert init_cfg == [
dict(type='Constant', layer='Conv2d', val=1, bias=2),
dict(type='Constant', layer='Linear', val=3, bias=4)
]
# test layer key and override key
init_cfg = dict(
type='Constant',
val=1,
bias=2,
layer=['Conv2d', 'Linear'],
override=dict(type='Constant', name='conv2d_2', val=3, bias=4))
initialize(foonet, init_cfg)
assert torch.equal(foonet.linear.weight,
torch.full(foonet.linear.weight.shape, 1.))
assert torch.equal(foonet.linear.bias,
torch.full(foonet.linear.bias.shape, 2.))
assert torch.equal(foonet.conv2d.weight,
torch.full(foonet.conv2d.weight.shape, 1.))
assert torch.equal(foonet.conv2d.bias,
torch.full(foonet.conv2d.bias.shape, 2.))
assert torch.equal(foonet.conv2d_2.weight,
torch.full(foonet.conv2d_2.weight.shape, 3.))
assert torch.equal(foonet.conv2d_2.bias,
torch.full(foonet.conv2d_2.bias.shape, 4.))
assert init_cfg == dict(
type='Constant',
val=1,
bias=2,
layer=['Conv2d', 'Linear'],
override=dict(type='Constant', name='conv2d_2', val=3, bias=4))
# test override key
init_cfg = dict(
type='Constant', val=5, bias=6, override=dict(name='conv2d_2'))
initialize(foonet, init_cfg)
assert not torch.equal(foonet.linear.weight,
torch.full(foonet.linear.weight.shape, 5.))
assert not torch.equal(foonet.linear.bias,
torch.full(foonet.linear.bias.shape, 6.))
assert not torch.equal(foonet.conv2d.weight,
torch.full(foonet.conv2d.weight.shape, 5.))
assert not torch.equal(foonet.conv2d.bias,
torch.full(foonet.conv2d.bias.shape, 6.))
assert torch.equal(foonet.conv2d_2.weight,
torch.full(foonet.conv2d_2.weight.shape, 5.))
assert torch.equal(foonet.conv2d_2.bias,
torch.full(foonet.conv2d_2.bias.shape, 6.))
assert init_cfg == dict(
type='Constant', val=5, bias=6, override=dict(name='conv2d_2'))
init_cfg = dict(
type='Pretrained',
checkpoint='modelA.pth',
override=dict(type='Constant', name='conv2d_2', val=3, bias=4))
modelA = FooModule()
constant_func = ConstantInit(val=1, bias=2, layer=['Conv2d', 'Linear'])
modelA.apply(constant_func)
with TemporaryDirectory():
torch.save(modelA.state_dict(), 'modelA.pth')
initialize(foonet, init_cfg)
assert torch.equal(foonet.linear.weight,
torch.full(foonet.linear.weight.shape, 1.))
assert torch.equal(foonet.linear.bias,
torch.full(foonet.linear.bias.shape, 2.))
assert torch.equal(foonet.conv2d.weight,
torch.full(foonet.conv2d.weight.shape, 1.))
assert torch.equal(foonet.conv2d.bias,
torch.full(foonet.conv2d.bias.shape, 2.))
assert torch.equal(foonet.conv2d_2.weight,
torch.full(foonet.conv2d_2.weight.shape, 3.))
assert torch.equal(foonet.conv2d_2.bias,
torch.full(foonet.conv2d_2.bias.shape, 4.))
assert init_cfg == dict(
type='Pretrained',
checkpoint='modelA.pth',
override=dict(type='Constant', name='conv2d_2', val=3, bias=4))
# test init_cfg type
with pytest.raises(TypeError):
init_cfg = 'init_cfg'
initialize(foonet, init_cfg)
# test override value type
with pytest.raises(TypeError):
init_cfg = dict(
type='Constant',
val=1,
bias=2,
layer=['Conv2d', 'Linear'],
override='conv')
initialize(foonet, init_cfg)
# test override name
with pytest.raises(RuntimeError):
init_cfg = dict(
type='Constant',
val=1,
bias=2,
layer=['Conv2d', 'Linear'],
override=dict(type='Constant', name='conv2d_3', val=3, bias=4))
initialize(foonet, init_cfg)
# test list override name
with pytest.raises(RuntimeError):
init_cfg = dict(
type='Constant',
val=1,
bias=2,
layer=['Conv2d', 'Linear'],
override=[
dict(type='Constant', name='conv2d', val=3, bias=4),
dict(type='Constant', name='conv2d_3', val=5, bias=6)
])
initialize(foonet, init_cfg)
# test override with args except type key
with pytest.raises(ValueError):
init_cfg = dict(
type='Constant',
val=1,
bias=2,
override=dict(name='conv2d_2', val=3, bias=4))
initialize(foonet, init_cfg)
# test override without name
with pytest.raises(ValueError):
init_cfg = dict(
type='Constant',
val=1,
bias=2,
override=dict(type='Constant', val=3, bias=4))
initialize(foonet, init_cfg)