mirror of https://github.com/open-mmlab/mmcv.git
[Enhancement] Add a new argument define_metric in wandb hook (#2237)
* wandb define_metric * add test and some fix based on mmengine PR * fix test * add summary warnings * Update mmcv/runner/hooks/logger/wandb.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmcv/runner/hooks/logger/wandb.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>pull/2356/head
parent
ff18904721
commit
9709ff3f8c
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import warnings
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from mmcv.utils import scandir
|
||||
|
@ -43,6 +44,17 @@ class WandbLoggerHook(LoggerHook):
|
|||
``out_suffix`` will be uploaded to wandb.
|
||||
Default: ('.log.json', '.log', '.py').
|
||||
`New in version 1.4.3.`
|
||||
define_metric_cfg (dict, optional): A dict of metrics and summaries for
|
||||
wandb.define_metric. The key is metric and the value is summary.
|
||||
The summary should be in ["min", "max", "mean" ,"best", "last",
|
||||
"none"].
|
||||
For example, if setting
|
||||
``define_metric_cfg={'coco/bbox_mAP': 'max'}``, the maximum value
|
||||
of ``coco/bbox_mAP`` will be logged on wandb UI. See
|
||||
`wandb docs <https://docs.wandb.ai/ref/python/run#define_metric>`_
|
||||
for details.
|
||||
Defaults to None.
|
||||
`New in version 1.6.3.`
|
||||
|
||||
.. _wandb:
|
||||
https://docs.wandb.ai
|
||||
|
@ -57,7 +69,8 @@ class WandbLoggerHook(LoggerHook):
|
|||
by_epoch: bool = True,
|
||||
with_step: bool = True,
|
||||
log_artifact: bool = True,
|
||||
out_suffix: Union[str, tuple] = ('.log.json', '.log', '.py')):
|
||||
out_suffix: Union[str, tuple] = ('.log.json', '.log', '.py'),
|
||||
define_metric_cfg: Optional[Dict] = None):
|
||||
super().__init__(interval, ignore_last, reset_flag, by_epoch)
|
||||
self.import_wandb()
|
||||
self.init_kwargs = init_kwargs
|
||||
|
@ -65,6 +78,7 @@ class WandbLoggerHook(LoggerHook):
|
|||
self.with_step = with_step
|
||||
self.log_artifact = log_artifact
|
||||
self.out_suffix = out_suffix
|
||||
self.define_metric_cfg = define_metric_cfg
|
||||
|
||||
def import_wandb(self) -> None:
|
||||
try:
|
||||
|
@ -83,6 +97,15 @@ class WandbLoggerHook(LoggerHook):
|
|||
self.wandb.init(**self.init_kwargs) # type: ignore
|
||||
else:
|
||||
self.wandb.init() # type: ignore
|
||||
summary_choice = ['min', 'max', 'mean', 'best', 'last', 'none']
|
||||
if self.define_metric_cfg is not None:
|
||||
for metric, summary in self.define_metric_cfg.items():
|
||||
if summary not in summary_choice:
|
||||
warnings.warn(
|
||||
f'summary should be in {summary_choice}. '
|
||||
f'metric={metric}, summary={summary} will be skipped.')
|
||||
self.wandb.define_metric( # type: ignore
|
||||
metric, summary=summary)
|
||||
|
||||
@master_only
|
||||
def log(self, runner) -> None:
|
||||
|
|
|
@ -1606,7 +1606,8 @@ def test_segmind_hook():
|
|||
def test_wandb_hook():
|
||||
sys.modules['wandb'] = MagicMock()
|
||||
runner = _build_demo_runner()
|
||||
hook = WandbLoggerHook(log_artifact=True)
|
||||
hook = WandbLoggerHook(
|
||||
log_artifact=True, define_metric_cfg={'val/loss': 'min'})
|
||||
loader = DataLoader(torch.ones((5, 2)))
|
||||
|
||||
runner.register_hook(hook)
|
||||
|
@ -1615,6 +1616,7 @@ def test_wandb_hook():
|
|||
shutil.rmtree(runner.work_dir)
|
||||
|
||||
hook.wandb.init.assert_called_with()
|
||||
hook.wandb.define_metric.assert_called_with('val/loss', summary='min')
|
||||
hook.wandb.log.assert_called_with({
|
||||
'learning_rate': 0.02,
|
||||
'momentum': 0.95
|
||||
|
|
Loading…
Reference in New Issue