mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Fix weight initializing in test and refine registry logging. (#367)
* [Fix] Fix weight initializing and registry logging. * sync params * resolve comments
This commit is contained in:
parent
3da66d1f87
commit
1241c21296
@ -1,5 +1,6 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import inspect
|
import inspect
|
||||||
|
import logging
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from ..config import Config, ConfigDict
|
from ..config import Config, ConfigDict
|
||||||
@ -116,7 +117,8 @@ def build_from_cfg(
|
|||||||
f'An `{obj_cls.__name__}` instance is built from ' # type: ignore # noqa: E501
|
f'An `{obj_cls.__name__}` instance is built from ' # type: ignore # noqa: E501
|
||||||
'registry, its implementation can be found in '
|
'registry, its implementation can be found in '
|
||||||
f'{obj_cls.__module__}', # type: ignore
|
f'{obj_cls.__module__}', # type: ignore
|
||||||
logger='current')
|
logger='current',
|
||||||
|
level=logging.DEBUG)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -188,7 +190,8 @@ def build_runner_from_cfg(cfg: Union[dict, ConfigDict, Config],
|
|||||||
f'An `{runner_cls.__name__}` instance is built from ' # type: ignore # noqa: E501
|
f'An `{runner_cls.__name__}` instance is built from ' # type: ignore # noqa: E501
|
||||||
'registry, its implementation can be found in'
|
'registry, its implementation can be found in'
|
||||||
f'{runner_cls.__module__}', # type: ignore
|
f'{runner_cls.__module__}', # type: ignore
|
||||||
logger='current')
|
logger='current',
|
||||||
|
level=logging.DEBUG)
|
||||||
return runner
|
return runner
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -318,7 +318,7 @@ class Registry:
|
|||||||
>>> mobilenet_cls = DETECTORS.get('cls.MobileNet')
|
>>> mobilenet_cls = DETECTORS.get('cls.MobileNet')
|
||||||
"""
|
"""
|
||||||
# Avoid circular import
|
# Avoid circular import
|
||||||
from ..logging.logger import MMLogger
|
from ..logging import print_log
|
||||||
|
|
||||||
scope, real_key = self.split_scope_key(key)
|
scope, real_key = self.split_scope_key(key)
|
||||||
obj_cls = None
|
obj_cls = None
|
||||||
@ -356,10 +356,11 @@ class Registry:
|
|||||||
obj_cls = root.get(key)
|
obj_cls = root.get(key)
|
||||||
|
|
||||||
if obj_cls is not None:
|
if obj_cls is not None:
|
||||||
logger: MMLogger = MMLogger.get_current_instance()
|
print_log(
|
||||||
logger.info(
|
|
||||||
f'Get class `{obj_cls.__name__}` from "{registry_name}"'
|
f'Get class `{obj_cls.__name__}` from "{registry_name}"'
|
||||||
f' registry in "{scope_name}"')
|
f' registry in "{scope_name}"',
|
||||||
|
logger='current',
|
||||||
|
level=logging.DEBUG)
|
||||||
return obj_cls
|
return obj_cls
|
||||||
|
|
||||||
def _search_child(self, scope: str) -> Optional['Registry']:
|
def _search_child(self, scope: str) -> Optional['Registry']:
|
||||||
|
@ -797,8 +797,6 @@ class Runner:
|
|||||||
elif isinstance(model, dict):
|
elif isinstance(model, dict):
|
||||||
model = MODELS.build(model)
|
model = MODELS.build(model)
|
||||||
# init weights
|
# init weights
|
||||||
if hasattr(model, 'init_weights'): # type: ignore
|
|
||||||
model.init_weights() # type: ignore
|
|
||||||
return model # type: ignore
|
return model # type: ignore
|
||||||
else:
|
else:
|
||||||
raise TypeError('model should be a nn.Module object or dict, '
|
raise TypeError('model should be a nn.Module object or dict, '
|
||||||
@ -870,6 +868,17 @@ class Runner:
|
|||||||
model_wrapper_cfg, default_args=default_args)
|
model_wrapper_cfg, default_args=default_args)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def _init_model_weights(self) -> None:
|
||||||
|
"""Initialize the model weights if the model has
|
||||||
|
:meth:`init_weights`"""
|
||||||
|
model = self.model.module if is_model_wrapper(
|
||||||
|
self.model) else self.model
|
||||||
|
if hasattr(model, 'init_weights'):
|
||||||
|
model.init_weights()
|
||||||
|
# sync params and buffers
|
||||||
|
for name, params in model.state_dict().items():
|
||||||
|
broadcast(params)
|
||||||
|
|
||||||
def scale_lr(self,
|
def scale_lr(self,
|
||||||
optim_wrapper: OptimWrapper,
|
optim_wrapper: OptimWrapper,
|
||||||
auto_scale_lr: Optional[Dict] = None) -> None:
|
auto_scale_lr: Optional[Dict] = None) -> None:
|
||||||
@ -1606,10 +1615,11 @@ class Runner:
|
|||||||
if self._val_loop is not None:
|
if self._val_loop is not None:
|
||||||
self._val_loop = self.build_val_loop(
|
self._val_loop = self.build_val_loop(
|
||||||
self._val_loop) # type: ignore
|
self._val_loop) # type: ignore
|
||||||
|
# TODO: add a contextmanager to avoid calling `before_run` many times
|
||||||
self.call_hook('before_run')
|
self.call_hook('before_run')
|
||||||
|
|
||||||
# TODO: add a contextmanager to avoid calling `before_run` many times
|
# initialize the model weights
|
||||||
|
self._init_model_weights()
|
||||||
# make sure checkpoint-related hooks are triggered after `before_run`
|
# make sure checkpoint-related hooks are triggered after `before_run`
|
||||||
self.load_or_resume()
|
self.load_or_resume()
|
||||||
|
|
||||||
|
@ -701,25 +701,6 @@ class TestRunner(TestCase):
|
|||||||
model = runner.build_model(dict(type='ToyModel1'))
|
model = runner.build_model(dict(type='ToyModel1'))
|
||||||
self.assertIsInstance(model, ToyModel1)
|
self.assertIsInstance(model, ToyModel1)
|
||||||
|
|
||||||
# test init weights
|
|
||||||
@MODELS.register_module()
|
|
||||||
class ToyModel2(ToyModel):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.initiailzed = False
|
|
||||||
|
|
||||||
def init_weights(self):
|
|
||||||
self.initiailzed = True
|
|
||||||
|
|
||||||
model = runner.build_model(dict(type='ToyModel2'))
|
|
||||||
self.assertTrue(model.initiailzed)
|
|
||||||
|
|
||||||
# test init weights with model object
|
|
||||||
_model = ToyModel2()
|
|
||||||
model = runner.build_model(_model)
|
|
||||||
self.assertFalse(model.initiailzed)
|
|
||||||
|
|
||||||
def test_wrap_model(self):
|
def test_wrap_model(self):
|
||||||
# revert sync batchnorm
|
# revert sync batchnorm
|
||||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
@ -1390,6 +1371,25 @@ class TestRunner(TestCase):
|
|||||||
for result, target, in zip(val_interval_results, val_interval_targets):
|
for result, target, in zip(val_interval_results, val_interval_targets):
|
||||||
self.assertEqual(result, target)
|
self.assertEqual(result, target)
|
||||||
|
|
||||||
|
# 7. test init weights
|
||||||
|
@MODELS.register_module()
|
||||||
|
class ToyModel2(ToyModel):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.initiailzed = False
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
self.initiailzed = True
|
||||||
|
|
||||||
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_train7'
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
|
model = ToyModel2()
|
||||||
|
runner.model = model
|
||||||
|
runner.train()
|
||||||
|
self.assertTrue(model.initiailzed)
|
||||||
|
|
||||||
def test_val(self):
|
def test_val(self):
|
||||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
cfg.experiment_name = 'test_val1'
|
cfg.experiment_name = 'test_val1'
|
||||||
|
Loading…
x
Reference in New Issue
Block a user