mmengine/tests/test_model/test_base_module.py

370 lines
14 KiB
Python
Raw Normal View History

2022-06-06 10:51:23 +08:00
# Copyright (c) OpenMMLab. All rights reserved.
2022-09-13 11:46:21 +08:00
import logging
2022-06-06 10:51:23 +08:00
from unittest import TestCase
import torch
from torch import nn
from mmengine.logging.logger import MMLogger
from mmengine.model import BaseModule, ModuleDict, ModuleList, Sequential
2022-06-06 10:51:23 +08:00
from mmengine.registry import Registry, build_from_cfg
COMPONENTS = Registry('component')
FOOMODELS = Registry('model')
Logger = MMLogger.get_current_instance()
@COMPONENTS.register_module()
class FooConv1d(BaseModule):
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.conv1d = nn.Conv1d(4, 1, 4)
def forward(self, x):
return self.conv1d(x)
@COMPONENTS.register_module()
class FooConv2d(BaseModule):
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.conv2d = nn.Conv2d(3, 1, 3)
def forward(self, x):
return self.conv2d(x)
@COMPONENTS.register_module()
class FooLinear(BaseModule):
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.linear = nn.Linear(3, 4)
def forward(self, x):
return self.linear(x)
@COMPONENTS.register_module()
class FooLinearConv1d(BaseModule):
def __init__(self, linear=None, conv1d=None, init_cfg=None):
super().__init__(init_cfg)
if linear is not None:
self.linear = build_from_cfg(linear, COMPONENTS)
if conv1d is not None:
self.conv1d = build_from_cfg(conv1d, COMPONENTS)
def forward(self, x):
x = self.linear(x)
return self.conv1d(x)
@FOOMODELS.register_module()
class FooModel(BaseModule):
def __init__(self,
component1=None,
component2=None,
component3=None,
component4=None,
init_cfg=None) -> None:
super().__init__(init_cfg)
if component1 is not None:
self.component1 = build_from_cfg(component1, COMPONENTS)
if component2 is not None:
self.component2 = build_from_cfg(component2, COMPONENTS)
if component3 is not None:
self.component3 = build_from_cfg(component3, COMPONENTS)
if component4 is not None:
self.component4 = build_from_cfg(component4, COMPONENTS)
# its type is not BaseModule, it can be initialized
# with "override" key.
self.reg = nn.Linear(3, 4)
class TestBaseModule(TestCase):
def setUp(self) -> None:
self.BaseModule = BaseModule()
self.model_cfg = dict(
type='FooModel',
init_cfg=[
dict(type='Constant', val=1, bias=2, layer='Linear'),
dict(type='Constant', val=3, bias=4, layer='Conv1d'),
dict(type='Constant', val=5, bias=6, layer='Conv2d')
],
component1=dict(type='FooConv1d'),
component2=dict(type='FooConv2d'),
component3=dict(type='FooLinear'),
component4=dict(
type='FooLinearConv1d',
linear=dict(type='FooLinear'),
conv1d=dict(type='FooConv1d')))
self.model = build_from_cfg(self.model_cfg, FOOMODELS)
2022-09-13 11:46:21 +08:00
self.logger = MMLogger.get_instance(self._testMethodName)
def tearDown(self) -> None:
logging.shutdown()
MMLogger._instance_dict.clear()
return super().tearDown()
2022-06-06 10:51:23 +08:00
def test_is_init(self):
assert self.BaseModule.is_init is False
def test_init_weights(self):
"""
Config
model (FooModel, Linear: weight=1, bias=2, Conv1d: weight=3, bias=4,
Conv2d: weight=5, bias=6)
component1 (FooConv1d)
component2 (FooConv2d)
component3 (FooLinear)
component4 (FooLinearConv1d)
linear (FooLinear)
conv1d (FooConv1d)
reg (nn.Linear)
Parameters after initialization
model (FooModel)
component1 (FooConv1d, weight=3, bias=4)
component2 (FooConv2d, weight=5, bias=6)
component3 (FooLinear, weight=1, bias=2)
component4 (FooLinearConv1d)
linear (FooLinear, weight=1, bias=2)
conv1d (FooConv1d, weight=3, bias=4)
reg (nn.Linear, weight=1, bias=2)
"""
self.model.init_weights()
assert torch.equal(
self.model.component1.conv1d.weight,
torch.full(self.model.component1.conv1d.weight.shape, 3.0))
assert torch.equal(
self.model.component1.conv1d.bias,
torch.full(self.model.component1.conv1d.bias.shape, 4.0))
assert torch.equal(
self.model.component2.conv2d.weight,
torch.full(self.model.component2.conv2d.weight.shape, 5.0))
assert torch.equal(
self.model.component2.conv2d.bias,
torch.full(self.model.component2.conv2d.bias.shape, 6.0))
assert torch.equal(
self.model.component3.linear.weight,
torch.full(self.model.component3.linear.weight.shape, 1.0))
assert torch.equal(
self.model.component3.linear.bias,
torch.full(self.model.component3.linear.bias.shape, 2.0))
assert torch.equal(
self.model.component4.linear.linear.weight,
torch.full(self.model.component4.linear.linear.weight.shape, 1.0))
assert torch.equal(
self.model.component4.linear.linear.bias,
torch.full(self.model.component4.linear.linear.bias.shape, 2.0))
assert torch.equal(
self.model.component4.conv1d.conv1d.weight,
torch.full(self.model.component4.conv1d.conv1d.weight.shape, 3.0))
assert torch.equal(
self.model.component4.conv1d.conv1d.bias,
torch.full(self.model.component4.conv1d.conv1d.bias.shape, 4.0))
assert torch.equal(self.model.reg.weight,
torch.full(self.model.reg.weight.shape, 1.0))
assert torch.equal(self.model.reg.bias,
torch.full(self.model.reg.bias.shape, 2.0))
def test_dump_init_info(self):
import os
import shutil
dump_dir = 'tests/test_model/test_dump_info'
if not (os.path.exists(dump_dir) and os.path.isdir(dump_dir)):
os.makedirs(dump_dir)
for filename in os.listdir(dump_dir):
file_path = os.path.join(dump_dir, filename)
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
MMLogger.get_instance('logger1') # add logger without FileHandler
model1 = build_from_cfg(self.model_cfg, FOOMODELS)
model1.init_weights()
assert len(os.listdir(dump_dir)) == 0
log_path = os.path.join(dump_dir, 'out.log')
MMLogger.get_instance(
'logger2', log_file=log_path) # add logger with FileHandler
model2 = build_from_cfg(self.model_cfg, FOOMODELS)
model2.init_weights()
assert len(os.listdir(dump_dir)) == 1
assert os.stat(log_path).st_size != 0
2022-09-13 11:46:21 +08:00
# `FileHandler` should be closed in Windows, otherwise we cannot
# delete the temporary directory
logging.shutdown()
MMLogger._instance_dict.clear()
2022-06-06 10:51:23 +08:00
shutil.rmtree(dump_dir)
class TestModuleList(TestCase):
def test_modulelist_weight_init(self):
models_cfg = [
dict(
type='FooConv1d',
init_cfg=dict(
type='Constant', layer='Conv1d', val=0., bias=1.)),
dict(
type='FooConv2d',
init_cfg=dict(
type='Constant', layer='Conv2d', val=2., bias=3.)),
]
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg]
modellist = ModuleList(layers)
modellist.init_weights()
self.assertTrue(
torch.equal(modellist[0].conv1d.weight,
torch.full(modellist[0].conv1d.weight.shape, 0.)))
self.assertTrue(
torch.equal(modellist[0].conv1d.bias,
torch.full(modellist[0].conv1d.bias.shape, 1.)))
self.assertTrue(
torch.equal(modellist[1].conv2d.weight,
torch.full(modellist[1].conv2d.weight.shape, 2.)))
self.assertTrue(
torch.equal(modellist[1].conv2d.bias,
torch.full(modellist[1].conv2d.bias.shape, 3.)))
# inner init_cfg has higher priority
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg]
modellist = ModuleList(
layers,
init_cfg=dict(
type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.))
modellist.init_weights()
self.assertTrue(
torch.equal(modellist[0].conv1d.weight,
torch.full(modellist[0].conv1d.weight.shape, 0.)))
self.assertTrue(
torch.equal(modellist[0].conv1d.bias,
torch.full(modellist[0].conv1d.bias.shape, 1.)))
self.assertTrue(
torch.equal(modellist[1].conv2d.weight,
torch.full(modellist[1].conv2d.weight.shape, 2.)))
self.assertTrue(
torch.equal(modellist[1].conv2d.bias,
torch.full(modellist[1].conv2d.bias.shape, 3.)))
class TestModuleDict(TestCase):
def test_moduledict_weight_init(self):
models_cfg = dict(
foo_conv_1d=dict(
type='FooConv1d',
init_cfg=dict(
type='Constant', layer='Conv1d', val=0., bias=1.)),
foo_conv_2d=dict(
type='FooConv2d',
init_cfg=dict(
type='Constant', layer='Conv2d', val=2., bias=3.)),
)
layers = {
name: build_from_cfg(cfg, COMPONENTS)
for name, cfg in models_cfg.items()
}
modeldict = ModuleDict(layers)
modeldict.init_weights()
self.assertTrue(
torch.equal(
modeldict['foo_conv_1d'].conv1d.weight,
torch.full(modeldict['foo_conv_1d'].conv1d.weight.shape, 0.)))
self.assertTrue(
torch.equal(
modeldict['foo_conv_1d'].conv1d.bias,
torch.full(modeldict['foo_conv_1d'].conv1d.bias.shape, 1.)))
self.assertTrue(
torch.equal(
modeldict['foo_conv_2d'].conv2d.weight,
torch.full(modeldict['foo_conv_2d'].conv2d.weight.shape, 2.)))
self.assertTrue(
torch.equal(
modeldict['foo_conv_2d'].conv2d.bias,
torch.full(modeldict['foo_conv_2d'].conv2d.bias.shape, 3.)))
# inner init_cfg has higher priority
layers = {
name: build_from_cfg(cfg, COMPONENTS)
for name, cfg in models_cfg.items()
}
modeldict = ModuleDict(
layers,
init_cfg=dict(
type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.))
modeldict.init_weights()
self.assertTrue(
torch.equal(
modeldict['foo_conv_1d'].conv1d.weight,
torch.full(modeldict['foo_conv_1d'].conv1d.weight.shape, 0.)))
self.assertTrue(
torch.equal(
modeldict['foo_conv_1d'].conv1d.bias,
torch.full(modeldict['foo_conv_1d'].conv1d.bias.shape, 1.)))
self.assertTrue(
torch.equal(
modeldict['foo_conv_2d'].conv2d.weight,
torch.full(modeldict['foo_conv_2d'].conv2d.weight.shape, 2.)))
self.assertTrue(
torch.equal(
modeldict['foo_conv_2d'].conv2d.bias,
torch.full(modeldict['foo_conv_2d'].conv2d.bias.shape, 3.)))
class TestSequential(TestCase):
def test_sequential_model_weight_init(self):
seq_model_cfg = [
dict(
type='FooConv1d',
init_cfg=dict(
type='Constant', layer='Conv1d', val=0., bias=1.)),
dict(
type='FooConv2d',
init_cfg=dict(
type='Constant', layer='Conv2d', val=2., bias=3.)),
]
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg]
seq_model = Sequential(*layers)
seq_model.init_weights()
self.assertTrue(
torch.equal(seq_model[0].conv1d.weight,
torch.full(seq_model[0].conv1d.weight.shape, 0.)))
self.assertTrue(
torch.equal(seq_model[0].conv1d.bias,
torch.full(seq_model[0].conv1d.bias.shape, 1.)))
self.assertTrue(
torch.equal(seq_model[1].conv2d.weight,
torch.full(seq_model[1].conv2d.weight.shape, 2.)))
self.assertTrue(
torch.equal(seq_model[1].conv2d.bias,
torch.full(seq_model[1].conv2d.bias.shape, 3.)))
# inner init_cfg has higher priority
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg]
seq_model = Sequential(
*layers,
init_cfg=dict(
type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.))
seq_model.init_weights()
self.assertTrue(
torch.equal(seq_model[0].conv1d.weight,
torch.full(seq_model[0].conv1d.weight.shape, 0.)))
self.assertTrue(
torch.equal(seq_model[0].conv1d.bias,
torch.full(seq_model[0].conv1d.bias.shape, 1.)))
self.assertTrue(
torch.equal(seq_model[1].conv2d.weight,
torch.full(seq_model[1].conv2d.weight.shape, 2.)))
self.assertTrue(
torch.equal(seq_model[1].conv2d.bias,
torch.full(seq_model[1].conv2d.bias.shape, 3.)))