import torch from torch import nn from mmcv.runner import BaseModule, ModuleList, Sequential from mmcv.utils import Registry, build_from_cfg COMPONENTS = Registry('component') FOOMODELS = Registry('model') @COMPONENTS.register_module() class FooConv1d(BaseModule): def __init__(self, init_cfg=None): super().__init__(init_cfg) self.conv1d = nn.Conv1d(4, 1, 4) def forward(self, x): return self.conv1d(x) @COMPONENTS.register_module() class FooConv2d(BaseModule): def __init__(self, init_cfg=None): super().__init__(init_cfg) self.conv2d = nn.Conv2d(3, 1, 3) def forward(self, x): return self.conv2d(x) @COMPONENTS.register_module() class FooLinear(BaseModule): def __init__(self, init_cfg=None): super().__init__(init_cfg) self.linear = nn.Linear(3, 4) def forward(self, x): return self.linear(x) @COMPONENTS.register_module() class FooLinearConv1d(BaseModule): def __init__(self, linear=None, conv1d=None, init_cfg=None): super().__init__(init_cfg) if linear is not None: self.linear = build_from_cfg(linear, COMPONENTS) if conv1d is not None: self.conv1d = build_from_cfg(conv1d, COMPONENTS) def forward(self, x): x = self.linear(x) return self.conv1d(x) @FOOMODELS.register_module() class FooModel(BaseModule): def __init__(self, component1=None, component2=None, component3=None, component4=None, init_cfg=None) -> None: super().__init__(init_cfg) if component1 is not None: self.component1 = build_from_cfg(component1, COMPONENTS) if component2 is not None: self.component2 = build_from_cfg(component2, COMPONENTS) if component3 is not None: self.component3 = build_from_cfg(component3, COMPONENTS) if component4 is not None: self.component4 = build_from_cfg(component4, COMPONENTS) # its type is not BaseModule, it can be initialized # with "override" key. self.reg = nn.Linear(3, 4) def test_model_weight_init(): """ Config model (FooModel, Linear: weight=1, bias=2, Conv1d: weight=3, bias=4, Conv2d: weight=5, bias=6) ├──component1 (FooConv1d) ├──component2 (FooConv2d) ├──component3 (FooLinear) ├──component4 (FooLinearConv1d) ├──linear (FooLinear) ├──conv1d (FooConv1d) ├──reg (nn.Linear) Parameters after initialization model (FooModel) ├──component1 (FooConv1d, weight=3, bias=4) ├──component2 (FooConv2d, weight=5, bias=6) ├──component3 (FooLinear, weight=1, bias=2) ├──component4 (FooLinearConv1d) ├──linear (FooLinear, weight=1, bias=2) ├──conv1d (FooConv1d, weight=3, bias=4) ├──reg (nn.Linear, weight=1, bias=2) """ model_cfg = dict( type='FooModel', init_cfg=[ dict(type='Constant', val=1, bias=2, layer='Linear'), dict(type='Constant', val=3, bias=4, layer='Conv1d'), dict(type='Constant', val=5, bias=6, layer='Conv2d') ], component1=dict(type='FooConv1d'), component2=dict(type='FooConv2d'), component3=dict(type='FooLinear'), component4=dict( type='FooLinearConv1d', linear=dict(type='FooLinear'), conv1d=dict(type='FooConv1d'))) model = build_from_cfg(model_cfg, FOOMODELS) model.init_weight() assert torch.equal(model.component1.conv1d.weight, torch.full(model.component1.conv1d.weight.shape, 3.0)) assert torch.equal(model.component1.conv1d.bias, torch.full(model.component1.conv1d.bias.shape, 4.0)) assert torch.equal(model.component2.conv2d.weight, torch.full(model.component2.conv2d.weight.shape, 5.0)) assert torch.equal(model.component2.conv2d.bias, torch.full(model.component2.conv2d.bias.shape, 6.0)) assert torch.equal(model.component3.linear.weight, torch.full(model.component3.linear.weight.shape, 1.0)) assert torch.equal(model.component3.linear.bias, torch.full(model.component3.linear.bias.shape, 2.0)) assert torch.equal( model.component4.linear.linear.weight, torch.full(model.component4.linear.linear.weight.shape, 1.0)) assert torch.equal( model.component4.linear.linear.bias, torch.full(model.component4.linear.linear.bias.shape, 2.0)) assert torch.equal( model.component4.conv1d.conv1d.weight, torch.full(model.component4.conv1d.conv1d.weight.shape, 3.0)) assert torch.equal( model.component4.conv1d.conv1d.bias, torch.full(model.component4.conv1d.conv1d.bias.shape, 4.0)) assert torch.equal(model.reg.weight, torch.full(model.reg.weight.shape, 1.0)) assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 2.0)) def test_nest_components_weight_init(): """ Config model (FooModel, Linear: weight=1, bias=2, Conv1d: weight=3, bias=4, Conv2d: weight=5, bias=6) ├──component1 (FooConv1d, Conv1d: weight=7, bias=8) ├──component2 (FooConv2d, Conv2d: weight=9, bias=10) ├──component3 (FooLinear) ├──component4 (FooLinearConv1d, Linear: weight=11, bias=12) ├──linear (FooLinear, Linear: weight=11, bias=12) ├──conv1d (FooConv1d) ├──reg (nn.Linear, weight=13, bias=14) Parameters after initialization model (FooModel) ├──component1 (FooConv1d, weight=7, bias=8) ├──component2 (FooConv2d, weight=9, bias=10) ├──component3 (FooLinear, weight=1, bias=2) ├──component4 (FooLinearConv1d) ├──linear (FooLinear, weight=1, bias=2) ├──conv1d (FooConv1d, weight=3, bias=4) ├──reg (nn.Linear, weight=13, bias=14) """ model_cfg = dict( type='FooModel', init_cfg=[ dict( type='Constant', val=1, bias=2, layer='Linear', override=dict(type='Constant', name='reg', val=13, bias=14)), dict(type='Constant', val=3, bias=4, layer='Conv1d'), dict(type='Constant', val=5, bias=6, layer='Conv2d'), ], component1=dict( type='FooConv1d', init_cfg=dict(type='Constant', layer='Conv1d', val=7, bias=8)), component2=dict( type='FooConv2d', init_cfg=dict(type='Constant', layer='Conv2d', val=9, bias=10)), component3=dict(type='FooLinear'), component4=dict( type='FooLinearConv1d', linear=dict(type='FooLinear'), conv1d=dict(type='FooConv1d'))) model = build_from_cfg(model_cfg, FOOMODELS) model.init_weight() assert torch.equal(model.component1.conv1d.weight, torch.full(model.component1.conv1d.weight.shape, 7.0)) assert torch.equal(model.component1.conv1d.bias, torch.full(model.component1.conv1d.bias.shape, 8.0)) assert torch.equal(model.component2.conv2d.weight, torch.full(model.component2.conv2d.weight.shape, 9.0)) assert torch.equal(model.component2.conv2d.bias, torch.full(model.component2.conv2d.bias.shape, 10.0)) assert torch.equal(model.component3.linear.weight, torch.full(model.component3.linear.weight.shape, 1.0)) assert torch.equal(model.component3.linear.bias, torch.full(model.component3.linear.bias.shape, 2.0)) assert torch.equal( model.component4.linear.linear.weight, torch.full(model.component4.linear.linear.weight.shape, 1.0)) assert torch.equal( model.component4.linear.linear.bias, torch.full(model.component4.linear.linear.bias.shape, 2.0)) assert torch.equal( model.component4.conv1d.conv1d.weight, torch.full(model.component4.conv1d.conv1d.weight.shape, 3.0)) assert torch.equal( model.component4.conv1d.conv1d.bias, torch.full(model.component4.conv1d.conv1d.bias.shape, 4.0)) assert torch.equal(model.reg.weight, torch.full(model.reg.weight.shape, 13.0)) assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 14.0)) def test_without_layer_weight_init(): model_cfg = dict( type='FooModel', init_cfg=[ dict(type='Constant', val=1, bias=2, layer='Linear'), dict(type='Constant', val=3, bias=4, layer='Conv1d'), dict(type='Constant', val=5, bias=6, layer='Conv2d') ], component1=dict( type='FooConv1d', init_cfg=dict(type='Constant', val=7, bias=8)), component2=dict(type='FooConv2d'), component3=dict(type='FooLinear')) model = build_from_cfg(model_cfg, FOOMODELS) model.init_weight() assert torch.equal(model.component1.conv1d.weight, torch.full(model.component1.conv1d.weight.shape, 3.0)) assert torch.equal(model.component1.conv1d.bias, torch.full(model.component1.conv1d.bias.shape, 4.0)) # init_cfg in component1 does not have layer key, so it does nothing assert torch.equal(model.component2.conv2d.weight, torch.full(model.component2.conv2d.weight.shape, 5.0)) assert torch.equal(model.component2.conv2d.bias, torch.full(model.component2.conv2d.bias.shape, 6.0)) assert torch.equal(model.component3.linear.weight, torch.full(model.component3.linear.weight.shape, 1.0)) assert torch.equal(model.component3.linear.bias, torch.full(model.component3.linear.bias.shape, 2.0)) assert torch.equal(model.reg.weight, torch.full(model.reg.weight.shape, 1.0)) assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 2.0)) def test_override_weight_init(): # only initialize 'override' model_cfg = dict( type='FooModel', init_cfg=[ dict(type='Constant', val=10, bias=20, override=dict(name='reg')) ], component1=dict(type='FooConv1d'), component3=dict(type='FooLinear')) model = build_from_cfg(model_cfg, FOOMODELS) model.init_weight() assert torch.equal(model.reg.weight, torch.full(model.reg.weight.shape, 10.0)) assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 20.0)) # do not initialize others assert not torch.equal( model.component1.conv1d.weight, torch.full(model.component1.conv1d.weight.shape, 10.0)) assert not torch.equal( model.component1.conv1d.bias, torch.full(model.component1.conv1d.bias.shape, 20.0)) assert not torch.equal( model.component3.linear.weight, torch.full(model.component3.linear.weight.shape, 10.0)) assert not torch.equal( model.component3.linear.bias, torch.full(model.component3.linear.bias.shape, 20.0)) # 'override' has higher priority model_cfg = dict( type='FooModel', init_cfg=[ dict( type='Constant', val=1, bias=2, override=dict(name='reg', type='Constant', val=30, bias=40)) ], component1=dict(type='FooConv1d'), component2=dict(type='FooConv2d'), component3=dict(type='FooLinear')) model = build_from_cfg(model_cfg, FOOMODELS) model.init_weight() assert torch.equal(model.reg.weight, torch.full(model.reg.weight.shape, 30.0)) assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 40.0)) def test_sequential_model_weight_init(): seq_model_cfg = [ dict( type='FooConv1d', init_cfg=dict(type='Constant', layer='Conv1d', val=0., bias=1.)), dict( type='FooConv2d', init_cfg=dict(type='Constant', layer='Conv2d', val=2., bias=3.)), ] layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg] seq_model = Sequential(*layers) seq_model.init_weight() assert torch.equal(seq_model[0].conv1d.weight, torch.full(seq_model[0].conv1d.weight.shape, 0.)) assert torch.equal(seq_model[0].conv1d.bias, torch.full(seq_model[0].conv1d.bias.shape, 1.)) assert torch.equal(seq_model[1].conv2d.weight, torch.full(seq_model[1].conv2d.weight.shape, 2.)) assert torch.equal(seq_model[1].conv2d.bias, torch.full(seq_model[1].conv2d.bias.shape, 3.)) # inner init_cfg has highter priority layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg] seq_model = Sequential( *layers, init_cfg=dict( type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.)) seq_model.init_weight() assert torch.equal(seq_model[0].conv1d.weight, torch.full(seq_model[0].conv1d.weight.shape, 0.)) assert torch.equal(seq_model[0].conv1d.bias, torch.full(seq_model[0].conv1d.bias.shape, 1.)) assert torch.equal(seq_model[1].conv2d.weight, torch.full(seq_model[1].conv2d.weight.shape, 2.)) assert torch.equal(seq_model[1].conv2d.bias, torch.full(seq_model[1].conv2d.bias.shape, 3.)) def test_modulelist_weight_init(): models_cfg = [ dict( type='FooConv1d', init_cfg=dict(type='Constant', layer='Conv1d', val=0., bias=1.)), dict( type='FooConv2d', init_cfg=dict(type='Constant', layer='Conv2d', val=2., bias=3.)), ] layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg] modellist = ModuleList(layers) modellist.init_weight() assert torch.equal(modellist[0].conv1d.weight, torch.full(modellist[0].conv1d.weight.shape, 0.)) assert torch.equal(modellist[0].conv1d.bias, torch.full(modellist[0].conv1d.bias.shape, 1.)) assert torch.equal(modellist[1].conv2d.weight, torch.full(modellist[1].conv2d.weight.shape, 2.)) assert torch.equal(modellist[1].conv2d.bias, torch.full(modellist[1].conv2d.bias.shape, 3.)) # inner init_cfg has highter priority layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg] modellist = ModuleList( layers, init_cfg=dict( type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.)) modellist.init_weight() assert torch.equal(modellist[0].conv1d.weight, torch.full(modellist[0].conv1d.weight.shape, 0.)) assert torch.equal(modellist[0].conv1d.bias, torch.full(modellist[0].conv1d.bias.shape, 1.)) assert torch.equal(modellist[1].conv2d.weight, torch.full(modellist[1].conv2d.weight.shape, 2.)) assert torch.equal(modellist[1].conv2d.bias, torch.full(modellist[1].conv2d.bias.shape, 3.))