mirror of https://github.com/open-mmlab/mmcv.git
563 lines
22 KiB
Python
563 lines
22 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import random
|
|
from tempfile import TemporaryDirectory
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
from scipy import stats
|
|
from torch import nn
|
|
|
|
from mmcv.cnn import (Caffe2XavierInit, ConstantInit, KaimingInit, NormalInit,
|
|
PretrainedInit, TruncNormalInit, UniformInit, XavierInit,
|
|
bias_init_with_prob, caffe2_xavier_init, constant_init,
|
|
initialize, kaiming_init, normal_init, trunc_normal_init,
|
|
uniform_init, xavier_init)
|
|
|
|
if torch.__version__ == 'parrots':
|
|
pytest.skip('not supported in parrots now', allow_module_level=True)
|
|
|
|
|
|
def test_constant_init():
|
|
conv_module = nn.Conv2d(3, 16, 3)
|
|
constant_init(conv_module, 0.1)
|
|
assert conv_module.weight.allclose(
|
|
torch.full_like(conv_module.weight, 0.1))
|
|
assert conv_module.bias.allclose(torch.zeros_like(conv_module.bias))
|
|
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
|
|
constant_init(conv_module_no_bias, 0.1)
|
|
assert conv_module.weight.allclose(
|
|
torch.full_like(conv_module.weight, 0.1))
|
|
|
|
|
|
def test_xavier_init():
|
|
conv_module = nn.Conv2d(3, 16, 3)
|
|
xavier_init(conv_module, bias=0.1)
|
|
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
|
|
xavier_init(conv_module, distribution='uniform')
|
|
# TODO: sanity check of weight distribution, e.g. mean, std
|
|
with pytest.raises(AssertionError):
|
|
xavier_init(conv_module, distribution='student-t')
|
|
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
|
|
xavier_init(conv_module_no_bias)
|
|
|
|
|
|
def test_normal_init():
|
|
conv_module = nn.Conv2d(3, 16, 3)
|
|
normal_init(conv_module, bias=0.1)
|
|
# TODO: sanity check of weight distribution, e.g. mean, std
|
|
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
|
|
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
|
|
normal_init(conv_module_no_bias)
|
|
# TODO: sanity check distribution, e.g. mean, std
|
|
|
|
|
|
def test_trunc_normal_init():
|
|
|
|
def _random_float(a, b):
|
|
return (b - a) * random.random() + a
|
|
|
|
def _is_trunc_normal(tensor, mean, std, a, b):
|
|
# scipy's trunc norm is suited for data drawn from N(0, 1),
|
|
# so we need to transform our data to test it using scipy.
|
|
z_samples = (tensor.view(-1) - mean) / std
|
|
z_samples = z_samples.tolist()
|
|
a0 = (a - mean) / std
|
|
b0 = (b - mean) / std
|
|
p_value = stats.kstest(z_samples, 'truncnorm', args=(a0, b0))[1]
|
|
return p_value > 0.0001
|
|
|
|
conv_module = nn.Conv2d(3, 16, 3)
|
|
mean = _random_float(-3, 3)
|
|
std = _random_float(.01, 1)
|
|
a = _random_float(mean - 2 * std, mean)
|
|
b = _random_float(mean, mean + 2 * std)
|
|
trunc_normal_init(conv_module, mean, std, a, b, bias=0.1)
|
|
assert _is_trunc_normal(conv_module.weight, mean, std, a, b)
|
|
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
|
|
|
|
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
|
|
trunc_normal_init(conv_module_no_bias)
|
|
# TODO: sanity check distribution, e.g. mean, std
|
|
|
|
|
|
def test_uniform_init():
|
|
conv_module = nn.Conv2d(3, 16, 3)
|
|
uniform_init(conv_module, bias=0.1)
|
|
# TODO: sanity check of weight distribution, e.g. mean, std
|
|
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
|
|
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
|
|
uniform_init(conv_module_no_bias)
|
|
|
|
|
|
def test_kaiming_init():
|
|
conv_module = nn.Conv2d(3, 16, 3)
|
|
kaiming_init(conv_module, bias=0.1)
|
|
# TODO: sanity check of weight distribution, e.g. mean, std
|
|
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
|
|
kaiming_init(conv_module, distribution='uniform')
|
|
with pytest.raises(AssertionError):
|
|
kaiming_init(conv_module, distribution='student-t')
|
|
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
|
|
kaiming_init(conv_module_no_bias)
|
|
|
|
|
|
def test_caffe_xavier_init():
|
|
conv_module = nn.Conv2d(3, 16, 3)
|
|
caffe2_xavier_init(conv_module)
|
|
|
|
|
|
def test_bias_init_with_prob():
|
|
conv_module = nn.Conv2d(3, 16, 3)
|
|
prior_prob = 0.1
|
|
normal_init(conv_module, bias=bias_init_with_prob(0.1))
|
|
# TODO: sanity check of weight distribution, e.g. mean, std
|
|
bias = float(-np.log((1 - prior_prob) / prior_prob))
|
|
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, bias))
|
|
|
|
|
|
def test_constaninit():
|
|
"""test ConstantInit class."""
|
|
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
|
|
func = ConstantInit(val=1, bias=2, layer='Conv2d')
|
|
func(model)
|
|
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
|
|
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
|
|
|
|
assert not torch.equal(model[2].weight,
|
|
torch.full(model[2].weight.shape, 1.))
|
|
assert not torch.equal(model[2].bias, torch.full(model[2].bias.shape, 2.))
|
|
|
|
func = ConstantInit(val=3, bias_prob=0.01, layer='Linear')
|
|
func(model)
|
|
res = bias_init_with_prob(0.01)
|
|
|
|
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
|
|
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 3.))
|
|
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
|
|
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, res))
|
|
|
|
# test layer key with base class name
|
|
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
|
|
func = ConstantInit(val=4., bias=5., layer='_ConvNd')
|
|
func(model)
|
|
assert torch.all(model[0].weight == 4.)
|
|
assert torch.all(model[2].weight == 4.)
|
|
assert torch.all(model[0].bias == 5.)
|
|
assert torch.all(model[2].bias == 5.)
|
|
|
|
# test bias input type
|
|
with pytest.raises(TypeError):
|
|
func = ConstantInit(val=1, bias='1')
|
|
# test bias_prob type
|
|
with pytest.raises(TypeError):
|
|
func = ConstantInit(val=1, bias_prob='1')
|
|
# test layer input type
|
|
with pytest.raises(TypeError):
|
|
func = ConstantInit(val=1, layer=1)
|
|
|
|
|
|
def test_xavierinit():
|
|
"""test XavierInit class."""
|
|
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
|
|
func = XavierInit(bias=0.1, layer='Conv2d')
|
|
func(model)
|
|
assert model[0].bias.allclose(torch.full_like(model[2].bias, 0.1))
|
|
assert not model[2].bias.allclose(torch.full_like(model[0].bias, 0.1))
|
|
|
|
constant_func = ConstantInit(val=0, bias=0, layer=['Conv2d', 'Linear'])
|
|
func = XavierInit(gain=100, bias_prob=0.01, layer=['Conv2d', 'Linear'])
|
|
model.apply(constant_func)
|
|
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.))
|
|
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.))
|
|
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.))
|
|
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.))
|
|
|
|
res = bias_init_with_prob(0.01)
|
|
func(model)
|
|
assert not torch.equal(model[0].weight,
|
|
torch.full(model[0].weight.shape, 0.))
|
|
assert not torch.equal(model[2].weight,
|
|
torch.full(model[2].weight.shape, 0.))
|
|
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, res))
|
|
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, res))
|
|
|
|
# test layer key with base class name
|
|
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
|
|
func = ConstantInit(val=4., bias=5., layer='_ConvNd')
|
|
func(model)
|
|
assert torch.all(model[0].weight == 4.)
|
|
assert torch.all(model[2].weight == 4.)
|
|
assert torch.all(model[0].bias == 5.)
|
|
assert torch.all(model[2].bias == 5.)
|
|
|
|
func = XavierInit(gain=100, bias_prob=0.01, layer='_ConvNd')
|
|
func(model)
|
|
assert not torch.all(model[0].weight == 4.)
|
|
assert not torch.all(model[2].weight == 4.)
|
|
assert torch.all(model[0].bias == res)
|
|
assert torch.all(model[2].bias == res)
|
|
|
|
# test bias input type
|
|
with pytest.raises(TypeError):
|
|
func = XavierInit(bias='0.1', layer='Conv2d')
|
|
# test layer inpur type
|
|
with pytest.raises(TypeError):
|
|
func = XavierInit(bias=0.1, layer=1)
|
|
|
|
|
|
def test_normalinit():
|
|
"""test Normalinit class."""
|
|
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
|
|
|
|
func = NormalInit(mean=100, std=1e-5, bias=200, layer=['Conv2d', 'Linear'])
|
|
func(model)
|
|
assert model[0].weight.allclose(torch.tensor(100.))
|
|
assert model[2].weight.allclose(torch.tensor(100.))
|
|
assert model[0].bias.allclose(torch.tensor(200.))
|
|
assert model[2].bias.allclose(torch.tensor(200.))
|
|
|
|
func = NormalInit(
|
|
mean=300, std=1e-5, bias_prob=0.01, layer=['Conv2d', 'Linear'])
|
|
res = bias_init_with_prob(0.01)
|
|
func(model)
|
|
assert model[0].weight.allclose(torch.tensor(300.))
|
|
assert model[2].weight.allclose(torch.tensor(300.))
|
|
assert model[0].bias.allclose(torch.tensor(res))
|
|
assert model[2].bias.allclose(torch.tensor(res))
|
|
|
|
# test layer key with base class name
|
|
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
|
|
|
|
func = NormalInit(mean=300, std=1e-5, bias_prob=0.01, layer='_ConvNd')
|
|
func(model)
|
|
assert model[0].weight.allclose(torch.tensor(300.))
|
|
assert model[2].weight.allclose(torch.tensor(300.))
|
|
assert torch.all(model[0].bias == res)
|
|
assert torch.all(model[2].bias == res)
|
|
|
|
|
|
def test_truncnormalinit():
|
|
"""test TruncNormalInit class."""
|
|
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
|
|
|
|
func = TruncNormalInit(
|
|
mean=100, std=1e-5, bias=200, a=0, b=200, layer=['Conv2d', 'Linear'])
|
|
func(model)
|
|
assert model[0].weight.allclose(torch.tensor(100.))
|
|
assert model[2].weight.allclose(torch.tensor(100.))
|
|
assert model[0].bias.allclose(torch.tensor(200.))
|
|
assert model[2].bias.allclose(torch.tensor(200.))
|
|
|
|
func = TruncNormalInit(
|
|
mean=300,
|
|
std=1e-5,
|
|
a=100,
|
|
b=400,
|
|
bias_prob=0.01,
|
|
layer=['Conv2d', 'Linear'])
|
|
res = bias_init_with_prob(0.01)
|
|
func(model)
|
|
assert model[0].weight.allclose(torch.tensor(300.))
|
|
assert model[2].weight.allclose(torch.tensor(300.))
|
|
assert model[0].bias.allclose(torch.tensor(res))
|
|
assert model[2].bias.allclose(torch.tensor(res))
|
|
|
|
# test layer key with base class name
|
|
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
|
|
|
|
func = TruncNormalInit(
|
|
mean=300, std=1e-5, a=100, b=400, bias_prob=0.01, layer='_ConvNd')
|
|
func(model)
|
|
assert model[0].weight.allclose(torch.tensor(300.))
|
|
assert model[2].weight.allclose(torch.tensor(300.))
|
|
assert torch.all(model[0].bias == res)
|
|
assert torch.all(model[2].bias == res)
|
|
|
|
|
|
def test_uniforminit():
|
|
""""test UniformInit class."""
|
|
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
|
|
func = UniformInit(a=1, b=1, bias=2, layer=['Conv2d', 'Linear'])
|
|
func(model)
|
|
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
|
|
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 1.))
|
|
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
|
|
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 2.))
|
|
|
|
func = UniformInit(a=100, b=100, layer=['Conv2d', 'Linear'], bias=10)
|
|
func(model)
|
|
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape,
|
|
100.))
|
|
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape,
|
|
100.))
|
|
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.))
|
|
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))
|
|
|
|
# test layer key with base class name
|
|
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
|
|
|
|
func = UniformInit(a=100, b=100, bias_prob=0.01, layer='_ConvNd')
|
|
res = bias_init_with_prob(0.01)
|
|
func(model)
|
|
assert torch.all(model[0].weight == 100.)
|
|
assert torch.all(model[2].weight == 100.)
|
|
assert torch.all(model[0].bias == res)
|
|
assert torch.all(model[2].bias == res)
|
|
|
|
|
|
def test_kaiminginit():
|
|
"""test KaimingInit class."""
|
|
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
|
|
func = KaimingInit(bias=0.1, layer='Conv2d')
|
|
func(model)
|
|
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.1))
|
|
assert not torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.1))
|
|
|
|
func = KaimingInit(a=100, bias=10, layer=['Conv2d', 'Linear'])
|
|
constant_func = ConstantInit(val=0, bias=0, layer=['Conv2d', 'Linear'])
|
|
model.apply(constant_func)
|
|
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.))
|
|
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.))
|
|
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.))
|
|
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.))
|
|
|
|
func(model)
|
|
assert not torch.equal(model[0].weight,
|
|
torch.full(model[0].weight.shape, 0.))
|
|
assert not torch.equal(model[2].weight,
|
|
torch.full(model[2].weight.shape, 0.))
|
|
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.))
|
|
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))
|
|
|
|
# test layer key with base class name
|
|
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
|
|
func = KaimingInit(bias=0.1, layer='_ConvNd')
|
|
func(model)
|
|
assert torch.all(model[0].bias == 0.1)
|
|
assert torch.all(model[2].bias == 0.1)
|
|
|
|
func = KaimingInit(a=100, bias=10, layer='_ConvNd')
|
|
constant_func = ConstantInit(val=0, bias=0, layer='_ConvNd')
|
|
model.apply(constant_func)
|
|
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.))
|
|
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.))
|
|
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.))
|
|
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.))
|
|
|
|
func(model)
|
|
assert not torch.equal(model[0].weight,
|
|
torch.full(model[0].weight.shape, 0.))
|
|
assert not torch.equal(model[2].weight,
|
|
torch.full(model[2].weight.shape, 0.))
|
|
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.))
|
|
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))
|
|
|
|
|
|
def test_caffe2xavierinit():
|
|
"""test Caffe2XavierInit."""
|
|
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
|
|
func = Caffe2XavierInit(bias=0.1, layer='Conv2d')
|
|
func(model)
|
|
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.1))
|
|
assert not torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.1))
|
|
|
|
|
|
class FooModule(nn.Module):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = nn.Linear(1, 2)
|
|
self.conv2d = nn.Conv2d(3, 1, 3)
|
|
self.conv2d_2 = nn.Conv2d(3, 2, 3)
|
|
|
|
|
|
def test_pretrainedinit():
|
|
"""test PretrainedInit class."""
|
|
|
|
modelA = FooModule()
|
|
constant_func = ConstantInit(val=1, bias=2, layer=['Conv2d', 'Linear'])
|
|
modelA.apply(constant_func)
|
|
modelB = FooModule()
|
|
funcB = PretrainedInit(checkpoint='modelA.pth')
|
|
modelC = nn.Linear(1, 2)
|
|
funcC = PretrainedInit(checkpoint='modelA.pth', prefix='linear.')
|
|
with TemporaryDirectory():
|
|
torch.save(modelA.state_dict(), 'modelA.pth')
|
|
funcB(modelB)
|
|
assert torch.equal(modelB.linear.weight,
|
|
torch.full(modelB.linear.weight.shape, 1.))
|
|
assert torch.equal(modelB.linear.bias,
|
|
torch.full(modelB.linear.bias.shape, 2.))
|
|
assert torch.equal(modelB.conv2d.weight,
|
|
torch.full(modelB.conv2d.weight.shape, 1.))
|
|
assert torch.equal(modelB.conv2d.bias,
|
|
torch.full(modelB.conv2d.bias.shape, 2.))
|
|
assert torch.equal(modelB.conv2d_2.weight,
|
|
torch.full(modelB.conv2d_2.weight.shape, 1.))
|
|
assert torch.equal(modelB.conv2d_2.bias,
|
|
torch.full(modelB.conv2d_2.bias.shape, 2.))
|
|
|
|
funcC(modelC)
|
|
assert torch.equal(modelC.weight, torch.full(modelC.weight.shape, 1.))
|
|
assert torch.equal(modelC.bias, torch.full(modelC.bias.shape, 2.))
|
|
|
|
|
|
def test_initialize():
|
|
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
|
|
foonet = FooModule()
|
|
|
|
# test layer key
|
|
init_cfg = dict(type='Constant', layer=['Conv2d', 'Linear'], val=1, bias=2)
|
|
initialize(model, init_cfg)
|
|
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
|
|
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 1.))
|
|
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
|
|
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 2.))
|
|
assert init_cfg == dict(
|
|
type='Constant', layer=['Conv2d', 'Linear'], val=1, bias=2)
|
|
|
|
# test init_cfg with list type
|
|
init_cfg = [
|
|
dict(type='Constant', layer='Conv2d', val=1, bias=2),
|
|
dict(type='Constant', layer='Linear', val=3, bias=4)
|
|
]
|
|
initialize(model, init_cfg)
|
|
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
|
|
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 3.))
|
|
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
|
|
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 4.))
|
|
assert init_cfg == [
|
|
dict(type='Constant', layer='Conv2d', val=1, bias=2),
|
|
dict(type='Constant', layer='Linear', val=3, bias=4)
|
|
]
|
|
|
|
# test layer key and override key
|
|
init_cfg = dict(
|
|
type='Constant',
|
|
val=1,
|
|
bias=2,
|
|
layer=['Conv2d', 'Linear'],
|
|
override=dict(type='Constant', name='conv2d_2', val=3, bias=4))
|
|
initialize(foonet, init_cfg)
|
|
assert torch.equal(foonet.linear.weight,
|
|
torch.full(foonet.linear.weight.shape, 1.))
|
|
assert torch.equal(foonet.linear.bias,
|
|
torch.full(foonet.linear.bias.shape, 2.))
|
|
assert torch.equal(foonet.conv2d.weight,
|
|
torch.full(foonet.conv2d.weight.shape, 1.))
|
|
assert torch.equal(foonet.conv2d.bias,
|
|
torch.full(foonet.conv2d.bias.shape, 2.))
|
|
assert torch.equal(foonet.conv2d_2.weight,
|
|
torch.full(foonet.conv2d_2.weight.shape, 3.))
|
|
assert torch.equal(foonet.conv2d_2.bias,
|
|
torch.full(foonet.conv2d_2.bias.shape, 4.))
|
|
assert init_cfg == dict(
|
|
type='Constant',
|
|
val=1,
|
|
bias=2,
|
|
layer=['Conv2d', 'Linear'],
|
|
override=dict(type='Constant', name='conv2d_2', val=3, bias=4))
|
|
|
|
# test override key
|
|
init_cfg = dict(
|
|
type='Constant', val=5, bias=6, override=dict(name='conv2d_2'))
|
|
initialize(foonet, init_cfg)
|
|
assert not torch.equal(foonet.linear.weight,
|
|
torch.full(foonet.linear.weight.shape, 5.))
|
|
assert not torch.equal(foonet.linear.bias,
|
|
torch.full(foonet.linear.bias.shape, 6.))
|
|
assert not torch.equal(foonet.conv2d.weight,
|
|
torch.full(foonet.conv2d.weight.shape, 5.))
|
|
assert not torch.equal(foonet.conv2d.bias,
|
|
torch.full(foonet.conv2d.bias.shape, 6.))
|
|
assert torch.equal(foonet.conv2d_2.weight,
|
|
torch.full(foonet.conv2d_2.weight.shape, 5.))
|
|
assert torch.equal(foonet.conv2d_2.bias,
|
|
torch.full(foonet.conv2d_2.bias.shape, 6.))
|
|
assert init_cfg == dict(
|
|
type='Constant', val=5, bias=6, override=dict(name='conv2d_2'))
|
|
|
|
init_cfg = dict(
|
|
type='Pretrained',
|
|
checkpoint='modelA.pth',
|
|
override=dict(type='Constant', name='conv2d_2', val=3, bias=4))
|
|
modelA = FooModule()
|
|
constant_func = ConstantInit(val=1, bias=2, layer=['Conv2d', 'Linear'])
|
|
modelA.apply(constant_func)
|
|
with TemporaryDirectory():
|
|
torch.save(modelA.state_dict(), 'modelA.pth')
|
|
initialize(foonet, init_cfg)
|
|
assert torch.equal(foonet.linear.weight,
|
|
torch.full(foonet.linear.weight.shape, 1.))
|
|
assert torch.equal(foonet.linear.bias,
|
|
torch.full(foonet.linear.bias.shape, 2.))
|
|
assert torch.equal(foonet.conv2d.weight,
|
|
torch.full(foonet.conv2d.weight.shape, 1.))
|
|
assert torch.equal(foonet.conv2d.bias,
|
|
torch.full(foonet.conv2d.bias.shape, 2.))
|
|
assert torch.equal(foonet.conv2d_2.weight,
|
|
torch.full(foonet.conv2d_2.weight.shape, 3.))
|
|
assert torch.equal(foonet.conv2d_2.bias,
|
|
torch.full(foonet.conv2d_2.bias.shape, 4.))
|
|
assert init_cfg == dict(
|
|
type='Pretrained',
|
|
checkpoint='modelA.pth',
|
|
override=dict(type='Constant', name='conv2d_2', val=3, bias=4))
|
|
|
|
# test init_cfg type
|
|
with pytest.raises(TypeError):
|
|
init_cfg = 'init_cfg'
|
|
initialize(foonet, init_cfg)
|
|
|
|
# test override value type
|
|
with pytest.raises(TypeError):
|
|
init_cfg = dict(
|
|
type='Constant',
|
|
val=1,
|
|
bias=2,
|
|
layer=['Conv2d', 'Linear'],
|
|
override='conv')
|
|
initialize(foonet, init_cfg)
|
|
|
|
# test override name
|
|
with pytest.raises(RuntimeError):
|
|
init_cfg = dict(
|
|
type='Constant',
|
|
val=1,
|
|
bias=2,
|
|
layer=['Conv2d', 'Linear'],
|
|
override=dict(type='Constant', name='conv2d_3', val=3, bias=4))
|
|
initialize(foonet, init_cfg)
|
|
|
|
# test list override name
|
|
with pytest.raises(RuntimeError):
|
|
init_cfg = dict(
|
|
type='Constant',
|
|
val=1,
|
|
bias=2,
|
|
layer=['Conv2d', 'Linear'],
|
|
override=[
|
|
dict(type='Constant', name='conv2d', val=3, bias=4),
|
|
dict(type='Constant', name='conv2d_3', val=5, bias=6)
|
|
])
|
|
initialize(foonet, init_cfg)
|
|
|
|
# test override with args except type key
|
|
with pytest.raises(ValueError):
|
|
init_cfg = dict(
|
|
type='Constant',
|
|
val=1,
|
|
bias=2,
|
|
override=dict(name='conv2d_2', val=3, bias=4))
|
|
initialize(foonet, init_cfg)
|
|
|
|
# test override without name
|
|
with pytest.raises(ValueError):
|
|
init_cfg = dict(
|
|
type='Constant',
|
|
val=1,
|
|
bias=2,
|
|
override=dict(type='Constant', val=3, bias=4))
|
|
initialize(foonet, init_cfg)
|