[Fix] Fix weight initializing in test and refine registry logging. (#367)

* [Fix] Fix weight initializing and registry logging.

* sync params

* resolve comments
This commit is contained in:
RangiLyu 2022-07-19 18:28:57 +08:00 committed by GitHub
parent 3da66d1f87
commit 1241c21296
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 43 additions and 29 deletions

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
import logging
from typing import Any, Optional, Union
from ..config import Config, ConfigDict
@ -116,7 +117,8 @@ def build_from_cfg(
f'An `{obj_cls.__name__}` instance is built from ' # type: ignore # noqa: E501
'registry, its implementation can be found in '
f'{obj_cls.__module__}', # type: ignore
logger='current')
logger='current',
level=logging.DEBUG)
return obj
except Exception as e:
@ -188,7 +190,8 @@ def build_runner_from_cfg(cfg: Union[dict, ConfigDict, Config],
f'An `{runner_cls.__name__}` instance is built from ' # type: ignore # noqa: E501
'registry, its implementation can be found in'
f'{runner_cls.__module__}', # type: ignore
logger='current')
logger='current',
level=logging.DEBUG)
return runner
except Exception as e:

View File

@ -318,7 +318,7 @@ class Registry:
>>> mobilenet_cls = DETECTORS.get('cls.MobileNet')
"""
# Avoid circular import
from ..logging.logger import MMLogger
from ..logging import print_log
scope, real_key = self.split_scope_key(key)
obj_cls = None
@ -356,10 +356,11 @@ class Registry:
obj_cls = root.get(key)
if obj_cls is not None:
logger: MMLogger = MMLogger.get_current_instance()
logger.info(
print_log(
f'Get class `{obj_cls.__name__}` from "{registry_name}"'
f' registry in "{scope_name}"')
f' registry in "{scope_name}"',
logger='current',
level=logging.DEBUG)
return obj_cls
def _search_child(self, scope: str) -> Optional['Registry']:

View File

@ -797,8 +797,6 @@ class Runner:
elif isinstance(model, dict):
model = MODELS.build(model)
# init weights
if hasattr(model, 'init_weights'): # type: ignore
model.init_weights() # type: ignore
return model # type: ignore
else:
raise TypeError('model should be a nn.Module object or dict, '
@ -870,6 +868,17 @@ class Runner:
model_wrapper_cfg, default_args=default_args)
return model
def _init_model_weights(self) -> None:
"""Initialize the model weights if the model has
:meth:`init_weights`"""
model = self.model.module if is_model_wrapper(
self.model) else self.model
if hasattr(model, 'init_weights'):
model.init_weights()
# sync params and buffers
for name, params in model.state_dict().items():
broadcast(params)
def scale_lr(self,
optim_wrapper: OptimWrapper,
auto_scale_lr: Optional[Dict] = None) -> None:
@ -1606,10 +1615,11 @@ class Runner:
if self._val_loop is not None:
self._val_loop = self.build_val_loop(
self._val_loop) # type: ignore
# TODO: add a contextmanager to avoid calling `before_run` many times
self.call_hook('before_run')
# TODO: add a contextmanager to avoid calling `before_run` many times
# initialize the model weights
self._init_model_weights()
# make sure checkpoint-related hooks are triggered after `before_run`
self.load_or_resume()

View File

@ -701,25 +701,6 @@ class TestRunner(TestCase):
model = runner.build_model(dict(type='ToyModel1'))
self.assertIsInstance(model, ToyModel1)
# test init weights
@MODELS.register_module()
class ToyModel2(ToyModel):
def __init__(self):
super().__init__()
self.initiailzed = False
def init_weights(self):
self.initiailzed = True
model = runner.build_model(dict(type='ToyModel2'))
self.assertTrue(model.initiailzed)
# test init weights with model object
_model = ToyModel2()
model = runner.build_model(_model)
self.assertFalse(model.initiailzed)
def test_wrap_model(self):
# revert sync batchnorm
cfg = copy.deepcopy(self.epoch_based_cfg)
@ -1390,6 +1371,25 @@ class TestRunner(TestCase):
for result, target, in zip(val_interval_results, val_interval_targets):
self.assertEqual(result, target)
# 7. test init weights
@MODELS.register_module()
class ToyModel2(ToyModel):
def __init__(self):
super().__init__()
self.initiailzed = False
def init_weights(self):
self.initiailzed = True
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_train7'
runner = Runner.from_cfg(cfg)
model = ToyModel2()
runner.model = model
runner.train()
self.assertTrue(model.initiailzed)
def test_val(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_val1'