mmcv/tests/test_runner/test_basemodule.py
Shilong Zhang 17fa6670eb
[Features] Add logger for initialization of parameters (#1150)
* add logger for init

* change init_info of oevrload init_weight

* add judgement for params_init_info

* add delete comments for params_init_info

* add docstr and more comments

* add docstr and more comments

* resolve comments

* dump to a file

* add unitest

* fix unitest

* fix unitest

* write to ori log

* fix typo

* resolve commnets
2021-07-23 21:07:54 +08:00

483 lines
18 KiB
Python

import tempfile
import torch
from torch import nn
from mmcv.runner import BaseModule, ModuleList, Sequential
from mmcv.runner.base_module import update_init_info
from mmcv.utils import Registry, build_from_cfg
COMPONENTS = Registry('component')
FOOMODELS = Registry('model')
@COMPONENTS.register_module()
class FooConv1d(BaseModule):
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.conv1d = nn.Conv1d(4, 1, 4)
def forward(self, x):
return self.conv1d(x)
@COMPONENTS.register_module()
class FooConv2d(BaseModule):
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.conv2d = nn.Conv2d(3, 1, 3)
def forward(self, x):
return self.conv2d(x)
@COMPONENTS.register_module()
class FooLinear(BaseModule):
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.linear = nn.Linear(3, 4)
def forward(self, x):
return self.linear(x)
@COMPONENTS.register_module()
class FooLinearConv1d(BaseModule):
def __init__(self, linear=None, conv1d=None, init_cfg=None):
super().__init__(init_cfg)
if linear is not None:
self.linear = build_from_cfg(linear, COMPONENTS)
if conv1d is not None:
self.conv1d = build_from_cfg(conv1d, COMPONENTS)
def forward(self, x):
x = self.linear(x)
return self.conv1d(x)
@FOOMODELS.register_module()
class FooModel(BaseModule):
def __init__(self,
component1=None,
component2=None,
component3=None,
component4=None,
init_cfg=None) -> None:
super().__init__(init_cfg)
if component1 is not None:
self.component1 = build_from_cfg(component1, COMPONENTS)
if component2 is not None:
self.component2 = build_from_cfg(component2, COMPONENTS)
if component3 is not None:
self.component3 = build_from_cfg(component3, COMPONENTS)
if component4 is not None:
self.component4 = build_from_cfg(component4, COMPONENTS)
# its type is not BaseModule, it can be initialized
# with "override" key.
self.reg = nn.Linear(3, 4)
def test_initilization_info_logger():
# 'override' has higher priority
import torch.nn as nn
from mmcv.utils.logging import get_logger
import os
class OverloadInitConv(nn.Conv2d, BaseModule):
def init_weights(self):
for p in self.parameters():
with torch.no_grad():
p.fill_(1)
class CheckLoggerModel(BaseModule):
def __init__(self, init_cfg=None):
super(CheckLoggerModel, self).__init__(init_cfg)
self.conv1 = nn.Conv2d(1, 1, 1, 1)
self.conv2 = OverloadInitConv(1, 1, 1, 1)
self.conv3 = nn.Conv2d(1, 1, 1, 1)
self.fc1 = nn.Linear(1, 1)
init_cfg = [
dict(
type='Normal',
layer='Conv2d',
std=0.01,
override=dict(
type='Normal', name='conv3', std=0.01, bias_prob=0.01)),
dict(type='Constant', layer='Linear', val=0., bias=1.)
]
model = CheckLoggerModel(init_cfg=init_cfg)
train_log = '20210720_132454.log'
workdir = tempfile.mkdtemp()
log_file = os.path.join(workdir, train_log)
# create a logger
get_logger('init_logger', log_file=log_file)
assert hasattr(model, '_params_init_info')
model.init_weights()
# assert `_params_init_info` would be deleted after `init_weights`
assert not hasattr(model, '_params_init_info')
# assert initialization information has been dumped
assert os.path.exists(log_file)
with open(log_file) as f:
lines = f.readlines()
for line in lines:
print(line)
# check initialization information is right
for line in lines:
if 'conv1.weight' in line:
assert 'NormalInit' in line
if 'conv2.weight' in line:
assert 'OverloadInitConv' in line
if 'fc1.weight' in line:
assert 'ConstantInit' in line
def test_update_init_info():
class DummyModel(BaseModule):
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.conv1 = nn.Conv2d(1, 1, 1, 1)
self.conv3 = nn.Conv2d(1, 1, 1, 1)
self.fc1 = nn.Linear(1, 1)
model = DummyModel()
from collections import defaultdict
model._params_init_info = defaultdict(dict)
for name, param in model.named_parameters():
model._params_init_info[param]['param_name'] = name
model._params_init_info[param]['init_info'] = 'init'
model._params_init_info[param]['tmp_mean_value'] = param.data.mean()
with torch.no_grad():
for p in model.parameters():
p.fill_(1)
update_init_info(model, init_info='fill_1')
for item in model._params_init_info.values():
assert item['init_info'] == 'fill_1'
assert item['tmp_mean_value'] == 1
def test_model_weight_init():
"""
Config
model (FooModel, Linear: weight=1, bias=2, Conv1d: weight=3, bias=4,
Conv2d: weight=5, bias=6)
├──component1 (FooConv1d)
├──component2 (FooConv2d)
├──component3 (FooLinear)
├──component4 (FooLinearConv1d)
├──linear (FooLinear)
├──conv1d (FooConv1d)
├──reg (nn.Linear)
Parameters after initialization
model (FooModel)
├──component1 (FooConv1d, weight=3, bias=4)
├──component2 (FooConv2d, weight=5, bias=6)
├──component3 (FooLinear, weight=1, bias=2)
├──component4 (FooLinearConv1d)
├──linear (FooLinear, weight=1, bias=2)
├──conv1d (FooConv1d, weight=3, bias=4)
├──reg (nn.Linear, weight=1, bias=2)
"""
model_cfg = dict(
type='FooModel',
init_cfg=[
dict(type='Constant', val=1, bias=2, layer='Linear'),
dict(type='Constant', val=3, bias=4, layer='Conv1d'),
dict(type='Constant', val=5, bias=6, layer='Conv2d')
],
component1=dict(type='FooConv1d'),
component2=dict(type='FooConv2d'),
component3=dict(type='FooLinear'),
component4=dict(
type='FooLinearConv1d',
linear=dict(type='FooLinear'),
conv1d=dict(type='FooConv1d')))
model = build_from_cfg(model_cfg, FOOMODELS)
model.init_weights()
assert torch.equal(model.component1.conv1d.weight,
torch.full(model.component1.conv1d.weight.shape, 3.0))
assert torch.equal(model.component1.conv1d.bias,
torch.full(model.component1.conv1d.bias.shape, 4.0))
assert torch.equal(model.component2.conv2d.weight,
torch.full(model.component2.conv2d.weight.shape, 5.0))
assert torch.equal(model.component2.conv2d.bias,
torch.full(model.component2.conv2d.bias.shape, 6.0))
assert torch.equal(model.component3.linear.weight,
torch.full(model.component3.linear.weight.shape, 1.0))
assert torch.equal(model.component3.linear.bias,
torch.full(model.component3.linear.bias.shape, 2.0))
assert torch.equal(
model.component4.linear.linear.weight,
torch.full(model.component4.linear.linear.weight.shape, 1.0))
assert torch.equal(
model.component4.linear.linear.bias,
torch.full(model.component4.linear.linear.bias.shape, 2.0))
assert torch.equal(
model.component4.conv1d.conv1d.weight,
torch.full(model.component4.conv1d.conv1d.weight.shape, 3.0))
assert torch.equal(
model.component4.conv1d.conv1d.bias,
torch.full(model.component4.conv1d.conv1d.bias.shape, 4.0))
assert torch.equal(model.reg.weight, torch.full(model.reg.weight.shape,
1.0))
assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 2.0))
def test_nest_components_weight_init():
"""
Config
model (FooModel, Linear: weight=1, bias=2, Conv1d: weight=3, bias=4,
Conv2d: weight=5, bias=6)
├──component1 (FooConv1d, Conv1d: weight=7, bias=8)
├──component2 (FooConv2d, Conv2d: weight=9, bias=10)
├──component3 (FooLinear)
├──component4 (FooLinearConv1d, Linear: weight=11, bias=12)
├──linear (FooLinear, Linear: weight=11, bias=12)
├──conv1d (FooConv1d)
├──reg (nn.Linear, weight=13, bias=14)
Parameters after initialization
model (FooModel)
├──component1 (FooConv1d, weight=7, bias=8)
├──component2 (FooConv2d, weight=9, bias=10)
├──component3 (FooLinear, weight=1, bias=2)
├──component4 (FooLinearConv1d)
├──linear (FooLinear, weight=1, bias=2)
├──conv1d (FooConv1d, weight=3, bias=4)
├──reg (nn.Linear, weight=13, bias=14)
"""
model_cfg = dict(
type='FooModel',
init_cfg=[
dict(
type='Constant',
val=1,
bias=2,
layer='Linear',
override=dict(type='Constant', name='reg', val=13, bias=14)),
dict(type='Constant', val=3, bias=4, layer='Conv1d'),
dict(type='Constant', val=5, bias=6, layer='Conv2d'),
],
component1=dict(
type='FooConv1d',
init_cfg=dict(type='Constant', layer='Conv1d', val=7, bias=8)),
component2=dict(
type='FooConv2d',
init_cfg=dict(type='Constant', layer='Conv2d', val=9, bias=10)),
component3=dict(type='FooLinear'),
component4=dict(
type='FooLinearConv1d',
linear=dict(type='FooLinear'),
conv1d=dict(type='FooConv1d')))
model = build_from_cfg(model_cfg, FOOMODELS)
model.init_weights()
assert torch.equal(model.component1.conv1d.weight,
torch.full(model.component1.conv1d.weight.shape, 7.0))
assert torch.equal(model.component1.conv1d.bias,
torch.full(model.component1.conv1d.bias.shape, 8.0))
assert torch.equal(model.component2.conv2d.weight,
torch.full(model.component2.conv2d.weight.shape, 9.0))
assert torch.equal(model.component2.conv2d.bias,
torch.full(model.component2.conv2d.bias.shape, 10.0))
assert torch.equal(model.component3.linear.weight,
torch.full(model.component3.linear.weight.shape, 1.0))
assert torch.equal(model.component3.linear.bias,
torch.full(model.component3.linear.bias.shape, 2.0))
assert torch.equal(
model.component4.linear.linear.weight,
torch.full(model.component4.linear.linear.weight.shape, 1.0))
assert torch.equal(
model.component4.linear.linear.bias,
torch.full(model.component4.linear.linear.bias.shape, 2.0))
assert torch.equal(
model.component4.conv1d.conv1d.weight,
torch.full(model.component4.conv1d.conv1d.weight.shape, 3.0))
assert torch.equal(
model.component4.conv1d.conv1d.bias,
torch.full(model.component4.conv1d.conv1d.bias.shape, 4.0))
assert torch.equal(model.reg.weight,
torch.full(model.reg.weight.shape, 13.0))
assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 14.0))
def test_without_layer_weight_init():
model_cfg = dict(
type='FooModel',
init_cfg=[
dict(type='Constant', val=1, bias=2, layer='Linear'),
dict(type='Constant', val=3, bias=4, layer='Conv1d'),
dict(type='Constant', val=5, bias=6, layer='Conv2d')
],
component1=dict(
type='FooConv1d', init_cfg=dict(type='Constant', val=7, bias=8)),
component2=dict(type='FooConv2d'),
component3=dict(type='FooLinear'))
model = build_from_cfg(model_cfg, FOOMODELS)
model.init_weights()
assert torch.equal(model.component1.conv1d.weight,
torch.full(model.component1.conv1d.weight.shape, 3.0))
assert torch.equal(model.component1.conv1d.bias,
torch.full(model.component1.conv1d.bias.shape, 4.0))
# init_cfg in component1 does not have layer key, so it does nothing
assert torch.equal(model.component2.conv2d.weight,
torch.full(model.component2.conv2d.weight.shape, 5.0))
assert torch.equal(model.component2.conv2d.bias,
torch.full(model.component2.conv2d.bias.shape, 6.0))
assert torch.equal(model.component3.linear.weight,
torch.full(model.component3.linear.weight.shape, 1.0))
assert torch.equal(model.component3.linear.bias,
torch.full(model.component3.linear.bias.shape, 2.0))
assert torch.equal(model.reg.weight, torch.full(model.reg.weight.shape,
1.0))
assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 2.0))
def test_override_weight_init():
# only initialize 'override'
model_cfg = dict(
type='FooModel',
init_cfg=[
dict(type='Constant', val=10, bias=20, override=dict(name='reg'))
],
component1=dict(type='FooConv1d'),
component3=dict(type='FooLinear'))
model = build_from_cfg(model_cfg, FOOMODELS)
model.init_weights()
assert torch.equal(model.reg.weight,
torch.full(model.reg.weight.shape, 10.0))
assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 20.0))
# do not initialize others
assert not torch.equal(
model.component1.conv1d.weight,
torch.full(model.component1.conv1d.weight.shape, 10.0))
assert not torch.equal(
model.component1.conv1d.bias,
torch.full(model.component1.conv1d.bias.shape, 20.0))
assert not torch.equal(
model.component3.linear.weight,
torch.full(model.component3.linear.weight.shape, 10.0))
assert not torch.equal(
model.component3.linear.bias,
torch.full(model.component3.linear.bias.shape, 20.0))
# 'override' has higher priority
model_cfg = dict(
type='FooModel',
init_cfg=[
dict(
type='Constant',
val=1,
bias=2,
override=dict(name='reg', type='Constant', val=30, bias=40))
],
component1=dict(type='FooConv1d'),
component2=dict(type='FooConv2d'),
component3=dict(type='FooLinear'))
model = build_from_cfg(model_cfg, FOOMODELS)
model.init_weights()
assert torch.equal(model.reg.weight,
torch.full(model.reg.weight.shape, 30.0))
assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 40.0))
def test_sequential_model_weight_init():
seq_model_cfg = [
dict(
type='FooConv1d',
init_cfg=dict(type='Constant', layer='Conv1d', val=0., bias=1.)),
dict(
type='FooConv2d',
init_cfg=dict(type='Constant', layer='Conv2d', val=2., bias=3.)),
]
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg]
seq_model = Sequential(*layers)
seq_model.init_weights()
assert torch.equal(seq_model[0].conv1d.weight,
torch.full(seq_model[0].conv1d.weight.shape, 0.))
assert torch.equal(seq_model[0].conv1d.bias,
torch.full(seq_model[0].conv1d.bias.shape, 1.))
assert torch.equal(seq_model[1].conv2d.weight,
torch.full(seq_model[1].conv2d.weight.shape, 2.))
assert torch.equal(seq_model[1].conv2d.bias,
torch.full(seq_model[1].conv2d.bias.shape, 3.))
# inner init_cfg has highter priority
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg]
seq_model = Sequential(
*layers,
init_cfg=dict(
type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.))
seq_model.init_weights()
assert torch.equal(seq_model[0].conv1d.weight,
torch.full(seq_model[0].conv1d.weight.shape, 0.))
assert torch.equal(seq_model[0].conv1d.bias,
torch.full(seq_model[0].conv1d.bias.shape, 1.))
assert torch.equal(seq_model[1].conv2d.weight,
torch.full(seq_model[1].conv2d.weight.shape, 2.))
assert torch.equal(seq_model[1].conv2d.bias,
torch.full(seq_model[1].conv2d.bias.shape, 3.))
def test_modulelist_weight_init():
models_cfg = [
dict(
type='FooConv1d',
init_cfg=dict(type='Constant', layer='Conv1d', val=0., bias=1.)),
dict(
type='FooConv2d',
init_cfg=dict(type='Constant', layer='Conv2d', val=2., bias=3.)),
]
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg]
modellist = ModuleList(layers)
modellist.init_weights()
assert torch.equal(modellist[0].conv1d.weight,
torch.full(modellist[0].conv1d.weight.shape, 0.))
assert torch.equal(modellist[0].conv1d.bias,
torch.full(modellist[0].conv1d.bias.shape, 1.))
assert torch.equal(modellist[1].conv2d.weight,
torch.full(modellist[1].conv2d.weight.shape, 2.))
assert torch.equal(modellist[1].conv2d.bias,
torch.full(modellist[1].conv2d.bias.shape, 3.))
# inner init_cfg has highter priority
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg]
modellist = ModuleList(
layers,
init_cfg=dict(
type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.))
modellist.init_weights()
assert torch.equal(modellist[0].conv1d.weight,
torch.full(modellist[0].conv1d.weight.shape, 0.))
assert torch.equal(modellist[0].conv1d.bias,
torch.full(modellist[0].conv1d.bias.shape, 1.))
assert torch.equal(modellist[1].conv2d.weight,
torch.full(modellist[1].conv2d.weight.shape, 2.))
assert torch.equal(modellist[1].conv2d.bias,
torch.full(modellist[1].conv2d.bias.shape, 3.))