mirror of https://github.com/open-mmlab/mmcv.git
[Enhancement] Add ability to pass logger instance to frameworks (#2317)
* Add ability to pass logger instance to frameworks * refine docstring * Update mmcv/runner/hooks/logger/dvclive.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>pull/2337/head
parent
dd5b415d2a
commit
e417035f5d
|
@ -22,7 +22,10 @@ class DvcliveLoggerHook(LoggerHook):
|
|||
reset_flag (bool): Whether to clear the output buffer after logging.
|
||||
Default: False.
|
||||
by_epoch (bool): Whether EpochBasedRunner is used. Default: True.
|
||||
kwargs: Arguments for instantiating `Live`_.
|
||||
dvclive (Live, optional): An instance of the `Live`_ logger to use
|
||||
instead of initializing a new one internally. Defaults to None.
|
||||
kwargs: Arguments for instantiating `Live`_ (ignored if `dvclive` is
|
||||
provided).
|
||||
|
||||
.. _dvclive:
|
||||
https://dvc.org/doc/dvclive
|
||||
|
@ -37,18 +40,19 @@ class DvcliveLoggerHook(LoggerHook):
|
|||
ignore_last: bool = True,
|
||||
reset_flag: bool = False,
|
||||
by_epoch: bool = True,
|
||||
dvclive=None,
|
||||
**kwargs):
|
||||
super().__init__(interval, ignore_last, reset_flag, by_epoch)
|
||||
self.model_file = model_file
|
||||
self.import_dvclive(**kwargs)
|
||||
self._import_dvclive(dvclive, **kwargs)
|
||||
|
||||
def import_dvclive(self, **kwargs) -> None:
|
||||
def _import_dvclive(self, dvclive=None, **kwargs) -> None:
|
||||
try:
|
||||
from dvclive import Live
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Please run "pip install dvclive" to install dvclive')
|
||||
self.dvclive = Live(**kwargs)
|
||||
self.dvclive = dvclive if dvclive is not None else Live(**kwargs)
|
||||
|
||||
@master_only
|
||||
def log(self, runner) -> None:
|
||||
|
|
|
@ -1665,7 +1665,6 @@ def test_dvclive_hook_model_file(tmp_path):
|
|||
hook = DvcliveLoggerHook(model_file=osp.join(runner.work_dir, 'model.pth'))
|
||||
runner.register_hook(hook)
|
||||
|
||||
loader = torch.utils.data.DataLoader(torch.ones((5, 2)))
|
||||
loader = DataLoader(torch.ones((5, 2)))
|
||||
|
||||
runner.run([loader, loader], [('train', 1), ('val', 1)])
|
||||
|
@ -1675,6 +1674,16 @@ def test_dvclive_hook_model_file(tmp_path):
|
|||
shutil.rmtree(runner.work_dir)
|
||||
|
||||
|
||||
def test_dvclive_hook_pass_logger(tmp_path):
|
||||
sys.modules['dvclive'] = MagicMock()
|
||||
from dvclive import Live
|
||||
logger = Live()
|
||||
|
||||
sys.modules['dvclive'] = MagicMock()
|
||||
assert DvcliveLoggerHook().dvclive is not logger
|
||||
assert DvcliveLoggerHook(dvclive=logger).dvclive is logger
|
||||
|
||||
|
||||
def test_clearml_hook():
|
||||
sys.modules['clearml'] = MagicMock()
|
||||
runner = _build_demo_runner()
|
||||
|
|
Loading…
Reference in New Issue