[Enhancement] Add a new argument define_metric in wandb hook (#2237)

* wandb define_metric

* add test and some fix based on mmengine PR

* fix test

* add summary warnings

* Update mmcv/runner/hooks/logger/wandb.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/runner/hooks/logger/wandb.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
pull/2356/head
takuoko 2022-10-25 00:13:07 +09:00 committed by GitHub
parent ff18904721
commit 9709ff3f8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 2 deletions

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import warnings
from typing import Dict, Optional, Union
from mmcv.utils import scandir
@ -43,6 +44,17 @@ class WandbLoggerHook(LoggerHook):
``out_suffix`` will be uploaded to wandb.
Default: ('.log.json', '.log', '.py').
`New in version 1.4.3.`
define_metric_cfg (dict, optional): A dict of metrics and summaries for
wandb.define_metric. The key is metric and the value is summary.
The summary should be in ["min", "max", "mean" ,"best", "last",
"none"].
For example, if setting
``define_metric_cfg={'coco/bbox_mAP': 'max'}``, the maximum value
of ``coco/bbox_mAP`` will be logged on wandb UI. See
`wandb docs <https://docs.wandb.ai/ref/python/run#define_metric>`_
for details.
Defaults to None.
`New in version 1.6.3.`
.. _wandb:
https://docs.wandb.ai
@ -57,7 +69,8 @@ class WandbLoggerHook(LoggerHook):
by_epoch: bool = True,
with_step: bool = True,
log_artifact: bool = True,
out_suffix: Union[str, tuple] = ('.log.json', '.log', '.py')):
out_suffix: Union[str, tuple] = ('.log.json', '.log', '.py'),
define_metric_cfg: Optional[Dict] = None):
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.import_wandb()
self.init_kwargs = init_kwargs
@ -65,6 +78,7 @@ class WandbLoggerHook(LoggerHook):
self.with_step = with_step
self.log_artifact = log_artifact
self.out_suffix = out_suffix
self.define_metric_cfg = define_metric_cfg
def import_wandb(self) -> None:
try:
@ -83,6 +97,15 @@ class WandbLoggerHook(LoggerHook):
self.wandb.init(**self.init_kwargs) # type: ignore
else:
self.wandb.init() # type: ignore
summary_choice = ['min', 'max', 'mean', 'best', 'last', 'none']
if self.define_metric_cfg is not None:
for metric, summary in self.define_metric_cfg.items():
if summary not in summary_choice:
warnings.warn(
f'summary should be in {summary_choice}. '
f'metric={metric}, summary={summary} will be skipped.')
self.wandb.define_metric( # type: ignore
metric, summary=summary)
@master_only
def log(self, runner) -> None:

View File

@ -1606,7 +1606,8 @@ def test_segmind_hook():
def test_wandb_hook():
sys.modules['wandb'] = MagicMock()
runner = _build_demo_runner()
hook = WandbLoggerHook(log_artifact=True)
hook = WandbLoggerHook(
log_artifact=True, define_metric_cfg={'val/loss': 'min'})
loader = DataLoader(torch.ones((5, 2)))
runner.register_hook(hook)
@ -1615,6 +1616,7 @@ def test_wandb_hook():
shutil.rmtree(runner.work_dir)
hook.wandb.init.assert_called_with()
hook.wandb.define_metric.assert_called_with('val/loss', summary='min')
hook.wandb.log.assert_called_with({
'learning_rate': 0.02,
'momentum': 0.95