[Enhancement] Add a new argument define_metric in wandb hook (#2237)

* wandb define_metric

* add test and some fix based on mmengine PR

* fix test

* add summary warnings

* Update mmcv/runner/hooks/logger/wandb.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/runner/hooks/logger/wandb.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
pull/2356/head
takuoko 2022-10-25 00:13:07 +09:00 committed by GitHub
parent ff18904721
commit 9709ff3f8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 2 deletions

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp import os.path as osp
import warnings
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from mmcv.utils import scandir from mmcv.utils import scandir
@ -43,6 +44,17 @@ class WandbLoggerHook(LoggerHook):
``out_suffix`` will be uploaded to wandb. ``out_suffix`` will be uploaded to wandb.
Default: ('.log.json', '.log', '.py'). Default: ('.log.json', '.log', '.py').
`New in version 1.4.3.` `New in version 1.4.3.`
define_metric_cfg (dict, optional): A dict of metrics and summaries for
wandb.define_metric. The key is metric and the value is summary.
The summary should be in ["min", "max", "mean" ,"best", "last",
"none"].
For example, if setting
``define_metric_cfg={'coco/bbox_mAP': 'max'}``, the maximum value
of ``coco/bbox_mAP`` will be logged on wandb UI. See
`wandb docs <https://docs.wandb.ai/ref/python/run#define_metric>`_
for details.
Defaults to None.
`New in version 1.6.3.`
.. _wandb: .. _wandb:
https://docs.wandb.ai https://docs.wandb.ai
@ -57,7 +69,8 @@ class WandbLoggerHook(LoggerHook):
by_epoch: bool = True, by_epoch: bool = True,
with_step: bool = True, with_step: bool = True,
log_artifact: bool = True, log_artifact: bool = True,
out_suffix: Union[str, tuple] = ('.log.json', '.log', '.py')): out_suffix: Union[str, tuple] = ('.log.json', '.log', '.py'),
define_metric_cfg: Optional[Dict] = None):
super().__init__(interval, ignore_last, reset_flag, by_epoch) super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.import_wandb() self.import_wandb()
self.init_kwargs = init_kwargs self.init_kwargs = init_kwargs
@ -65,6 +78,7 @@ class WandbLoggerHook(LoggerHook):
self.with_step = with_step self.with_step = with_step
self.log_artifact = log_artifact self.log_artifact = log_artifact
self.out_suffix = out_suffix self.out_suffix = out_suffix
self.define_metric_cfg = define_metric_cfg
def import_wandb(self) -> None: def import_wandb(self) -> None:
try: try:
@ -83,6 +97,15 @@ class WandbLoggerHook(LoggerHook):
self.wandb.init(**self.init_kwargs) # type: ignore self.wandb.init(**self.init_kwargs) # type: ignore
else: else:
self.wandb.init() # type: ignore self.wandb.init() # type: ignore
summary_choice = ['min', 'max', 'mean', 'best', 'last', 'none']
if self.define_metric_cfg is not None:
for metric, summary in self.define_metric_cfg.items():
if summary not in summary_choice:
warnings.warn(
f'summary should be in {summary_choice}. '
f'metric={metric}, summary={summary} will be skipped.')
self.wandb.define_metric( # type: ignore
metric, summary=summary)
@master_only @master_only
def log(self, runner) -> None: def log(self, runner) -> None:

View File

@ -1606,7 +1606,8 @@ def test_segmind_hook():
def test_wandb_hook(): def test_wandb_hook():
sys.modules['wandb'] = MagicMock() sys.modules['wandb'] = MagicMock()
runner = _build_demo_runner() runner = _build_demo_runner()
hook = WandbLoggerHook(log_artifact=True) hook = WandbLoggerHook(
log_artifact=True, define_metric_cfg={'val/loss': 'min'})
loader = DataLoader(torch.ones((5, 2))) loader = DataLoader(torch.ones((5, 2)))
runner.register_hook(hook) runner.register_hook(hook)
@ -1615,6 +1616,7 @@ def test_wandb_hook():
shutil.rmtree(runner.work_dir) shutil.rmtree(runner.work_dir)
hook.wandb.init.assert_called_with() hook.wandb.init.assert_called_with()
hook.wandb.define_metric.assert_called_with('val/loss', summary='min')
hook.wandb.log.assert_called_with({ hook.wandb.log.assert_called_with({
'learning_rate': 0.02, 'learning_rate': 0.02,
'momentum': 0.95 'momentum': 0.95