`DvcliveLoggerHook` updates to work with `DVC` (#1208)

* Updates to work with DVC

* Update docstrings

* Updated test

* Updated DVCLiveLoggerHook

* Fix name

* Added missing next_step call

* Fix expected call

* Implicit next_step

* Suggestions from review

* Update test_hooks.py

* Updated to last dvclive version

* Cleaned docstring

* Update mmcv/runner/hooks/logger/dvclive.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update dvclive.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
pull/1599/head
David de la Iglesia Castro 2021-12-22 12:35:06 +01:00 committed by GitHub
parent fb486b96fd
commit ac92a1116f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 56 additions and 24 deletions

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from pathlib import Path
from ...dist_utils import master_only
from ..hook import HOOKS
from .base import LoggerHook
@ -11,48 +13,62 @@ class DvcliveLoggerHook(LoggerHook):
It requires `dvclive`_ to be installed.
Args:
path (str): Directory where dvclive will write TSV log files.
model_file (str):
Default None.
If not None, after each epoch the model will
be saved to {model_file}.
interval (int): Logging interval (every k iterations).
Default 10.
ignore_last (bool): Ignore the log of last iterations in each epoch
if less than `interval`.
Default: True.
reset_flag (bool): Whether to clear the output buffer after logging.
Default: True.
Default: False.
by_epoch (bool): Whether EpochBasedRunner is used.
Default: True.
kwargs:
Arguments for instantiating `Live`_
.. _dvclive:
https://dvc.org/doc/dvclive
.. _Live:
https://dvc.org/doc/dvclive/api-reference/live#parameters
"""
def __init__(self,
path,
model_file=None,
interval=10,
ignore_last=True,
reset_flag=True,
by_epoch=True):
reset_flag=False,
by_epoch=True,
**kwargs):
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.model_file = model_file
self.import_dvclive(**kwargs)
super(DvcliveLoggerHook, self).__init__(interval, ignore_last,
reset_flag, by_epoch)
self.path = path
self.import_dvclive()
def import_dvclive(self):
def import_dvclive(self, **kwargs):
try:
import dvclive
from dvclive import Live
except ImportError:
raise ImportError(
'Please run "pip install dvclive" to install dvclive')
self.dvclive = dvclive
@master_only
def before_run(self, runner):
self.dvclive.init(self.path)
self.dvclive = Live(**kwargs)
@master_only
def log(self, runner):
tags = self.get_loggable_tags(runner)
if tags:
self.dvclive.set_step(self.get_iter(runner))
for k, v in tags.items():
self.dvclive.log(k, v, step=self.get_iter(runner))
self.dvclive.log(k, v)
@master_only
def after_train_epoch(self, runner):
super().after_train_epoch(runner)
if self.model_file is not None:
runner.save_checkpoint(
Path(self.model_file).parent,
filename_tmpl=Path(self.model_file).name,
create_symlink=False,
)

View File

@ -1226,21 +1226,37 @@ def test_neptune_hook():
hook.run.stop.assert_called_with()
def test_dvclive_hook(tmp_path):
def test_dvclive_hook():
sys.modules['dvclive'] = MagicMock()
runner = _build_demo_runner()
(tmp_path / 'dvclive').mkdir()
hook = DvcliveLoggerHook(str(tmp_path / 'dvclive'))
hook = DvcliveLoggerHook()
dvclive_mock = hook.dvclive
loader = DataLoader(torch.ones((5, 2)))
runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)])
shutil.rmtree(runner.work_dir)
hook.dvclive.init.assert_called_with(str(tmp_path / 'dvclive'))
hook.dvclive.log.assert_called_with('momentum', 0.95, step=6)
hook.dvclive.log.assert_any_call('learning_rate', 0.02, step=6)
dvclive_mock.set_step.assert_called_with(6)
dvclive_mock.log.assert_called_with('momentum', 0.95)
def test_dvclive_hook_model_file(tmp_path):
sys.modules['dvclive'] = MagicMock()
runner = _build_demo_runner()
hook = DvcliveLoggerHook(model_file=osp.join(runner.work_dir, 'model.pth'))
runner.register_hook(hook)
loader = torch.utils.data.DataLoader(torch.ones((5, 2)))
loader = DataLoader(torch.ones((5, 2)))
runner.run([loader, loader], [('train', 1), ('val', 1)])
assert osp.exists(osp.join(runner.work_dir, 'model.pth'))
shutil.rmtree(runner.work_dir)
def _build_demo_runner_without_hook(runner_type='EpochBasedRunner',