[Feature] Support to use name of the base classes in init_cfg (#1057)

* [Fix] Support names of base classes matching in init_cfg

* revise bool to len
pull/1071/head
Miao Zheng 2021-06-01 22:35:19 +08:00 committed by GitHub
parent bf2c9fa8d2
commit 50c255bc2d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 99 additions and 11 deletions

View File

@ -93,6 +93,10 @@ def bias_init_with_prob(prior_prob):
return bias_init
def _get_bases_name(m):
return [b.__name__ for b in m.__class__.__bases__]
class BaseInit(object):
def __init__(self, *, bias=0, bias_prob=None, layer=None):
@ -146,7 +150,8 @@ class ConstantInit(BaseInit):
constant_init(m, self.val, self.bias)
else:
layername = m.__class__.__name__
if layername in self.layer:
basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
constant_init(m, self.val, self.bias)
module.apply(init)
@ -183,7 +188,8 @@ class XavierInit(BaseInit):
xavier_init(m, self.gain, self.bias, self.distribution)
else:
layername = m.__class__.__name__
if layername in self.layer:
basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
xavier_init(m, self.gain, self.bias, self.distribution)
module.apply(init)
@ -219,9 +225,9 @@ class NormalInit(BaseInit):
normal_init(m, self.mean, self.std, self.bias)
else:
layername = m.__class__.__name__
for layer_ in self.layer:
if layername == layer_:
normal_init(m, self.mean, self.std, self.bias)
basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
normal_init(m, self.mean, self.std, self.bias)
module.apply(init)
@ -267,10 +273,10 @@ class TruncNormalInit(BaseInit):
self.bias)
else:
layername = m.__class__.__name__
for layer_ in self.layer:
if layername == layer_:
trunc_normal_init(m, self.mean, self.std, self.a,
self.b, self.bias)
basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
trunc_normal_init(m, self.mean, self.std, self.a, self.b,
self.bias)
module.apply(init)
@ -305,7 +311,8 @@ class UniformInit(BaseInit):
uniform_init(m, self.a, self.b, self.bias)
else:
layername = m.__class__.__name__
if layername in self.layer:
basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
uniform_init(m, self.a, self.b, self.bias)
module.apply(init)
@ -359,7 +366,8 @@ class KaimingInit(BaseInit):
self.bias, self.distribution)
else:
layername = m.__class__.__name__
if layername in self.layer:
basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
kaiming_init(m, self.a, self.mode, self.nonlinearity,
self.bias, self.distribution)

View File

@ -134,6 +134,15 @@ def test_constaninit():
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, res))
# test layer key with base class name
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
func = ConstantInit(val=4., bias=5., layer='_ConvNd')
func(model)
assert torch.all(model[0].weight == 4.)
assert torch.all(model[2].weight == 4.)
assert torch.all(model[0].bias == 5.)
assert torch.all(model[2].bias == 5.)
# test bias input type
with pytest.raises(TypeError):
func = ConstantInit(val=1, bias='1')
@ -170,6 +179,22 @@ def test_xavierinit():
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, res))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, res))
# test layer key with base class name
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
func = ConstantInit(val=4., bias=5., layer='_ConvNd')
func(model)
assert torch.all(model[0].weight == 4.)
assert torch.all(model[2].weight == 4.)
assert torch.all(model[0].bias == 5.)
assert torch.all(model[2].bias == 5.)
func = XavierInit(gain=100, bias_prob=0.01, layer='_ConvNd')
func(model)
assert not torch.all(model[0].weight == 4.)
assert not torch.all(model[2].weight == 4.)
assert torch.all(model[0].bias == res)
assert torch.all(model[2].bias == res)
# test bias input type
with pytest.raises(TypeError):
func = XavierInit(bias='0.1', layer='Conv2d')
@ -198,6 +223,16 @@ def test_normalinit():
assert model[0].bias.allclose(torch.tensor(res))
assert model[2].bias.allclose(torch.tensor(res))
# test layer key with base class name
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
func = NormalInit(mean=300, std=1e-5, bias_prob=0.01, layer='_ConvNd')
func(model)
assert model[0].weight.allclose(torch.tensor(300.))
assert model[2].weight.allclose(torch.tensor(300.))
assert torch.all(model[0].bias == res)
assert torch.all(model[2].bias == res)
def test_truncnormalinit():
"""test TruncNormalInit class."""
@ -225,6 +260,17 @@ def test_truncnormalinit():
assert model[0].bias.allclose(torch.tensor(res))
assert model[2].bias.allclose(torch.tensor(res))
# test layer key with base class name
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
func = TruncNormalInit(
mean=300, std=1e-5, a=100, b=400, bias_prob=0.01, layer='_ConvNd')
func(model)
assert model[0].weight.allclose(torch.tensor(300.))
assert model[2].weight.allclose(torch.tensor(300.))
assert torch.all(model[0].bias == res)
assert torch.all(model[2].bias == res)
def test_uniforminit():
""""test UniformInit class."""
@ -245,6 +291,17 @@ def test_uniforminit():
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))
# test layer key with base class name
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
func = UniformInit(a=100, b=100, bias_prob=0.01, layer='_ConvNd')
res = bias_init_with_prob(0.01)
func(model)
assert torch.all(model[0].weight == 100.)
assert torch.all(model[2].weight == 100.)
assert torch.all(model[0].bias == res)
assert torch.all(model[2].bias == res)
def test_kaiminginit():
"""test KaimingInit class."""
@ -270,6 +327,29 @@ def test_kaiminginit():
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))
# test layer key with base class name
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
func = KaimingInit(bias=0.1, layer='_ConvNd')
func(model)
assert torch.all(model[0].bias == 0.1)
assert torch.all(model[2].bias == 0.1)
func = KaimingInit(a=100, bias=10, layer='_ConvNd')
constant_func = ConstantInit(val=0, bias=0, layer='_ConvNd')
model.apply(constant_func)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.))
func(model)
assert not torch.equal(model[0].weight,
torch.full(model[0].weight.shape, 0.))
assert not torch.equal(model[2].weight,
torch.full(model[2].weight.shape, 0.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))
def test_caffe2xavierinit():
"""test Caffe2XavierInit."""