From 1241c21296fda095afb309e465c43f9dfd10b573 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Tue, 19 Jul 2022 18:28:57 +0800 Subject: [PATCH] [Fix] Fix weight initializing in test and refine registry logging. (#367) * [Fix] Fix weight initializing and registry logging. * sync params * resolve comments --- mmengine/registry/build_functions.py | 7 +++-- mmengine/registry/registry.py | 9 ++++--- mmengine/runner/runner.py | 18 ++++++++++--- tests/test_runner/test_runner.py | 38 ++++++++++++++-------------- 4 files changed, 43 insertions(+), 29 deletions(-) diff --git a/mmengine/registry/build_functions.py b/mmengine/registry/build_functions.py index f2857e0a..0f4dbb96 100644 --- a/mmengine/registry/build_functions.py +++ b/mmengine/registry/build_functions.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import inspect +import logging from typing import Any, Optional, Union from ..config import Config, ConfigDict @@ -116,7 +117,8 @@ def build_from_cfg( f'An `{obj_cls.__name__}` instance is built from ' # type: ignore # noqa: E501 'registry, its implementation can be found in ' f'{obj_cls.__module__}', # type: ignore - logger='current') + logger='current', + level=logging.DEBUG) return obj except Exception as e: @@ -188,7 +190,8 @@ def build_runner_from_cfg(cfg: Union[dict, ConfigDict, Config], f'An `{runner_cls.__name__}` instance is built from ' # type: ignore # noqa: E501 'registry, its implementation can be found in' f'{runner_cls.__module__}', # type: ignore - logger='current') + logger='current', + level=logging.DEBUG) return runner except Exception as e: diff --git a/mmengine/registry/registry.py b/mmengine/registry/registry.py index 6b42fef4..351aabee 100644 --- a/mmengine/registry/registry.py +++ b/mmengine/registry/registry.py @@ -318,7 +318,7 @@ class Registry: >>> mobilenet_cls = DETECTORS.get('cls.MobileNet') """ # Avoid circular import - from ..logging.logger import MMLogger + from ..logging import print_log scope, real_key = self.split_scope_key(key) obj_cls = None @@ -356,10 +356,11 @@ class Registry: obj_cls = root.get(key) if obj_cls is not None: - logger: MMLogger = MMLogger.get_current_instance() - logger.info( + print_log( f'Get class `{obj_cls.__name__}` from "{registry_name}"' - f' registry in "{scope_name}"') + f' registry in "{scope_name}"', + logger='current', + level=logging.DEBUG) return obj_cls def _search_child(self, scope: str) -> Optional['Registry']: diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index e3f74887..4ae1c086 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -797,8 +797,6 @@ class Runner: elif isinstance(model, dict): model = MODELS.build(model) # init weights - if hasattr(model, 'init_weights'): # type: ignore - model.init_weights() # type: ignore return model # type: ignore else: raise TypeError('model should be a nn.Module object or dict, ' @@ -870,6 +868,17 @@ class Runner: model_wrapper_cfg, default_args=default_args) return model + def _init_model_weights(self) -> None: + """Initialize the model weights if the model has + :meth:`init_weights`""" + model = self.model.module if is_model_wrapper( + self.model) else self.model + if hasattr(model, 'init_weights'): + model.init_weights() + # sync params and buffers + for name, params in model.state_dict().items(): + broadcast(params) + def scale_lr(self, optim_wrapper: OptimWrapper, auto_scale_lr: Optional[Dict] = None) -> None: @@ -1606,10 +1615,11 @@ class Runner: if self._val_loop is not None: self._val_loop = self.build_val_loop( self._val_loop) # type: ignore - + # TODO: add a contextmanager to avoid calling `before_run` many times self.call_hook('before_run') - # TODO: add a contextmanager to avoid calling `before_run` many times + # initialize the model weights + self._init_model_weights() # make sure checkpoint-related hooks are triggered after `before_run` self.load_or_resume() diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index 2dcc4011..6ede909f 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -701,25 +701,6 @@ class TestRunner(TestCase): model = runner.build_model(dict(type='ToyModel1')) self.assertIsInstance(model, ToyModel1) - # test init weights - @MODELS.register_module() - class ToyModel2(ToyModel): - - def __init__(self): - super().__init__() - self.initiailzed = False - - def init_weights(self): - self.initiailzed = True - - model = runner.build_model(dict(type='ToyModel2')) - self.assertTrue(model.initiailzed) - - # test init weights with model object - _model = ToyModel2() - model = runner.build_model(_model) - self.assertFalse(model.initiailzed) - def test_wrap_model(self): # revert sync batchnorm cfg = copy.deepcopy(self.epoch_based_cfg) @@ -1390,6 +1371,25 @@ class TestRunner(TestCase): for result, target, in zip(val_interval_results, val_interval_targets): self.assertEqual(result, target) + # 7. test init weights + @MODELS.register_module() + class ToyModel2(ToyModel): + + def __init__(self): + super().__init__() + self.initiailzed = False + + def init_weights(self): + self.initiailzed = True + + cfg = copy.deepcopy(self.epoch_based_cfg) + cfg.experiment_name = 'test_train7' + runner = Runner.from_cfg(cfg) + model = ToyModel2() + runner.model = model + runner.train() + self.assertTrue(model.initiailzed) + def test_val(self): cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_val1'