mirror of https://github.com/open-mmlab/mmcv.git
229 lines
8.4 KiB
Python
229 lines
8.4 KiB
Python
|
import torch
|
||
|
from torch import nn
|
||
|
|
||
|
from mmcv.runner import BaseModule
|
||
|
from mmcv.utils import Registry, build_from_cfg
|
||
|
|
||
|
COMPONENTS = Registry('component')
|
||
|
FOOMODELS = Registry('model')
|
||
|
|
||
|
|
||
|
@COMPONENTS.register_module()
|
||
|
class FooConv1d(BaseModule):
|
||
|
|
||
|
def __init__(self, init_cfg=None):
|
||
|
super().__init__(init_cfg)
|
||
|
self.conv1d = nn.Conv1d(4, 1, 4)
|
||
|
|
||
|
def forward(self, x):
|
||
|
return self.conv1d(x)
|
||
|
|
||
|
|
||
|
@COMPONENTS.register_module()
|
||
|
class FooConv2d(BaseModule):
|
||
|
|
||
|
def __init__(self, init_cfg=None):
|
||
|
super().__init__(init_cfg)
|
||
|
self.conv2d = nn.Conv2d(3, 1, 3)
|
||
|
|
||
|
def forward(self, x):
|
||
|
return self.conv2d(x)
|
||
|
|
||
|
|
||
|
@COMPONENTS.register_module()
|
||
|
class FooLinear(BaseModule):
|
||
|
|
||
|
def __init__(self, init_cfg=None):
|
||
|
super().__init__(init_cfg)
|
||
|
self.linear = nn.Linear(3, 4)
|
||
|
|
||
|
def forward(self, x):
|
||
|
return self.linear(x)
|
||
|
|
||
|
|
||
|
@COMPONENTS.register_module()
|
||
|
class FooLinearConv1d(BaseModule):
|
||
|
|
||
|
def __init__(self, linear=None, conv1d=None, init_cfg=None):
|
||
|
super().__init__(init_cfg)
|
||
|
if linear is not None:
|
||
|
self.linear = build_from_cfg(linear, COMPONENTS)
|
||
|
if conv1d is not None:
|
||
|
self.conv1d = build_from_cfg(conv1d, COMPONENTS)
|
||
|
|
||
|
def forward(self, x):
|
||
|
x = self.linear(x)
|
||
|
return self.conv1d(x)
|
||
|
|
||
|
|
||
|
@FOOMODELS.register_module()
|
||
|
class FooModel(BaseModule):
|
||
|
|
||
|
def __init__(self,
|
||
|
component1=None,
|
||
|
component2=None,
|
||
|
component3=None,
|
||
|
component4=None,
|
||
|
init_cfg=None) -> None:
|
||
|
super().__init__(init_cfg)
|
||
|
if component1 is not None:
|
||
|
self.component1 = build_from_cfg(component1, COMPONENTS)
|
||
|
if component2 is not None:
|
||
|
self.component2 = build_from_cfg(component2, COMPONENTS)
|
||
|
if component3 is not None:
|
||
|
self.component3 = build_from_cfg(component3, COMPONENTS)
|
||
|
if component4 is not None:
|
||
|
self.component4 = build_from_cfg(component4, COMPONENTS)
|
||
|
|
||
|
# its type is not BaseModule, it can be initialized
|
||
|
# with "override" key.
|
||
|
self.reg = nn.Linear(3, 4)
|
||
|
|
||
|
|
||
|
def test_model_weight_init():
|
||
|
"""
|
||
|
Config
|
||
|
model (FooModel, Linear: weight=1, bias=2, Conv1d: weight=3, bias=4,
|
||
|
Conv2d: weight=5, bias=6)
|
||
|
├──component1 (FooConv1d)
|
||
|
├──component2 (FooConv2d)
|
||
|
├──component3 (FooLinear)
|
||
|
├──component4 (FooLinearConv1d)
|
||
|
├──linear (FooLinear)
|
||
|
├──conv1d (FooConv1d)
|
||
|
├──reg (nn.Linear)
|
||
|
|
||
|
Parameters after initialization
|
||
|
model (FooModel)
|
||
|
├──component1 (FooConv1d, weight=3, bias=4)
|
||
|
├──component2 (FooConv2d, weight=5, bias=6)
|
||
|
├──component3 (FooLinear, weight=1, bias=2)
|
||
|
├──component4 (FooLinearConv1d)
|
||
|
├──linear (FooLinear, weight=1, bias=2)
|
||
|
├──conv1d (FooConv1d, weight=3, bias=4)
|
||
|
├──reg (nn.Linear, weight=1, bias=2)
|
||
|
"""
|
||
|
model_cfg = dict(
|
||
|
type='FooModel',
|
||
|
init_cfg=[
|
||
|
dict(type='Constant', val=1, bias=2, layer='Linear'),
|
||
|
dict(type='Constant', val=3, bias=4, layer='Conv1d'),
|
||
|
dict(type='Constant', val=5, bias=6, layer='Conv2d')
|
||
|
],
|
||
|
component1=dict(type='FooConv1d'),
|
||
|
component2=dict(type='FooConv2d'),
|
||
|
component3=dict(type='FooLinear'),
|
||
|
component4=dict(
|
||
|
type='FooLinearConv1d',
|
||
|
linear=dict(type='FooLinear'),
|
||
|
conv1d=dict(type='FooConv1d')))
|
||
|
|
||
|
model = build_from_cfg(model_cfg, FOOMODELS)
|
||
|
model.init_weight()
|
||
|
|
||
|
assert torch.equal(model.component1.conv1d.weight,
|
||
|
torch.full(model.component1.conv1d.weight.shape, 3.0))
|
||
|
assert torch.equal(model.component1.conv1d.bias,
|
||
|
torch.full(model.component1.conv1d.bias.shape, 4.0))
|
||
|
assert torch.equal(model.component2.conv2d.weight,
|
||
|
torch.full(model.component2.conv2d.weight.shape, 5.0))
|
||
|
assert torch.equal(model.component2.conv2d.bias,
|
||
|
torch.full(model.component2.conv2d.bias.shape, 6.0))
|
||
|
assert torch.equal(model.component3.linear.weight,
|
||
|
torch.full(model.component3.linear.weight.shape, 1.0))
|
||
|
assert torch.equal(model.component3.linear.bias,
|
||
|
torch.full(model.component3.linear.bias.shape, 2.0))
|
||
|
assert torch.equal(
|
||
|
model.component4.linear.linear.weight,
|
||
|
torch.full(model.component4.linear.linear.weight.shape, 1.0))
|
||
|
assert torch.equal(
|
||
|
model.component4.linear.linear.bias,
|
||
|
torch.full(model.component4.linear.linear.bias.shape, 2.0))
|
||
|
assert torch.equal(
|
||
|
model.component4.conv1d.conv1d.weight,
|
||
|
torch.full(model.component4.conv1d.conv1d.weight.shape, 3.0))
|
||
|
assert torch.equal(
|
||
|
model.component4.conv1d.conv1d.bias,
|
||
|
torch.full(model.component4.conv1d.conv1d.bias.shape, 4.0))
|
||
|
assert torch.equal(model.reg.weight, torch.full(model.reg.weight.shape,
|
||
|
1.0))
|
||
|
assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 2.0))
|
||
|
|
||
|
|
||
|
def test_nest_components_weight_init():
|
||
|
"""
|
||
|
Config
|
||
|
model (FooModel, Linear: weight=1, bias=2, Conv1d: weight=3, bias=4,
|
||
|
Conv2d: weight=5, bias=6)
|
||
|
├──component1 (FooConv1d, Conv1d: weight=7, bias=8)
|
||
|
├──component2 (FooConv2d, Conv2d: weight=9, bias=10)
|
||
|
├──component3 (FooLinear)
|
||
|
├──component4 (FooLinearConv1d, Linear: weight=11, bias=12)
|
||
|
├──linear (FooLinear, Linear: weight=11, bias=12)
|
||
|
├──conv1d (FooConv1d)
|
||
|
├──reg (nn.Linear, weight=13, bias=14)
|
||
|
|
||
|
Parameters after initialization
|
||
|
model (FooModel)
|
||
|
├──component1 (FooConv1d, weight=7, bias=8)
|
||
|
├──component2 (FooConv2d, weight=9, bias=10)
|
||
|
├──component3 (FooLinear, weight=1, bias=2)
|
||
|
├──component4 (FooLinearConv1d)
|
||
|
├──linear (FooLinear, weight=1, bias=2)
|
||
|
├──conv1d (FooConv1d, weight=3, bias=4)
|
||
|
├──reg (nn.Linear, weight=13, bias=14)
|
||
|
"""
|
||
|
|
||
|
model_cfg = dict(
|
||
|
type='FooModel',
|
||
|
init_cfg=[
|
||
|
dict(
|
||
|
type='Constant',
|
||
|
val=1,
|
||
|
bias=2,
|
||
|
layer='Linear',
|
||
|
override=dict(type='Constant', name='reg', val=13, bias=14)),
|
||
|
dict(type='Constant', val=3, bias=4, layer='Conv1d'),
|
||
|
dict(type='Constant', val=5, bias=6, layer='Conv2d'),
|
||
|
],
|
||
|
component1=dict(
|
||
|
type='FooConv1d', init_cfg=dict(type='Constant', val=7, bias=8)),
|
||
|
component2=dict(
|
||
|
type='FooConv2d', init_cfg=dict(type='Constant', val=9, bias=10)),
|
||
|
component3=dict(type='FooLinear'),
|
||
|
component4=dict(
|
||
|
type='FooLinearConv1d',
|
||
|
linear=dict(type='FooLinear'),
|
||
|
conv1d=dict(type='FooConv1d')))
|
||
|
|
||
|
model = build_from_cfg(model_cfg, FOOMODELS)
|
||
|
model.init_weight()
|
||
|
|
||
|
assert torch.equal(model.component1.conv1d.weight,
|
||
|
torch.full(model.component1.conv1d.weight.shape, 7.0))
|
||
|
assert torch.equal(model.component1.conv1d.bias,
|
||
|
torch.full(model.component1.conv1d.bias.shape, 8.0))
|
||
|
assert torch.equal(model.component2.conv2d.weight,
|
||
|
torch.full(model.component2.conv2d.weight.shape, 9.0))
|
||
|
assert torch.equal(model.component2.conv2d.bias,
|
||
|
torch.full(model.component2.conv2d.bias.shape, 10.0))
|
||
|
assert torch.equal(model.component3.linear.weight,
|
||
|
torch.full(model.component3.linear.weight.shape, 1.0))
|
||
|
assert torch.equal(model.component3.linear.bias,
|
||
|
torch.full(model.component3.linear.bias.shape, 2.0))
|
||
|
assert torch.equal(
|
||
|
model.component4.linear.linear.weight,
|
||
|
torch.full(model.component4.linear.linear.weight.shape, 1.0))
|
||
|
assert torch.equal(
|
||
|
model.component4.linear.linear.bias,
|
||
|
torch.full(model.component4.linear.linear.bias.shape, 2.0))
|
||
|
assert torch.equal(
|
||
|
model.component4.conv1d.conv1d.weight,
|
||
|
torch.full(model.component4.conv1d.conv1d.weight.shape, 3.0))
|
||
|
assert torch.equal(
|
||
|
model.component4.conv1d.conv1d.bias,
|
||
|
torch.full(model.component4.conv1d.conv1d.bias.shape, 4.0))
|
||
|
assert torch.equal(model.reg.weight,
|
||
|
torch.full(model.reg.weight.shape, 13.0))
|
||
|
assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 14.0))
|