# Copyright (c) OpenMMLab. All rights reserved. import logging from unittest import TestCase import torch from torch import nn from mmengine.logging.logger import MMLogger from mmengine.model import BaseModule, ModuleDict, ModuleList, Sequential from mmengine.registry import Registry, build_from_cfg COMPONENTS = Registry('component') FOOMODELS = Registry('model') Logger = MMLogger.get_current_instance() @COMPONENTS.register_module() class FooConv1d(BaseModule): def __init__(self, init_cfg=None): super().__init__(init_cfg) self.conv1d = nn.Conv1d(4, 1, 4) def forward(self, x): return self.conv1d(x) @COMPONENTS.register_module() class FooConv2d(BaseModule): def __init__(self, init_cfg=None): super().__init__(init_cfg) self.conv2d = nn.Conv2d(3, 1, 3) def forward(self, x): return self.conv2d(x) @COMPONENTS.register_module() class FooLinear(BaseModule): def __init__(self, init_cfg=None): super().__init__(init_cfg) self.linear = nn.Linear(3, 4) def forward(self, x): return self.linear(x) @COMPONENTS.register_module() class FooLinearConv1d(BaseModule): def __init__(self, linear=None, conv1d=None, init_cfg=None): super().__init__(init_cfg) if linear is not None: self.linear = build_from_cfg(linear, COMPONENTS) if conv1d is not None: self.conv1d = build_from_cfg(conv1d, COMPONENTS) def forward(self, x): x = self.linear(x) return self.conv1d(x) @FOOMODELS.register_module() class FooModel(BaseModule): def __init__(self, component1=None, component2=None, component3=None, component4=None, init_cfg=None) -> None: super().__init__(init_cfg) if component1 is not None: self.component1 = build_from_cfg(component1, COMPONENTS) if component2 is not None: self.component2 = build_from_cfg(component2, COMPONENTS) if component3 is not None: self.component3 = build_from_cfg(component3, COMPONENTS) if component4 is not None: self.component4 = build_from_cfg(component4, COMPONENTS) # its type is not BaseModule, it can be initialized # with "override" key. self.reg = nn.Linear(3, 4) class TestBaseModule(TestCase): def setUp(self) -> None: self.BaseModule = BaseModule() self.model_cfg = dict( type='FooModel', init_cfg=[ dict(type='Constant', val=1, bias=2, layer='Linear'), dict(type='Constant', val=3, bias=4, layer='Conv1d'), dict(type='Constant', val=5, bias=6, layer='Conv2d') ], component1=dict(type='FooConv1d'), component2=dict(type='FooConv2d'), component3=dict(type='FooLinear'), component4=dict( type='FooLinearConv1d', linear=dict(type='FooLinear'), conv1d=dict(type='FooConv1d'))) self.model = build_from_cfg(self.model_cfg, FOOMODELS) self.logger = MMLogger.get_instance(self._testMethodName) def tearDown(self) -> None: logging.shutdown() MMLogger._instance_dict.clear() return super().tearDown() def test_is_init(self): assert self.BaseModule.is_init is False def test_init_weights(self): """ Config model (FooModel, Linear: weight=1, bias=2, Conv1d: weight=3, bias=4, Conv2d: weight=5, bias=6) ├──component1 (FooConv1d) ├──component2 (FooConv2d) ├──component3 (FooLinear) ├──component4 (FooLinearConv1d) ├──linear (FooLinear) ├──conv1d (FooConv1d) ├──reg (nn.Linear) Parameters after initialization model (FooModel) ├──component1 (FooConv1d, weight=3, bias=4) ├──component2 (FooConv2d, weight=5, bias=6) ├──component3 (FooLinear, weight=1, bias=2) ├──component4 (FooLinearConv1d) ├──linear (FooLinear, weight=1, bias=2) ├──conv1d (FooConv1d, weight=3, bias=4) ├──reg (nn.Linear, weight=1, bias=2) """ self.model.init_weights() assert torch.equal( self.model.component1.conv1d.weight, torch.full(self.model.component1.conv1d.weight.shape, 3.0)) assert torch.equal( self.model.component1.conv1d.bias, torch.full(self.model.component1.conv1d.bias.shape, 4.0)) assert torch.equal( self.model.component2.conv2d.weight, torch.full(self.model.component2.conv2d.weight.shape, 5.0)) assert torch.equal( self.model.component2.conv2d.bias, torch.full(self.model.component2.conv2d.bias.shape, 6.0)) assert torch.equal( self.model.component3.linear.weight, torch.full(self.model.component3.linear.weight.shape, 1.0)) assert torch.equal( self.model.component3.linear.bias, torch.full(self.model.component3.linear.bias.shape, 2.0)) assert torch.equal( self.model.component4.linear.linear.weight, torch.full(self.model.component4.linear.linear.weight.shape, 1.0)) assert torch.equal( self.model.component4.linear.linear.bias, torch.full(self.model.component4.linear.linear.bias.shape, 2.0)) assert torch.equal( self.model.component4.conv1d.conv1d.weight, torch.full(self.model.component4.conv1d.conv1d.weight.shape, 3.0)) assert torch.equal( self.model.component4.conv1d.conv1d.bias, torch.full(self.model.component4.conv1d.conv1d.bias.shape, 4.0)) assert torch.equal(self.model.reg.weight, torch.full(self.model.reg.weight.shape, 1.0)) assert torch.equal(self.model.reg.bias, torch.full(self.model.reg.bias.shape, 2.0)) def test_dump_init_info(self): import os import shutil dump_dir = 'tests/test_model/test_dump_info' if not (os.path.exists(dump_dir) and os.path.isdir(dump_dir)): os.makedirs(dump_dir) for filename in os.listdir(dump_dir): file_path = os.path.join(dump_dir, filename) if os.path.isfile(file_path) or os.path.islink(file_path): os.unlink(file_path) elif os.path.isdir(file_path): shutil.rmtree(file_path) MMLogger.get_instance('logger1') # add logger without FileHandler model1 = build_from_cfg(self.model_cfg, FOOMODELS) model1.init_weights() assert len(os.listdir(dump_dir)) == 0 log_path = os.path.join(dump_dir, 'out.log') MMLogger.get_instance( 'logger2', log_file=log_path) # add logger with FileHandler model2 = build_from_cfg(self.model_cfg, FOOMODELS) model2.init_weights() assert len(os.listdir(dump_dir)) == 1 assert os.stat(log_path).st_size != 0 # `FileHandler` should be closed in Windows, otherwise we cannot # delete the temporary directory logging.shutdown() MMLogger._instance_dict.clear() shutil.rmtree(dump_dir) class TestModuleList(TestCase): def test_modulelist_weight_init(self): models_cfg = [ dict( type='FooConv1d', init_cfg=dict( type='Constant', layer='Conv1d', val=0., bias=1.)), dict( type='FooConv2d', init_cfg=dict( type='Constant', layer='Conv2d', val=2., bias=3.)), ] layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg] modellist = ModuleList(layers) modellist.init_weights() self.assertTrue( torch.equal(modellist[0].conv1d.weight, torch.full(modellist[0].conv1d.weight.shape, 0.))) self.assertTrue( torch.equal(modellist[0].conv1d.bias, torch.full(modellist[0].conv1d.bias.shape, 1.))) self.assertTrue( torch.equal(modellist[1].conv2d.weight, torch.full(modellist[1].conv2d.weight.shape, 2.))) self.assertTrue( torch.equal(modellist[1].conv2d.bias, torch.full(modellist[1].conv2d.bias.shape, 3.))) # inner init_cfg has higher priority layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg] modellist = ModuleList( layers, init_cfg=dict( type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.)) modellist.init_weights() self.assertTrue( torch.equal(modellist[0].conv1d.weight, torch.full(modellist[0].conv1d.weight.shape, 0.))) self.assertTrue( torch.equal(modellist[0].conv1d.bias, torch.full(modellist[0].conv1d.bias.shape, 1.))) self.assertTrue( torch.equal(modellist[1].conv2d.weight, torch.full(modellist[1].conv2d.weight.shape, 2.))) self.assertTrue( torch.equal(modellist[1].conv2d.bias, torch.full(modellist[1].conv2d.bias.shape, 3.))) class TestModuleDict(TestCase): def test_moduledict_weight_init(self): models_cfg = dict( foo_conv_1d=dict( type='FooConv1d', init_cfg=dict( type='Constant', layer='Conv1d', val=0., bias=1.)), foo_conv_2d=dict( type='FooConv2d', init_cfg=dict( type='Constant', layer='Conv2d', val=2., bias=3.)), ) layers = { name: build_from_cfg(cfg, COMPONENTS) for name, cfg in models_cfg.items() } modeldict = ModuleDict(layers) modeldict.init_weights() self.assertTrue( torch.equal( modeldict['foo_conv_1d'].conv1d.weight, torch.full(modeldict['foo_conv_1d'].conv1d.weight.shape, 0.))) self.assertTrue( torch.equal( modeldict['foo_conv_1d'].conv1d.bias, torch.full(modeldict['foo_conv_1d'].conv1d.bias.shape, 1.))) self.assertTrue( torch.equal( modeldict['foo_conv_2d'].conv2d.weight, torch.full(modeldict['foo_conv_2d'].conv2d.weight.shape, 2.))) self.assertTrue( torch.equal( modeldict['foo_conv_2d'].conv2d.bias, torch.full(modeldict['foo_conv_2d'].conv2d.bias.shape, 3.))) # inner init_cfg has higher priority layers = { name: build_from_cfg(cfg, COMPONENTS) for name, cfg in models_cfg.items() } modeldict = ModuleDict( layers, init_cfg=dict( type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.)) modeldict.init_weights() self.assertTrue( torch.equal( modeldict['foo_conv_1d'].conv1d.weight, torch.full(modeldict['foo_conv_1d'].conv1d.weight.shape, 0.))) self.assertTrue( torch.equal( modeldict['foo_conv_1d'].conv1d.bias, torch.full(modeldict['foo_conv_1d'].conv1d.bias.shape, 1.))) self.assertTrue( torch.equal( modeldict['foo_conv_2d'].conv2d.weight, torch.full(modeldict['foo_conv_2d'].conv2d.weight.shape, 2.))) self.assertTrue( torch.equal( modeldict['foo_conv_2d'].conv2d.bias, torch.full(modeldict['foo_conv_2d'].conv2d.bias.shape, 3.))) class TestSequential(TestCase): def test_sequential_model_weight_init(self): seq_model_cfg = [ dict( type='FooConv1d', init_cfg=dict( type='Constant', layer='Conv1d', val=0., bias=1.)), dict( type='FooConv2d', init_cfg=dict( type='Constant', layer='Conv2d', val=2., bias=3.)), ] layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg] seq_model = Sequential(*layers) seq_model.init_weights() self.assertTrue( torch.equal(seq_model[0].conv1d.weight, torch.full(seq_model[0].conv1d.weight.shape, 0.))) self.assertTrue( torch.equal(seq_model[0].conv1d.bias, torch.full(seq_model[0].conv1d.bias.shape, 1.))) self.assertTrue( torch.equal(seq_model[1].conv2d.weight, torch.full(seq_model[1].conv2d.weight.shape, 2.))) self.assertTrue( torch.equal(seq_model[1].conv2d.bias, torch.full(seq_model[1].conv2d.bias.shape, 3.))) # inner init_cfg has higher priority layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg] seq_model = Sequential( *layers, init_cfg=dict( type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.)) seq_model.init_weights() self.assertTrue( torch.equal(seq_model[0].conv1d.weight, torch.full(seq_model[0].conv1d.weight.shape, 0.))) self.assertTrue( torch.equal(seq_model[0].conv1d.bias, torch.full(seq_model[0].conv1d.bias.shape, 1.))) self.assertTrue( torch.equal(seq_model[1].conv2d.weight, torch.full(seq_model[1].conv2d.weight.shape, 2.))) self.assertTrue( torch.equal(seq_model[1].conv2d.bias, torch.full(seq_model[1].conv2d.bias.shape, 3.)))