# Copyright (c) OpenMMLab. All rights reserved. import copy import logging import os.path as osp import tempfile from unittest import TestCase import torch from torch import nn from torch.nn.init import constant_ from mmengine.logging.logger import MMLogger from mmengine.model import BaseModule, ModuleDict, ModuleList, Sequential from mmengine.registry import Registry, build_from_cfg COMPONENTS = Registry('component') FOOMODELS = Registry('model') Logger = MMLogger.get_current_instance() @COMPONENTS.register_module() class FooConv1d(BaseModule): def __init__(self, init_cfg=None): super().__init__(init_cfg) self.conv1d = nn.Conv1d(4, 1, 4) def forward(self, x): return self.conv1d(x) @COMPONENTS.register_module() class FooConv2d(BaseModule): def __init__(self, init_cfg=None): super().__init__(init_cfg) self.conv2d = nn.Conv2d(3, 1, 3) def forward(self, x): return self.conv2d(x) @COMPONENTS.register_module() class FooLinear(BaseModule): def __init__(self, init_cfg=None): super().__init__(init_cfg) self.linear = nn.Linear(3, 4) def forward(self, x): return self.linear(x) @COMPONENTS.register_module() class FooLinearConv1d(BaseModule): def __init__(self, linear=None, conv1d=None, init_cfg=None): super().__init__(init_cfg) if linear is not None: self.linear = build_from_cfg(linear, COMPONENTS) if conv1d is not None: self.conv1d = build_from_cfg(conv1d, COMPONENTS) def forward(self, x): x = self.linear(x) return self.conv1d(x) @FOOMODELS.register_module() class FooModel(BaseModule): def __init__(self, component1=None, component2=None, component3=None, component4=None, init_cfg=None) -> None: super().__init__(init_cfg) if component1 is not None: self.component1 = build_from_cfg(component1, COMPONENTS) if component2 is not None: self.component2 = build_from_cfg(component2, COMPONENTS) if component3 is not None: self.component3 = build_from_cfg(component3, COMPONENTS) if component4 is not None: self.component4 = build_from_cfg(component4, COMPONENTS) # its type is not BaseModule, it can be initialized # with "override" key. self.reg = nn.Linear(3, 4) class TestBaseModule(TestCase): def setUp(self) -> None: self.temp_dir = tempfile.TemporaryDirectory() self.BaseModule = BaseModule() self.model_cfg = dict( type='FooModel', init_cfg=[ dict(type='Constant', val=1, bias=2, layer='Linear'), dict(type='Constant', val=3, bias=4, layer='Conv1d'), dict(type='Constant', val=5, bias=6, layer='Conv2d') ], component1=dict(type='FooConv1d'), component2=dict(type='FooConv2d'), component3=dict(type='FooLinear'), component4=dict( type='FooLinearConv1d', linear=dict(type='FooLinear'), conv1d=dict(type='FooConv1d'))) self.model = build_from_cfg(self.model_cfg, FOOMODELS) self.logger = MMLogger.get_instance(self._testMethodName) def tearDown(self) -> None: self.temp_dir.cleanup() logging.shutdown() MMLogger._instance_dict.clear() return super().tearDown() def test_is_init(self): assert self.BaseModule.is_init is False def test_init_weights(self): """ Config model (FooModel, Linear: weight=1, bias=2, Conv1d: weight=3, bias=4, Conv2d: weight=5, bias=6) ├──component1 (FooConv1d) ├──component2 (FooConv2d) ├──component3 (FooLinear) ├──component4 (FooLinearConv1d) ├──linear (FooLinear) ├──conv1d (FooConv1d) ├──reg (nn.Linear) Parameters after initialization model (FooModel) ├──component1 (FooConv1d, weight=3, bias=4) ├──component2 (FooConv2d, weight=5, bias=6) ├──component3 (FooLinear, weight=1, bias=2) ├──component4 (FooLinearConv1d) ├──linear (FooLinear, weight=1, bias=2) ├──conv1d (FooConv1d, weight=3, bias=4) ├──reg (nn.Linear, weight=1, bias=2) """ self.model.init_weights() assert torch.equal( self.model.component1.conv1d.weight, torch.full(self.model.component1.conv1d.weight.shape, 3.0)) assert torch.equal( self.model.component1.conv1d.bias, torch.full(self.model.component1.conv1d.bias.shape, 4.0)) assert torch.equal( self.model.component2.conv2d.weight, torch.full(self.model.component2.conv2d.weight.shape, 5.0)) assert torch.equal( self.model.component2.conv2d.bias, torch.full(self.model.component2.conv2d.bias.shape, 6.0)) assert torch.equal( self.model.component3.linear.weight, torch.full(self.model.component3.linear.weight.shape, 1.0)) assert torch.equal( self.model.component3.linear.bias, torch.full(self.model.component3.linear.bias.shape, 2.0)) assert torch.equal( self.model.component4.linear.linear.weight, torch.full(self.model.component4.linear.linear.weight.shape, 1.0)) assert torch.equal( self.model.component4.linear.linear.bias, torch.full(self.model.component4.linear.linear.bias.shape, 2.0)) assert torch.equal( self.model.component4.conv1d.conv1d.weight, torch.full(self.model.component4.conv1d.conv1d.weight.shape, 3.0)) assert torch.equal( self.model.component4.conv1d.conv1d.bias, torch.full(self.model.component4.conv1d.conv1d.bias.shape, 4.0)) assert torch.equal(self.model.reg.weight, torch.full(self.model.reg.weight.shape, 1.0)) assert torch.equal(self.model.reg.bias, torch.full(self.model.reg.bias.shape, 2.0)) # Test build model from Pretrained weights class CustomLinear(BaseModule): def __init__(self, init_cfg=None): super().__init__(init_cfg) self.linear = nn.Linear(1, 1) def init_weights(self): constant_(self.linear.weight, 1) constant_(self.linear.bias, 2) @FOOMODELS.register_module() class PratrainedModel(FooModel): def __init__(self, component1=None, component2=None, component3=None, component4=None, init_cfg=None) -> None: super().__init__(component1, component2, component3, component4, init_cfg) self.linear = CustomLinear() checkpoint_path = osp.join(self.temp_dir.name, 'test.pth') torch.save(self.model.state_dict(), checkpoint_path) model_cfg = copy.deepcopy(self.model_cfg) model_cfg['type'] = 'PratrainedModel' model_cfg['init_cfg'] = dict( type='Pretrained', checkpoint=checkpoint_path) model = FOOMODELS.build(model_cfg) ori_layer_weight = model.linear.linear.weight.clone() ori_layer_bias = model.linear.linear.bias.clone() model.init_weights() self.assertTrue((ori_layer_weight != model.linear.linear.weight).any()) self.assertTrue((ori_layer_bias != model.linear.linear.bias).any()) def test_dump_init_info(self): import os import shutil dump_dir = 'tests/test_model/test_dump_info' if not (os.path.exists(dump_dir) and os.path.isdir(dump_dir)): os.makedirs(dump_dir) for filename in os.listdir(dump_dir): file_path = os.path.join(dump_dir, filename) if os.path.isfile(file_path) or os.path.islink(file_path): os.unlink(file_path) elif os.path.isdir(file_path): shutil.rmtree(file_path) MMLogger.get_instance('logger1') # add logger without FileHandler model1 = build_from_cfg(self.model_cfg, FOOMODELS) model1.init_weights() assert len(os.listdir(dump_dir)) == 0 log_path = os.path.join(dump_dir, 'out.log') MMLogger.get_instance( 'logger2', log_file=log_path) # add logger with FileHandler model2 = build_from_cfg(self.model_cfg, FOOMODELS) model2.init_weights() assert len(os.listdir(dump_dir)) == 1 assert os.stat(log_path).st_size != 0 # `FileHandler` should be closed in Windows, otherwise we cannot # delete the temporary directory logging.shutdown() MMLogger._instance_dict.clear() shutil.rmtree(dump_dir) class TestModuleList(TestCase): def test_modulelist_weight_init(self): models_cfg = [ dict( type='FooConv1d', init_cfg=dict( type='Constant', layer='Conv1d', val=0., bias=1.)), dict( type='FooConv2d', init_cfg=dict( type='Constant', layer='Conv2d', val=2., bias=3.)), ] layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg] modellist = ModuleList(layers) modellist.init_weights() self.assertTrue( torch.equal(modellist[0].conv1d.weight, torch.full(modellist[0].conv1d.weight.shape, 0.))) self.assertTrue( torch.equal(modellist[0].conv1d.bias, torch.full(modellist[0].conv1d.bias.shape, 1.))) self.assertTrue( torch.equal(modellist[1].conv2d.weight, torch.full(modellist[1].conv2d.weight.shape, 2.))) self.assertTrue( torch.equal(modellist[1].conv2d.bias, torch.full(modellist[1].conv2d.bias.shape, 3.))) # inner init_cfg has higher priority layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg] modellist = ModuleList( layers, init_cfg=dict( type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.)) modellist.init_weights() self.assertTrue( torch.equal(modellist[0].conv1d.weight, torch.full(modellist[0].conv1d.weight.shape, 0.))) self.assertTrue( torch.equal(modellist[0].conv1d.bias, torch.full(modellist[0].conv1d.bias.shape, 1.))) self.assertTrue( torch.equal(modellist[1].conv2d.weight, torch.full(modellist[1].conv2d.weight.shape, 2.))) self.assertTrue( torch.equal(modellist[1].conv2d.bias, torch.full(modellist[1].conv2d.bias.shape, 3.))) class TestModuleDict(TestCase): def test_moduledict_weight_init(self): models_cfg = dict( foo_conv_1d=dict( type='FooConv1d', init_cfg=dict( type='Constant', layer='Conv1d', val=0., bias=1.)), foo_conv_2d=dict( type='FooConv2d', init_cfg=dict( type='Constant', layer='Conv2d', val=2., bias=3.)), ) layers = { name: build_from_cfg(cfg, COMPONENTS) for name, cfg in models_cfg.items() } modeldict = ModuleDict(layers) modeldict.init_weights() self.assertTrue( torch.equal( modeldict['foo_conv_1d'].conv1d.weight, torch.full(modeldict['foo_conv_1d'].conv1d.weight.shape, 0.))) self.assertTrue( torch.equal( modeldict['foo_conv_1d'].conv1d.bias, torch.full(modeldict['foo_conv_1d'].conv1d.bias.shape, 1.))) self.assertTrue( torch.equal( modeldict['foo_conv_2d'].conv2d.weight, torch.full(modeldict['foo_conv_2d'].conv2d.weight.shape, 2.))) self.assertTrue( torch.equal( modeldict['foo_conv_2d'].conv2d.bias, torch.full(modeldict['foo_conv_2d'].conv2d.bias.shape, 3.))) # inner init_cfg has higher priority layers = { name: build_from_cfg(cfg, COMPONENTS) for name, cfg in models_cfg.items() } modeldict = ModuleDict( layers, init_cfg=dict( type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.)) modeldict.init_weights() self.assertTrue( torch.equal( modeldict['foo_conv_1d'].conv1d.weight, torch.full(modeldict['foo_conv_1d'].conv1d.weight.shape, 0.))) self.assertTrue( torch.equal( modeldict['foo_conv_1d'].conv1d.bias, torch.full(modeldict['foo_conv_1d'].conv1d.bias.shape, 1.))) self.assertTrue( torch.equal( modeldict['foo_conv_2d'].conv2d.weight, torch.full(modeldict['foo_conv_2d'].conv2d.weight.shape, 2.))) self.assertTrue( torch.equal( modeldict['foo_conv_2d'].conv2d.bias, torch.full(modeldict['foo_conv_2d'].conv2d.bias.shape, 3.))) class TestSequential(TestCase): def test_sequential_model_weight_init(self): seq_model_cfg = [ dict( type='FooConv1d', init_cfg=dict( type='Constant', layer='Conv1d', val=0., bias=1.)), dict( type='FooConv2d', init_cfg=dict( type='Constant', layer='Conv2d', val=2., bias=3.)), ] layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg] seq_model = Sequential(*layers) seq_model.init_weights() self.assertTrue( torch.equal(seq_model[0].conv1d.weight, torch.full(seq_model[0].conv1d.weight.shape, 0.))) self.assertTrue( torch.equal(seq_model[0].conv1d.bias, torch.full(seq_model[0].conv1d.bias.shape, 1.))) self.assertTrue( torch.equal(seq_model[1].conv2d.weight, torch.full(seq_model[1].conv2d.weight.shape, 2.))) self.assertTrue( torch.equal(seq_model[1].conv2d.bias, torch.full(seq_model[1].conv2d.bias.shape, 3.))) # inner init_cfg has higher priority layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg] seq_model = Sequential( *layers, init_cfg=dict( type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.)) seq_model.init_weights() self.assertTrue( torch.equal(seq_model[0].conv1d.weight, torch.full(seq_model[0].conv1d.weight.shape, 0.))) self.assertTrue( torch.equal(seq_model[0].conv1d.bias, torch.full(seq_model[0].conv1d.bias.shape, 1.))) self.assertTrue( torch.equal(seq_model[1].conv2d.weight, torch.full(seq_model[1].conv2d.weight.shape, 2.))) self.assertTrue( torch.equal(seq_model[1].conv2d.bias, torch.full(seq_model[1].conv2d.bias.shape, 3.)))