mmengine/tests/test_model/test_base_module.py

370 lines
14 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import logging
from unittest import TestCase
import torch
from torch import nn
from mmengine.logging.logger import MMLogger
from mmengine.model import BaseModule, ModuleDict, ModuleList, Sequential
from mmengine.registry import Registry, build_from_cfg
COMPONENTS = Registry('component')
FOOMODELS = Registry('model')
Logger = MMLogger.get_current_instance()
@COMPONENTS.register_module()
class FooConv1d(BaseModule):
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.conv1d = nn.Conv1d(4, 1, 4)
def forward(self, x):
return self.conv1d(x)
@COMPONENTS.register_module()
class FooConv2d(BaseModule):
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.conv2d = nn.Conv2d(3, 1, 3)
def forward(self, x):
return self.conv2d(x)
@COMPONENTS.register_module()
class FooLinear(BaseModule):
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.linear = nn.Linear(3, 4)
def forward(self, x):
return self.linear(x)
@COMPONENTS.register_module()
class FooLinearConv1d(BaseModule):
def __init__(self, linear=None, conv1d=None, init_cfg=None):
super().__init__(init_cfg)
if linear is not None:
self.linear = build_from_cfg(linear, COMPONENTS)
if conv1d is not None:
self.conv1d = build_from_cfg(conv1d, COMPONENTS)
def forward(self, x):
x = self.linear(x)
return self.conv1d(x)
@FOOMODELS.register_module()
class FooModel(BaseModule):
def __init__(self,
component1=None,
component2=None,
component3=None,
component4=None,
init_cfg=None) -> None:
super().__init__(init_cfg)
if component1 is not None:
self.component1 = build_from_cfg(component1, COMPONENTS)
if component2 is not None:
self.component2 = build_from_cfg(component2, COMPONENTS)
if component3 is not None:
self.component3 = build_from_cfg(component3, COMPONENTS)
if component4 is not None:
self.component4 = build_from_cfg(component4, COMPONENTS)
# its type is not BaseModule, it can be initialized
# with "override" key.
self.reg = nn.Linear(3, 4)
class TestBaseModule(TestCase):
def setUp(self) -> None:
self.BaseModule = BaseModule()
self.model_cfg = dict(
type='FooModel',
init_cfg=[
dict(type='Constant', val=1, bias=2, layer='Linear'),
dict(type='Constant', val=3, bias=4, layer='Conv1d'),
dict(type='Constant', val=5, bias=6, layer='Conv2d')
],
component1=dict(type='FooConv1d'),
component2=dict(type='FooConv2d'),
component3=dict(type='FooLinear'),
component4=dict(
type='FooLinearConv1d',
linear=dict(type='FooLinear'),
conv1d=dict(type='FooConv1d')))
self.model = build_from_cfg(self.model_cfg, FOOMODELS)
self.logger = MMLogger.get_instance(self._testMethodName)
def tearDown(self) -> None:
logging.shutdown()
MMLogger._instance_dict.clear()
return super().tearDown()
def test_is_init(self):
assert self.BaseModule.is_init is False
def test_init_weights(self):
"""
Config
model (FooModel, Linear: weight=1, bias=2, Conv1d: weight=3, bias=4,
Conv2d: weight=5, bias=6)
├──component1 (FooConv1d)
├──component2 (FooConv2d)
├──component3 (FooLinear)
├──component4 (FooLinearConv1d)
├──linear (FooLinear)
├──conv1d (FooConv1d)
├──reg (nn.Linear)
Parameters after initialization
model (FooModel)
├──component1 (FooConv1d, weight=3, bias=4)
├──component2 (FooConv2d, weight=5, bias=6)
├──component3 (FooLinear, weight=1, bias=2)
├──component4 (FooLinearConv1d)
├──linear (FooLinear, weight=1, bias=2)
├──conv1d (FooConv1d, weight=3, bias=4)
├──reg (nn.Linear, weight=1, bias=2)
"""
self.model.init_weights()
assert torch.equal(
self.model.component1.conv1d.weight,
torch.full(self.model.component1.conv1d.weight.shape, 3.0))
assert torch.equal(
self.model.component1.conv1d.bias,
torch.full(self.model.component1.conv1d.bias.shape, 4.0))
assert torch.equal(
self.model.component2.conv2d.weight,
torch.full(self.model.component2.conv2d.weight.shape, 5.0))
assert torch.equal(
self.model.component2.conv2d.bias,
torch.full(self.model.component2.conv2d.bias.shape, 6.0))
assert torch.equal(
self.model.component3.linear.weight,
torch.full(self.model.component3.linear.weight.shape, 1.0))
assert torch.equal(
self.model.component3.linear.bias,
torch.full(self.model.component3.linear.bias.shape, 2.0))
assert torch.equal(
self.model.component4.linear.linear.weight,
torch.full(self.model.component4.linear.linear.weight.shape, 1.0))
assert torch.equal(
self.model.component4.linear.linear.bias,
torch.full(self.model.component4.linear.linear.bias.shape, 2.0))
assert torch.equal(
self.model.component4.conv1d.conv1d.weight,
torch.full(self.model.component4.conv1d.conv1d.weight.shape, 3.0))
assert torch.equal(
self.model.component4.conv1d.conv1d.bias,
torch.full(self.model.component4.conv1d.conv1d.bias.shape, 4.0))
assert torch.equal(self.model.reg.weight,
torch.full(self.model.reg.weight.shape, 1.0))
assert torch.equal(self.model.reg.bias,
torch.full(self.model.reg.bias.shape, 2.0))
def test_dump_init_info(self):
import os
import shutil
dump_dir = 'tests/test_model/test_dump_info'
if not (os.path.exists(dump_dir) and os.path.isdir(dump_dir)):
os.makedirs(dump_dir)
for filename in os.listdir(dump_dir):
file_path = os.path.join(dump_dir, filename)
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
MMLogger.get_instance('logger1') # add logger without FileHandler
model1 = build_from_cfg(self.model_cfg, FOOMODELS)
model1.init_weights()
assert len(os.listdir(dump_dir)) == 0
log_path = os.path.join(dump_dir, 'out.log')
MMLogger.get_instance(
'logger2', log_file=log_path) # add logger with FileHandler
model2 = build_from_cfg(self.model_cfg, FOOMODELS)
model2.init_weights()
assert len(os.listdir(dump_dir)) == 1
assert os.stat(log_path).st_size != 0
# `FileHandler` should be closed in Windows, otherwise we cannot
# delete the temporary directory
logging.shutdown()
MMLogger._instance_dict.clear()
shutil.rmtree(dump_dir)
class TestModuleList(TestCase):
def test_modulelist_weight_init(self):
models_cfg = [
dict(
type='FooConv1d',
init_cfg=dict(
type='Constant', layer='Conv1d', val=0., bias=1.)),
dict(
type='FooConv2d',
init_cfg=dict(
type='Constant', layer='Conv2d', val=2., bias=3.)),
]
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg]
modellist = ModuleList(layers)
modellist.init_weights()
self.assertTrue(
torch.equal(modellist[0].conv1d.weight,
torch.full(modellist[0].conv1d.weight.shape, 0.)))
self.assertTrue(
torch.equal(modellist[0].conv1d.bias,
torch.full(modellist[0].conv1d.bias.shape, 1.)))
self.assertTrue(
torch.equal(modellist[1].conv2d.weight,
torch.full(modellist[1].conv2d.weight.shape, 2.)))
self.assertTrue(
torch.equal(modellist[1].conv2d.bias,
torch.full(modellist[1].conv2d.bias.shape, 3.)))
# inner init_cfg has higher priority
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg]
modellist = ModuleList(
layers,
init_cfg=dict(
type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.))
modellist.init_weights()
self.assertTrue(
torch.equal(modellist[0].conv1d.weight,
torch.full(modellist[0].conv1d.weight.shape, 0.)))
self.assertTrue(
torch.equal(modellist[0].conv1d.bias,
torch.full(modellist[0].conv1d.bias.shape, 1.)))
self.assertTrue(
torch.equal(modellist[1].conv2d.weight,
torch.full(modellist[1].conv2d.weight.shape, 2.)))
self.assertTrue(
torch.equal(modellist[1].conv2d.bias,
torch.full(modellist[1].conv2d.bias.shape, 3.)))
class TestModuleDict(TestCase):
def test_moduledict_weight_init(self):
models_cfg = dict(
foo_conv_1d=dict(
type='FooConv1d',
init_cfg=dict(
type='Constant', layer='Conv1d', val=0., bias=1.)),
foo_conv_2d=dict(
type='FooConv2d',
init_cfg=dict(
type='Constant', layer='Conv2d', val=2., bias=3.)),
)
layers = {
name: build_from_cfg(cfg, COMPONENTS)
for name, cfg in models_cfg.items()
}
modeldict = ModuleDict(layers)
modeldict.init_weights()
self.assertTrue(
torch.equal(
modeldict['foo_conv_1d'].conv1d.weight,
torch.full(modeldict['foo_conv_1d'].conv1d.weight.shape, 0.)))
self.assertTrue(
torch.equal(
modeldict['foo_conv_1d'].conv1d.bias,
torch.full(modeldict['foo_conv_1d'].conv1d.bias.shape, 1.)))
self.assertTrue(
torch.equal(
modeldict['foo_conv_2d'].conv2d.weight,
torch.full(modeldict['foo_conv_2d'].conv2d.weight.shape, 2.)))
self.assertTrue(
torch.equal(
modeldict['foo_conv_2d'].conv2d.bias,
torch.full(modeldict['foo_conv_2d'].conv2d.bias.shape, 3.)))
# inner init_cfg has higher priority
layers = {
name: build_from_cfg(cfg, COMPONENTS)
for name, cfg in models_cfg.items()
}
modeldict = ModuleDict(
layers,
init_cfg=dict(
type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.))
modeldict.init_weights()
self.assertTrue(
torch.equal(
modeldict['foo_conv_1d'].conv1d.weight,
torch.full(modeldict['foo_conv_1d'].conv1d.weight.shape, 0.)))
self.assertTrue(
torch.equal(
modeldict['foo_conv_1d'].conv1d.bias,
torch.full(modeldict['foo_conv_1d'].conv1d.bias.shape, 1.)))
self.assertTrue(
torch.equal(
modeldict['foo_conv_2d'].conv2d.weight,
torch.full(modeldict['foo_conv_2d'].conv2d.weight.shape, 2.)))
self.assertTrue(
torch.equal(
modeldict['foo_conv_2d'].conv2d.bias,
torch.full(modeldict['foo_conv_2d'].conv2d.bias.shape, 3.)))
class TestSequential(TestCase):
def test_sequential_model_weight_init(self):
seq_model_cfg = [
dict(
type='FooConv1d',
init_cfg=dict(
type='Constant', layer='Conv1d', val=0., bias=1.)),
dict(
type='FooConv2d',
init_cfg=dict(
type='Constant', layer='Conv2d', val=2., bias=3.)),
]
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg]
seq_model = Sequential(*layers)
seq_model.init_weights()
self.assertTrue(
torch.equal(seq_model[0].conv1d.weight,
torch.full(seq_model[0].conv1d.weight.shape, 0.)))
self.assertTrue(
torch.equal(seq_model[0].conv1d.bias,
torch.full(seq_model[0].conv1d.bias.shape, 1.)))
self.assertTrue(
torch.equal(seq_model[1].conv2d.weight,
torch.full(seq_model[1].conv2d.weight.shape, 2.)))
self.assertTrue(
torch.equal(seq_model[1].conv2d.bias,
torch.full(seq_model[1].conv2d.bias.shape, 3.)))
# inner init_cfg has higher priority
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg]
seq_model = Sequential(
*layers,
init_cfg=dict(
type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.))
seq_model.init_weights()
self.assertTrue(
torch.equal(seq_model[0].conv1d.weight,
torch.full(seq_model[0].conv1d.weight.shape, 0.)))
self.assertTrue(
torch.equal(seq_model[0].conv1d.bias,
torch.full(seq_model[0].conv1d.bias.shape, 1.)))
self.assertTrue(
torch.equal(seq_model[1].conv2d.weight,
torch.full(seq_model[1].conv2d.weight.shape, 2.)))
self.assertTrue(
torch.equal(seq_model[1].conv2d.bias,
torch.full(seq_model[1].conv2d.bias.shape, 3.)))