[Enhancement] Add ability to pass logger instance to frameworks (#2317)

* Add ability to pass logger instance to frameworks

* refine docstring

* Update mmcv/runner/hooks/logger/dvclive.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
pull/2337/head
Ivan Shcheklein 2022-10-11 21:09:53 -07:00 committed by GitHub
parent dd5b415d2a
commit e417035f5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 5 deletions

View File

@ -22,7 +22,10 @@ class DvcliveLoggerHook(LoggerHook):
reset_flag (bool): Whether to clear the output buffer after logging.
Default: False.
by_epoch (bool): Whether EpochBasedRunner is used. Default: True.
kwargs: Arguments for instantiating `Live`_.
dvclive (Live, optional): An instance of the `Live`_ logger to use
instead of initializing a new one internally. Defaults to None.
kwargs: Arguments for instantiating `Live`_ (ignored if `dvclive` is
provided).
.. _dvclive:
https://dvc.org/doc/dvclive
@ -37,18 +40,19 @@ class DvcliveLoggerHook(LoggerHook):
ignore_last: bool = True,
reset_flag: bool = False,
by_epoch: bool = True,
dvclive=None,
**kwargs):
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.model_file = model_file
self.import_dvclive(**kwargs)
self._import_dvclive(dvclive, **kwargs)
def import_dvclive(self, **kwargs) -> None:
def _import_dvclive(self, dvclive=None, **kwargs) -> None:
try:
from dvclive import Live
except ImportError:
raise ImportError(
'Please run "pip install dvclive" to install dvclive')
self.dvclive = Live(**kwargs)
self.dvclive = dvclive if dvclive is not None else Live(**kwargs)
@master_only
def log(self, runner) -> None:

View File

@ -1665,7 +1665,6 @@ def test_dvclive_hook_model_file(tmp_path):
hook = DvcliveLoggerHook(model_file=osp.join(runner.work_dir, 'model.pth'))
runner.register_hook(hook)
loader = torch.utils.data.DataLoader(torch.ones((5, 2)))
loader = DataLoader(torch.ones((5, 2)))
runner.run([loader, loader], [('train', 1), ('val', 1)])
@ -1675,6 +1674,16 @@ def test_dvclive_hook_model_file(tmp_path):
shutil.rmtree(runner.work_dir)
def test_dvclive_hook_pass_logger(tmp_path):
sys.modules['dvclive'] = MagicMock()
from dvclive import Live
logger = Live()
sys.modules['dvclive'] = MagicMock()
assert DvcliveLoggerHook().dvclive is not logger
assert DvcliveLoggerHook(dvclive=logger).dvclive is logger
def test_clearml_hook():
sys.modules['clearml'] = MagicMock()
runner = _build_demo_runner()