diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index a833570e..9d364b34 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -658,6 +658,10 @@ class Runner: def build_model(self, model: Union[nn.Module, Dict]) -> nn.Module: """Build model. + If ``model`` is a dict, it will be used to build a nn.Module object + and initialize the weights if it has ``init_weights`` method. + Else, if ``model`` is a nn.Module object it will be returned directly. + An example of ``model``:: model = dict(type='ResNet') @@ -673,7 +677,11 @@ class Runner: if isinstance(model, nn.Module): return model elif isinstance(model, dict): - return MODELS.build(model) + model = MODELS.build(model) + # init weights + if hasattr(model, 'init_weights'): + model.init_weights() + return model else: raise TypeError('model should be a nn.Module object or dict, ' f'but got {model}') diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index 25fbdbcf..0fdd8553 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -511,6 +511,25 @@ class TestRunner(TestCase): model = runner.build_model(dict(type='ToyModel1')) self.assertIsInstance(model, ToyModel1) + # test init weights + @MODELS.register_module() + class ToyModel2(ToyModel): + + def __init__(self): + super().__init__() + self.initiailzed = False + + def init_weights(self): + self.initiailzed = True + + model = runner.build_model(dict(type='ToyModel2')) + self.assertTrue(model.initiailzed) + + # test init weights with model object + _model = ToyModel2() + model = runner.build_model(_model) + self.assertFalse(model.initiailzed) + def test_wrap_model(self): # TODO: test on distributed environment # custom model wrapper