From ac92a1116f027c4f8d019ace4fd3fb29d1c4223b Mon Sep 17 00:00:00 2001 From: David de la Iglesia Castro Date: Wed, 22 Dec 2021 12:35:06 +0100 Subject: [PATCH] `DvcliveLoggerHook` updates to work with `DVC` (#1208) * Updates to work with DVC * Update docstrings * Updated test * Updated DVCLiveLoggerHook * Fix name * Added missing next_step call * Fix expected call * Implicit next_step * Suggestions from review * Update test_hooks.py * Updated to last dvclive version * Cleaned docstring * Update mmcv/runner/hooks/logger/dvclive.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update dvclive.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- mmcv/runner/hooks/logger/dvclive.py | 52 +++++++++++++++++++---------- tests/test_runner/test_hooks.py | 28 ++++++++++++---- 2 files changed, 56 insertions(+), 24 deletions(-) diff --git a/mmcv/runner/hooks/logger/dvclive.py b/mmcv/runner/hooks/logger/dvclive.py index 687cdc58c..536b1dd68 100644 --- a/mmcv/runner/hooks/logger/dvclive.py +++ b/mmcv/runner/hooks/logger/dvclive.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from pathlib import Path + from ...dist_utils import master_only from ..hook import HOOKS from .base import LoggerHook @@ -11,48 +13,62 @@ class DvcliveLoggerHook(LoggerHook): It requires `dvclive`_ to be installed. Args: - path (str): Directory where dvclive will write TSV log files. + model_file (str): + Default None. + If not None, after each epoch the model will + be saved to {model_file}. interval (int): Logging interval (every k iterations). Default 10. ignore_last (bool): Ignore the log of last iterations in each epoch if less than `interval`. Default: True. reset_flag (bool): Whether to clear the output buffer after logging. - Default: True. + Default: False. by_epoch (bool): Whether EpochBasedRunner is used. Default: True. + kwargs: + Arguments for instantiating `Live`_ .. _dvclive: https://dvc.org/doc/dvclive + + .. _Live: + https://dvc.org/doc/dvclive/api-reference/live#parameters """ def __init__(self, - path, + model_file=None, interval=10, ignore_last=True, - reset_flag=True, - by_epoch=True): + reset_flag=False, + by_epoch=True, + **kwargs): + super().__init__(interval, ignore_last, reset_flag, by_epoch) + self.model_file = model_file + self.import_dvclive(**kwargs) - super(DvcliveLoggerHook, self).__init__(interval, ignore_last, - reset_flag, by_epoch) - self.path = path - self.import_dvclive() - - def import_dvclive(self): + def import_dvclive(self, **kwargs): try: - import dvclive + from dvclive import Live except ImportError: raise ImportError( 'Please run "pip install dvclive" to install dvclive') - self.dvclive = dvclive - - @master_only - def before_run(self, runner): - self.dvclive.init(self.path) + self.dvclive = Live(**kwargs) @master_only def log(self, runner): tags = self.get_loggable_tags(runner) if tags: + self.dvclive.set_step(self.get_iter(runner)) for k, v in tags.items(): - self.dvclive.log(k, v, step=self.get_iter(runner)) + self.dvclive.log(k, v) + + @master_only + def after_train_epoch(self, runner): + super().after_train_epoch(runner) + if self.model_file is not None: + runner.save_checkpoint( + Path(self.model_file).parent, + filename_tmpl=Path(self.model_file).name, + create_symlink=False, + ) diff --git a/tests/test_runner/test_hooks.py b/tests/test_runner/test_hooks.py index be2970740..c628fa78d 100644 --- a/tests/test_runner/test_hooks.py +++ b/tests/test_runner/test_hooks.py @@ -1226,21 +1226,37 @@ def test_neptune_hook(): hook.run.stop.assert_called_with() -def test_dvclive_hook(tmp_path): +def test_dvclive_hook(): sys.modules['dvclive'] = MagicMock() runner = _build_demo_runner() - (tmp_path / 'dvclive').mkdir() - hook = DvcliveLoggerHook(str(tmp_path / 'dvclive')) + hook = DvcliveLoggerHook() + dvclive_mock = hook.dvclive loader = DataLoader(torch.ones((5, 2))) runner.register_hook(hook) runner.run([loader, loader], [('train', 1), ('val', 1)]) shutil.rmtree(runner.work_dir) - hook.dvclive.init.assert_called_with(str(tmp_path / 'dvclive')) - hook.dvclive.log.assert_called_with('momentum', 0.95, step=6) - hook.dvclive.log.assert_any_call('learning_rate', 0.02, step=6) + dvclive_mock.set_step.assert_called_with(6) + dvclive_mock.log.assert_called_with('momentum', 0.95) + + +def test_dvclive_hook_model_file(tmp_path): + sys.modules['dvclive'] = MagicMock() + runner = _build_demo_runner() + + hook = DvcliveLoggerHook(model_file=osp.join(runner.work_dir, 'model.pth')) + runner.register_hook(hook) + + loader = torch.utils.data.DataLoader(torch.ones((5, 2))) + loader = DataLoader(torch.ones((5, 2))) + + runner.run([loader, loader], [('train', 1), ('val', 1)]) + + assert osp.exists(osp.join(runner.work_dir, 'model.pth')) + + shutil.rmtree(runner.work_dir) def _build_demo_runner_without_hook(runner_type='EpochBasedRunner',