mirror of https://github.com/open-mmlab/mmcv.git
`DvcliveLoggerHook` updates to work with `DVC` (#1208)
* Updates to work with DVC * Update docstrings * Updated test * Updated DVCLiveLoggerHook * Fix name * Added missing next_step call * Fix expected call * Implicit next_step * Suggestions from review * Update test_hooks.py * Updated to last dvclive version * Cleaned docstring * Update mmcv/runner/hooks/logger/dvclive.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update dvclive.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>pull/1599/head
parent
fb486b96fd
commit
ac92a1116f
|
@ -1,4 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from pathlib import Path
|
||||
|
||||
from ...dist_utils import master_only
|
||||
from ..hook import HOOKS
|
||||
from .base import LoggerHook
|
||||
|
@ -11,48 +13,62 @@ class DvcliveLoggerHook(LoggerHook):
|
|||
It requires `dvclive`_ to be installed.
|
||||
|
||||
Args:
|
||||
path (str): Directory where dvclive will write TSV log files.
|
||||
model_file (str):
|
||||
Default None.
|
||||
If not None, after each epoch the model will
|
||||
be saved to {model_file}.
|
||||
interval (int): Logging interval (every k iterations).
|
||||
Default 10.
|
||||
ignore_last (bool): Ignore the log of last iterations in each epoch
|
||||
if less than `interval`.
|
||||
Default: True.
|
||||
reset_flag (bool): Whether to clear the output buffer after logging.
|
||||
Default: True.
|
||||
Default: False.
|
||||
by_epoch (bool): Whether EpochBasedRunner is used.
|
||||
Default: True.
|
||||
kwargs:
|
||||
Arguments for instantiating `Live`_
|
||||
|
||||
.. _dvclive:
|
||||
https://dvc.org/doc/dvclive
|
||||
|
||||
.. _Live:
|
||||
https://dvc.org/doc/dvclive/api-reference/live#parameters
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
path,
|
||||
model_file=None,
|
||||
interval=10,
|
||||
ignore_last=True,
|
||||
reset_flag=True,
|
||||
by_epoch=True):
|
||||
reset_flag=False,
|
||||
by_epoch=True,
|
||||
**kwargs):
|
||||
super().__init__(interval, ignore_last, reset_flag, by_epoch)
|
||||
self.model_file = model_file
|
||||
self.import_dvclive(**kwargs)
|
||||
|
||||
super(DvcliveLoggerHook, self).__init__(interval, ignore_last,
|
||||
reset_flag, by_epoch)
|
||||
self.path = path
|
||||
self.import_dvclive()
|
||||
|
||||
def import_dvclive(self):
|
||||
def import_dvclive(self, **kwargs):
|
||||
try:
|
||||
import dvclive
|
||||
from dvclive import Live
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Please run "pip install dvclive" to install dvclive')
|
||||
self.dvclive = dvclive
|
||||
|
||||
@master_only
|
||||
def before_run(self, runner):
|
||||
self.dvclive.init(self.path)
|
||||
self.dvclive = Live(**kwargs)
|
||||
|
||||
@master_only
|
||||
def log(self, runner):
|
||||
tags = self.get_loggable_tags(runner)
|
||||
if tags:
|
||||
self.dvclive.set_step(self.get_iter(runner))
|
||||
for k, v in tags.items():
|
||||
self.dvclive.log(k, v, step=self.get_iter(runner))
|
||||
self.dvclive.log(k, v)
|
||||
|
||||
@master_only
|
||||
def after_train_epoch(self, runner):
|
||||
super().after_train_epoch(runner)
|
||||
if self.model_file is not None:
|
||||
runner.save_checkpoint(
|
||||
Path(self.model_file).parent,
|
||||
filename_tmpl=Path(self.model_file).name,
|
||||
create_symlink=False,
|
||||
)
|
||||
|
|
|
@ -1226,21 +1226,37 @@ def test_neptune_hook():
|
|||
hook.run.stop.assert_called_with()
|
||||
|
||||
|
||||
def test_dvclive_hook(tmp_path):
|
||||
def test_dvclive_hook():
|
||||
sys.modules['dvclive'] = MagicMock()
|
||||
runner = _build_demo_runner()
|
||||
|
||||
(tmp_path / 'dvclive').mkdir()
|
||||
hook = DvcliveLoggerHook(str(tmp_path / 'dvclive'))
|
||||
hook = DvcliveLoggerHook()
|
||||
dvclive_mock = hook.dvclive
|
||||
loader = DataLoader(torch.ones((5, 2)))
|
||||
|
||||
runner.register_hook(hook)
|
||||
runner.run([loader, loader], [('train', 1), ('val', 1)])
|
||||
shutil.rmtree(runner.work_dir)
|
||||
|
||||
hook.dvclive.init.assert_called_with(str(tmp_path / 'dvclive'))
|
||||
hook.dvclive.log.assert_called_with('momentum', 0.95, step=6)
|
||||
hook.dvclive.log.assert_any_call('learning_rate', 0.02, step=6)
|
||||
dvclive_mock.set_step.assert_called_with(6)
|
||||
dvclive_mock.log.assert_called_with('momentum', 0.95)
|
||||
|
||||
|
||||
def test_dvclive_hook_model_file(tmp_path):
|
||||
sys.modules['dvclive'] = MagicMock()
|
||||
runner = _build_demo_runner()
|
||||
|
||||
hook = DvcliveLoggerHook(model_file=osp.join(runner.work_dir, 'model.pth'))
|
||||
runner.register_hook(hook)
|
||||
|
||||
loader = torch.utils.data.DataLoader(torch.ones((5, 2)))
|
||||
loader = DataLoader(torch.ones((5, 2)))
|
||||
|
||||
runner.run([loader, loader], [('train', 1), ('val', 1)])
|
||||
|
||||
assert osp.exists(osp.join(runner.work_dir, 'model.pth'))
|
||||
|
||||
shutil.rmtree(runner.work_dir)
|
||||
|
||||
|
||||
def _build_demo_runner_without_hook(runner_type='EpochBasedRunner',
|
||||
|
|
Loading…
Reference in New Issue