# Copyright (c) OpenMMLab. All rights reserved. from unittest import TestCase import torch from torch import nn from mmengine.logging.logger import MMLogger from mmengine.model.base_module import BaseModule from mmengine.registry import Registry, build_from_cfg COMPONENTS = Registry('component') FOOMODELS = Registry('model') Logger = MMLogger.get_current_instance() @COMPONENTS.register_module() class FooConv1d(BaseModule): def __init__(self, init_cfg=None): super().__init__(init_cfg) self.conv1d = nn.Conv1d(4, 1, 4) def forward(self, x): return self.conv1d(x) @COMPONENTS.register_module() class FooConv2d(BaseModule): def __init__(self, init_cfg=None): super().__init__(init_cfg) self.conv2d = nn.Conv2d(3, 1, 3) def forward(self, x): return self.conv2d(x) @COMPONENTS.register_module() class FooLinear(BaseModule): def __init__(self, init_cfg=None): super().__init__(init_cfg) self.linear = nn.Linear(3, 4) def forward(self, x): return self.linear(x) @COMPONENTS.register_module() class FooLinearConv1d(BaseModule): def __init__(self, linear=None, conv1d=None, init_cfg=None): super().__init__(init_cfg) if linear is not None: self.linear = build_from_cfg(linear, COMPONENTS) if conv1d is not None: self.conv1d = build_from_cfg(conv1d, COMPONENTS) def forward(self, x): x = self.linear(x) return self.conv1d(x) @FOOMODELS.register_module() class FooModel(BaseModule): def __init__(self, component1=None, component2=None, component3=None, component4=None, init_cfg=None) -> None: super().__init__(init_cfg) if component1 is not None: self.component1 = build_from_cfg(component1, COMPONENTS) if component2 is not None: self.component2 = build_from_cfg(component2, COMPONENTS) if component3 is not None: self.component3 = build_from_cfg(component3, COMPONENTS) if component4 is not None: self.component4 = build_from_cfg(component4, COMPONENTS) # its type is not BaseModule, it can be initialized # with "override" key. self.reg = nn.Linear(3, 4) class TestBaseModule(TestCase): def setUp(self) -> None: self.BaseModule = BaseModule() self.model_cfg = dict( type='FooModel', init_cfg=[ dict(type='Constant', val=1, bias=2, layer='Linear'), dict(type='Constant', val=3, bias=4, layer='Conv1d'), dict(type='Constant', val=5, bias=6, layer='Conv2d') ], component1=dict(type='FooConv1d'), component2=dict(type='FooConv2d'), component3=dict(type='FooLinear'), component4=dict( type='FooLinearConv1d', linear=dict(type='FooLinear'), conv1d=dict(type='FooConv1d'))) self.model = build_from_cfg(self.model_cfg, FOOMODELS) def test_is_init(self): assert self.BaseModule.is_init is False def test_init_weights(self): """ Config model (FooModel, Linear: weight=1, bias=2, Conv1d: weight=3, bias=4, Conv2d: weight=5, bias=6) ├──component1 (FooConv1d) ├──component2 (FooConv2d) ├──component3 (FooLinear) ├──component4 (FooLinearConv1d) ├──linear (FooLinear) ├──conv1d (FooConv1d) ├──reg (nn.Linear) Parameters after initialization model (FooModel) ├──component1 (FooConv1d, weight=3, bias=4) ├──component2 (FooConv2d, weight=5, bias=6) ├──component3 (FooLinear, weight=1, bias=2) ├──component4 (FooLinearConv1d) ├──linear (FooLinear, weight=1, bias=2) ├──conv1d (FooConv1d, weight=3, bias=4) ├──reg (nn.Linear, weight=1, bias=2) """ self.model.init_weights() assert torch.equal( self.model.component1.conv1d.weight, torch.full(self.model.component1.conv1d.weight.shape, 3.0)) assert torch.equal( self.model.component1.conv1d.bias, torch.full(self.model.component1.conv1d.bias.shape, 4.0)) assert torch.equal( self.model.component2.conv2d.weight, torch.full(self.model.component2.conv2d.weight.shape, 5.0)) assert torch.equal( self.model.component2.conv2d.bias, torch.full(self.model.component2.conv2d.bias.shape, 6.0)) assert torch.equal( self.model.component3.linear.weight, torch.full(self.model.component3.linear.weight.shape, 1.0)) assert torch.equal( self.model.component3.linear.bias, torch.full(self.model.component3.linear.bias.shape, 2.0)) assert torch.equal( self.model.component4.linear.linear.weight, torch.full(self.model.component4.linear.linear.weight.shape, 1.0)) assert torch.equal( self.model.component4.linear.linear.bias, torch.full(self.model.component4.linear.linear.bias.shape, 2.0)) assert torch.equal( self.model.component4.conv1d.conv1d.weight, torch.full(self.model.component4.conv1d.conv1d.weight.shape, 3.0)) assert torch.equal( self.model.component4.conv1d.conv1d.bias, torch.full(self.model.component4.conv1d.conv1d.bias.shape, 4.0)) assert torch.equal(self.model.reg.weight, torch.full(self.model.reg.weight.shape, 1.0)) assert torch.equal(self.model.reg.bias, torch.full(self.model.reg.bias.shape, 2.0)) def test_dump_init_info(self): import os import shutil dump_dir = 'tests/test_model/test_dump_info' if not (os.path.exists(dump_dir) and os.path.isdir(dump_dir)): os.makedirs(dump_dir) for filename in os.listdir(dump_dir): file_path = os.path.join(dump_dir, filename) if os.path.isfile(file_path) or os.path.islink(file_path): os.unlink(file_path) elif os.path.isdir(file_path): shutil.rmtree(file_path) MMLogger.get_instance('logger1') # add logger without FileHandler model1 = build_from_cfg(self.model_cfg, FOOMODELS) model1.init_weights() assert len(os.listdir(dump_dir)) == 0 log_path = os.path.join(dump_dir, 'out.log') MMLogger.get_instance( 'logger2', log_file=log_path) # add logger with FileHandler model2 = build_from_cfg(self.model_cfg, FOOMODELS) model2.init_weights() assert len(os.listdir(dump_dir)) == 1 assert os.stat(log_path).st_size != 0 shutil.rmtree(dump_dir)