mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Fix weight initializing in test and refine registry logging. (#367)
* [Fix] Fix weight initializing and registry logging. * sync params * resolve comments
This commit is contained in:
parent
3da66d1f87
commit
1241c21296
@ -1,5 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from ..config import Config, ConfigDict
|
||||
@ -116,7 +117,8 @@ def build_from_cfg(
|
||||
f'An `{obj_cls.__name__}` instance is built from ' # type: ignore # noqa: E501
|
||||
'registry, its implementation can be found in '
|
||||
f'{obj_cls.__module__}', # type: ignore
|
||||
logger='current')
|
||||
logger='current',
|
||||
level=logging.DEBUG)
|
||||
return obj
|
||||
|
||||
except Exception as e:
|
||||
@ -188,7 +190,8 @@ def build_runner_from_cfg(cfg: Union[dict, ConfigDict, Config],
|
||||
f'An `{runner_cls.__name__}` instance is built from ' # type: ignore # noqa: E501
|
||||
'registry, its implementation can be found in'
|
||||
f'{runner_cls.__module__}', # type: ignore
|
||||
logger='current')
|
||||
logger='current',
|
||||
level=logging.DEBUG)
|
||||
return runner
|
||||
|
||||
except Exception as e:
|
||||
|
@ -318,7 +318,7 @@ class Registry:
|
||||
>>> mobilenet_cls = DETECTORS.get('cls.MobileNet')
|
||||
"""
|
||||
# Avoid circular import
|
||||
from ..logging.logger import MMLogger
|
||||
from ..logging import print_log
|
||||
|
||||
scope, real_key = self.split_scope_key(key)
|
||||
obj_cls = None
|
||||
@ -356,10 +356,11 @@ class Registry:
|
||||
obj_cls = root.get(key)
|
||||
|
||||
if obj_cls is not None:
|
||||
logger: MMLogger = MMLogger.get_current_instance()
|
||||
logger.info(
|
||||
print_log(
|
||||
f'Get class `{obj_cls.__name__}` from "{registry_name}"'
|
||||
f' registry in "{scope_name}"')
|
||||
f' registry in "{scope_name}"',
|
||||
logger='current',
|
||||
level=logging.DEBUG)
|
||||
return obj_cls
|
||||
|
||||
def _search_child(self, scope: str) -> Optional['Registry']:
|
||||
|
@ -797,8 +797,6 @@ class Runner:
|
||||
elif isinstance(model, dict):
|
||||
model = MODELS.build(model)
|
||||
# init weights
|
||||
if hasattr(model, 'init_weights'): # type: ignore
|
||||
model.init_weights() # type: ignore
|
||||
return model # type: ignore
|
||||
else:
|
||||
raise TypeError('model should be a nn.Module object or dict, '
|
||||
@ -870,6 +868,17 @@ class Runner:
|
||||
model_wrapper_cfg, default_args=default_args)
|
||||
return model
|
||||
|
||||
def _init_model_weights(self) -> None:
|
||||
"""Initialize the model weights if the model has
|
||||
:meth:`init_weights`"""
|
||||
model = self.model.module if is_model_wrapper(
|
||||
self.model) else self.model
|
||||
if hasattr(model, 'init_weights'):
|
||||
model.init_weights()
|
||||
# sync params and buffers
|
||||
for name, params in model.state_dict().items():
|
||||
broadcast(params)
|
||||
|
||||
def scale_lr(self,
|
||||
optim_wrapper: OptimWrapper,
|
||||
auto_scale_lr: Optional[Dict] = None) -> None:
|
||||
@ -1606,10 +1615,11 @@ class Runner:
|
||||
if self._val_loop is not None:
|
||||
self._val_loop = self.build_val_loop(
|
||||
self._val_loop) # type: ignore
|
||||
|
||||
# TODO: add a contextmanager to avoid calling `before_run` many times
|
||||
self.call_hook('before_run')
|
||||
|
||||
# TODO: add a contextmanager to avoid calling `before_run` many times
|
||||
# initialize the model weights
|
||||
self._init_model_weights()
|
||||
# make sure checkpoint-related hooks are triggered after `before_run`
|
||||
self.load_or_resume()
|
||||
|
||||
|
@ -701,25 +701,6 @@ class TestRunner(TestCase):
|
||||
model = runner.build_model(dict(type='ToyModel1'))
|
||||
self.assertIsInstance(model, ToyModel1)
|
||||
|
||||
# test init weights
|
||||
@MODELS.register_module()
|
||||
class ToyModel2(ToyModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.initiailzed = False
|
||||
|
||||
def init_weights(self):
|
||||
self.initiailzed = True
|
||||
|
||||
model = runner.build_model(dict(type='ToyModel2'))
|
||||
self.assertTrue(model.initiailzed)
|
||||
|
||||
# test init weights with model object
|
||||
_model = ToyModel2()
|
||||
model = runner.build_model(_model)
|
||||
self.assertFalse(model.initiailzed)
|
||||
|
||||
def test_wrap_model(self):
|
||||
# revert sync batchnorm
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
@ -1390,6 +1371,25 @@ class TestRunner(TestCase):
|
||||
for result, target, in zip(val_interval_results, val_interval_targets):
|
||||
self.assertEqual(result, target)
|
||||
|
||||
# 7. test init weights
|
||||
@MODELS.register_module()
|
||||
class ToyModel2(ToyModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.initiailzed = False
|
||||
|
||||
def init_weights(self):
|
||||
self.initiailzed = True
|
||||
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_train7'
|
||||
runner = Runner.from_cfg(cfg)
|
||||
model = ToyModel2()
|
||||
runner.model = model
|
||||
runner.train()
|
||||
self.assertTrue(model.initiailzed)
|
||||
|
||||
def test_val(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_val1'
|
||||
|
Loading…
x
Reference in New Issue
Block a user