[Enhance] Support skipping initialization in `BaseModule` (#1263)
parent
6187595677
commit
3871881ef6
|
@ -59,6 +59,10 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
|
|||
def is_init(self):
|
||||
return self._is_init
|
||||
|
||||
@is_init.setter
|
||||
def is_init(self, value):
|
||||
self._is_init = value
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize the weights."""
|
||||
|
||||
|
@ -127,7 +131,8 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
|
|||
for m in self.children():
|
||||
if is_model_wrapper(m) and not hasattr(m, 'init_weights'):
|
||||
m = m.module
|
||||
if hasattr(m, 'init_weights'):
|
||||
if hasattr(m, 'init_weights') and not getattr(
|
||||
m, 'is_init', False):
|
||||
m.init_weights()
|
||||
# users may overload the `init_weights`
|
||||
update_init_info(
|
||||
|
|
|
@ -4,7 +4,7 @@ import logging
|
|||
import os.path as osp
|
||||
import tempfile
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
@ -238,6 +238,30 @@ class TestBaseModule(TestCase):
|
|||
self.assertTrue((model.ddp.module.linear.weight == 1).all())
|
||||
self.assertTrue((model.ddp.module.linear.bias == 2).all())
|
||||
|
||||
# Test submodule.init_weights will be skipped if `is_init` is set
|
||||
# to True in root model
|
||||
model: FooModel = FOOMODELS.build(copy.deepcopy(self.model_cfg))
|
||||
for child in model.children():
|
||||
child.init_weights = Mock()
|
||||
model.is_init = True
|
||||
model.init_weights()
|
||||
for child in model.children():
|
||||
child.init_weights.assert_not_called()
|
||||
|
||||
# Test submodule.init_weights will be skipped if submodule's `is_init`
|
||||
# is set to True
|
||||
model: FooModel = FOOMODELS.build(copy.deepcopy(self.model_cfg))
|
||||
for child in model.children():
|
||||
child.init_weights = Mock()
|
||||
model.component1.is_init = True
|
||||
model.reg.is_init = True
|
||||
model.init_weights()
|
||||
model.component1.init_weights.assert_not_called()
|
||||
model.component2.init_weights.assert_called_once()
|
||||
model.component3.init_weights.assert_called_once()
|
||||
model.component4.init_weights.assert_called_once()
|
||||
model.reg.init_weights.assert_not_called()
|
||||
|
||||
def test_dump_init_info(self):
|
||||
import os
|
||||
import shutil
|
||||
|
|
Loading…
Reference in New Issue