[Fix] Fix weight initializing in test and refine registry logging. (#367)

* [Fix] Fix weight initializing and registry logging.

* sync params

* resolve comments
This commit is contained in:
RangiLyu 2022-07-19 18:28:57 +08:00 committed by GitHub
parent 3da66d1f87
commit 1241c21296
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 43 additions and 29 deletions

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import inspect import inspect
import logging
from typing import Any, Optional, Union from typing import Any, Optional, Union
from ..config import Config, ConfigDict from ..config import Config, ConfigDict
@ -116,7 +117,8 @@ def build_from_cfg(
f'An `{obj_cls.__name__}` instance is built from ' # type: ignore # noqa: E501 f'An `{obj_cls.__name__}` instance is built from ' # type: ignore # noqa: E501
'registry, its implementation can be found in ' 'registry, its implementation can be found in '
f'{obj_cls.__module__}', # type: ignore f'{obj_cls.__module__}', # type: ignore
logger='current') logger='current',
level=logging.DEBUG)
return obj return obj
except Exception as e: except Exception as e:
@ -188,7 +190,8 @@ def build_runner_from_cfg(cfg: Union[dict, ConfigDict, Config],
f'An `{runner_cls.__name__}` instance is built from ' # type: ignore # noqa: E501 f'An `{runner_cls.__name__}` instance is built from ' # type: ignore # noqa: E501
'registry, its implementation can be found in' 'registry, its implementation can be found in'
f'{runner_cls.__module__}', # type: ignore f'{runner_cls.__module__}', # type: ignore
logger='current') logger='current',
level=logging.DEBUG)
return runner return runner
except Exception as e: except Exception as e:

View File

@ -318,7 +318,7 @@ class Registry:
>>> mobilenet_cls = DETECTORS.get('cls.MobileNet') >>> mobilenet_cls = DETECTORS.get('cls.MobileNet')
""" """
# Avoid circular import # Avoid circular import
from ..logging.logger import MMLogger from ..logging import print_log
scope, real_key = self.split_scope_key(key) scope, real_key = self.split_scope_key(key)
obj_cls = None obj_cls = None
@ -356,10 +356,11 @@ class Registry:
obj_cls = root.get(key) obj_cls = root.get(key)
if obj_cls is not None: if obj_cls is not None:
logger: MMLogger = MMLogger.get_current_instance() print_log(
logger.info(
f'Get class `{obj_cls.__name__}` from "{registry_name}"' f'Get class `{obj_cls.__name__}` from "{registry_name}"'
f' registry in "{scope_name}"') f' registry in "{scope_name}"',
logger='current',
level=logging.DEBUG)
return obj_cls return obj_cls
def _search_child(self, scope: str) -> Optional['Registry']: def _search_child(self, scope: str) -> Optional['Registry']:

View File

@ -797,8 +797,6 @@ class Runner:
elif isinstance(model, dict): elif isinstance(model, dict):
model = MODELS.build(model) model = MODELS.build(model)
# init weights # init weights
if hasattr(model, 'init_weights'): # type: ignore
model.init_weights() # type: ignore
return model # type: ignore return model # type: ignore
else: else:
raise TypeError('model should be a nn.Module object or dict, ' raise TypeError('model should be a nn.Module object or dict, '
@ -870,6 +868,17 @@ class Runner:
model_wrapper_cfg, default_args=default_args) model_wrapper_cfg, default_args=default_args)
return model return model
def _init_model_weights(self) -> None:
"""Initialize the model weights if the model has
:meth:`init_weights`"""
model = self.model.module if is_model_wrapper(
self.model) else self.model
if hasattr(model, 'init_weights'):
model.init_weights()
# sync params and buffers
for name, params in model.state_dict().items():
broadcast(params)
def scale_lr(self, def scale_lr(self,
optim_wrapper: OptimWrapper, optim_wrapper: OptimWrapper,
auto_scale_lr: Optional[Dict] = None) -> None: auto_scale_lr: Optional[Dict] = None) -> None:
@ -1606,10 +1615,11 @@ class Runner:
if self._val_loop is not None: if self._val_loop is not None:
self._val_loop = self.build_val_loop( self._val_loop = self.build_val_loop(
self._val_loop) # type: ignore self._val_loop) # type: ignore
# TODO: add a contextmanager to avoid calling `before_run` many times
self.call_hook('before_run') self.call_hook('before_run')
# TODO: add a contextmanager to avoid calling `before_run` many times # initialize the model weights
self._init_model_weights()
# make sure checkpoint-related hooks are triggered after `before_run` # make sure checkpoint-related hooks are triggered after `before_run`
self.load_or_resume() self.load_or_resume()

View File

@ -701,25 +701,6 @@ class TestRunner(TestCase):
model = runner.build_model(dict(type='ToyModel1')) model = runner.build_model(dict(type='ToyModel1'))
self.assertIsInstance(model, ToyModel1) self.assertIsInstance(model, ToyModel1)
# test init weights
@MODELS.register_module()
class ToyModel2(ToyModel):
def __init__(self):
super().__init__()
self.initiailzed = False
def init_weights(self):
self.initiailzed = True
model = runner.build_model(dict(type='ToyModel2'))
self.assertTrue(model.initiailzed)
# test init weights with model object
_model = ToyModel2()
model = runner.build_model(_model)
self.assertFalse(model.initiailzed)
def test_wrap_model(self): def test_wrap_model(self):
# revert sync batchnorm # revert sync batchnorm
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
@ -1390,6 +1371,25 @@ class TestRunner(TestCase):
for result, target, in zip(val_interval_results, val_interval_targets): for result, target, in zip(val_interval_results, val_interval_targets):
self.assertEqual(result, target) self.assertEqual(result, target)
# 7. test init weights
@MODELS.register_module()
class ToyModel2(ToyModel):
def __init__(self):
super().__init__()
self.initiailzed = False
def init_weights(self):
self.initiailzed = True
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_train7'
runner = Runner.from_cfg(cfg)
model = ToyModel2()
runner.model = model
runner.train()
self.assertTrue(model.initiailzed)
def test_val(self): def test_val(self):
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_val1' cfg.experiment_name = 'test_val1'