mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Support to use name of the base classes in init_cfg (#1057)
* [Fix] Support names of base classes matching in init_cfg * revise bool to lenpull/1071/head
parent
bf2c9fa8d2
commit
50c255bc2d
|
@ -93,6 +93,10 @@ def bias_init_with_prob(prior_prob):
|
|||
return bias_init
|
||||
|
||||
|
||||
def _get_bases_name(m):
|
||||
return [b.__name__ for b in m.__class__.__bases__]
|
||||
|
||||
|
||||
class BaseInit(object):
|
||||
|
||||
def __init__(self, *, bias=0, bias_prob=None, layer=None):
|
||||
|
@ -146,7 +150,8 @@ class ConstantInit(BaseInit):
|
|||
constant_init(m, self.val, self.bias)
|
||||
else:
|
||||
layername = m.__class__.__name__
|
||||
if layername in self.layer:
|
||||
basesname = _get_bases_name(m)
|
||||
if len(set(self.layer) & set([layername] + basesname)):
|
||||
constant_init(m, self.val, self.bias)
|
||||
|
||||
module.apply(init)
|
||||
|
@ -183,7 +188,8 @@ class XavierInit(BaseInit):
|
|||
xavier_init(m, self.gain, self.bias, self.distribution)
|
||||
else:
|
||||
layername = m.__class__.__name__
|
||||
if layername in self.layer:
|
||||
basesname = _get_bases_name(m)
|
||||
if len(set(self.layer) & set([layername] + basesname)):
|
||||
xavier_init(m, self.gain, self.bias, self.distribution)
|
||||
|
||||
module.apply(init)
|
||||
|
@ -219,9 +225,9 @@ class NormalInit(BaseInit):
|
|||
normal_init(m, self.mean, self.std, self.bias)
|
||||
else:
|
||||
layername = m.__class__.__name__
|
||||
for layer_ in self.layer:
|
||||
if layername == layer_:
|
||||
normal_init(m, self.mean, self.std, self.bias)
|
||||
basesname = _get_bases_name(m)
|
||||
if len(set(self.layer) & set([layername] + basesname)):
|
||||
normal_init(m, self.mean, self.std, self.bias)
|
||||
|
||||
module.apply(init)
|
||||
|
||||
|
@ -267,10 +273,10 @@ class TruncNormalInit(BaseInit):
|
|||
self.bias)
|
||||
else:
|
||||
layername = m.__class__.__name__
|
||||
for layer_ in self.layer:
|
||||
if layername == layer_:
|
||||
trunc_normal_init(m, self.mean, self.std, self.a,
|
||||
self.b, self.bias)
|
||||
basesname = _get_bases_name(m)
|
||||
if len(set(self.layer) & set([layername] + basesname)):
|
||||
trunc_normal_init(m, self.mean, self.std, self.a, self.b,
|
||||
self.bias)
|
||||
|
||||
module.apply(init)
|
||||
|
||||
|
@ -305,7 +311,8 @@ class UniformInit(BaseInit):
|
|||
uniform_init(m, self.a, self.b, self.bias)
|
||||
else:
|
||||
layername = m.__class__.__name__
|
||||
if layername in self.layer:
|
||||
basesname = _get_bases_name(m)
|
||||
if len(set(self.layer) & set([layername] + basesname)):
|
||||
uniform_init(m, self.a, self.b, self.bias)
|
||||
|
||||
module.apply(init)
|
||||
|
@ -359,7 +366,8 @@ class KaimingInit(BaseInit):
|
|||
self.bias, self.distribution)
|
||||
else:
|
||||
layername = m.__class__.__name__
|
||||
if layername in self.layer:
|
||||
basesname = _get_bases_name(m)
|
||||
if len(set(self.layer) & set([layername] + basesname)):
|
||||
kaiming_init(m, self.a, self.mode, self.nonlinearity,
|
||||
self.bias, self.distribution)
|
||||
|
||||
|
|
|
@ -134,6 +134,15 @@ def test_constaninit():
|
|||
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
|
||||
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, res))
|
||||
|
||||
# test layer key with base class name
|
||||
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
|
||||
func = ConstantInit(val=4., bias=5., layer='_ConvNd')
|
||||
func(model)
|
||||
assert torch.all(model[0].weight == 4.)
|
||||
assert torch.all(model[2].weight == 4.)
|
||||
assert torch.all(model[0].bias == 5.)
|
||||
assert torch.all(model[2].bias == 5.)
|
||||
|
||||
# test bias input type
|
||||
with pytest.raises(TypeError):
|
||||
func = ConstantInit(val=1, bias='1')
|
||||
|
@ -170,6 +179,22 @@ def test_xavierinit():
|
|||
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, res))
|
||||
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, res))
|
||||
|
||||
# test layer key with base class name
|
||||
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
|
||||
func = ConstantInit(val=4., bias=5., layer='_ConvNd')
|
||||
func(model)
|
||||
assert torch.all(model[0].weight == 4.)
|
||||
assert torch.all(model[2].weight == 4.)
|
||||
assert torch.all(model[0].bias == 5.)
|
||||
assert torch.all(model[2].bias == 5.)
|
||||
|
||||
func = XavierInit(gain=100, bias_prob=0.01, layer='_ConvNd')
|
||||
func(model)
|
||||
assert not torch.all(model[0].weight == 4.)
|
||||
assert not torch.all(model[2].weight == 4.)
|
||||
assert torch.all(model[0].bias == res)
|
||||
assert torch.all(model[2].bias == res)
|
||||
|
||||
# test bias input type
|
||||
with pytest.raises(TypeError):
|
||||
func = XavierInit(bias='0.1', layer='Conv2d')
|
||||
|
@ -198,6 +223,16 @@ def test_normalinit():
|
|||
assert model[0].bias.allclose(torch.tensor(res))
|
||||
assert model[2].bias.allclose(torch.tensor(res))
|
||||
|
||||
# test layer key with base class name
|
||||
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
|
||||
|
||||
func = NormalInit(mean=300, std=1e-5, bias_prob=0.01, layer='_ConvNd')
|
||||
func(model)
|
||||
assert model[0].weight.allclose(torch.tensor(300.))
|
||||
assert model[2].weight.allclose(torch.tensor(300.))
|
||||
assert torch.all(model[0].bias == res)
|
||||
assert torch.all(model[2].bias == res)
|
||||
|
||||
|
||||
def test_truncnormalinit():
|
||||
"""test TruncNormalInit class."""
|
||||
|
@ -225,6 +260,17 @@ def test_truncnormalinit():
|
|||
assert model[0].bias.allclose(torch.tensor(res))
|
||||
assert model[2].bias.allclose(torch.tensor(res))
|
||||
|
||||
# test layer key with base class name
|
||||
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
|
||||
|
||||
func = TruncNormalInit(
|
||||
mean=300, std=1e-5, a=100, b=400, bias_prob=0.01, layer='_ConvNd')
|
||||
func(model)
|
||||
assert model[0].weight.allclose(torch.tensor(300.))
|
||||
assert model[2].weight.allclose(torch.tensor(300.))
|
||||
assert torch.all(model[0].bias == res)
|
||||
assert torch.all(model[2].bias == res)
|
||||
|
||||
|
||||
def test_uniforminit():
|
||||
""""test UniformInit class."""
|
||||
|
@ -245,6 +291,17 @@ def test_uniforminit():
|
|||
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.))
|
||||
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))
|
||||
|
||||
# test layer key with base class name
|
||||
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
|
||||
|
||||
func = UniformInit(a=100, b=100, bias_prob=0.01, layer='_ConvNd')
|
||||
res = bias_init_with_prob(0.01)
|
||||
func(model)
|
||||
assert torch.all(model[0].weight == 100.)
|
||||
assert torch.all(model[2].weight == 100.)
|
||||
assert torch.all(model[0].bias == res)
|
||||
assert torch.all(model[2].bias == res)
|
||||
|
||||
|
||||
def test_kaiminginit():
|
||||
"""test KaimingInit class."""
|
||||
|
@ -270,6 +327,29 @@ def test_kaiminginit():
|
|||
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.))
|
||||
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))
|
||||
|
||||
# test layer key with base class name
|
||||
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
|
||||
func = KaimingInit(bias=0.1, layer='_ConvNd')
|
||||
func(model)
|
||||
assert torch.all(model[0].bias == 0.1)
|
||||
assert torch.all(model[2].bias == 0.1)
|
||||
|
||||
func = KaimingInit(a=100, bias=10, layer='_ConvNd')
|
||||
constant_func = ConstantInit(val=0, bias=0, layer='_ConvNd')
|
||||
model.apply(constant_func)
|
||||
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.))
|
||||
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.))
|
||||
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.))
|
||||
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.))
|
||||
|
||||
func(model)
|
||||
assert not torch.equal(model[0].weight,
|
||||
torch.full(model[0].weight.shape, 0.))
|
||||
assert not torch.equal(model[2].weight,
|
||||
torch.full(model[2].weight.shape, 0.))
|
||||
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.))
|
||||
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))
|
||||
|
||||
|
||||
def test_caffe2xavierinit():
|
||||
"""test Caffe2XavierInit."""
|
||||
|
|
Loading…
Reference in New Issue