mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Add DVCLiveVisBackend (#1336)
This commit is contained in:
parent
45ee96d0c4
commit
273fb2b333
@ -36,3 +36,4 @@ visualization Backend
|
||||
WandbVisBackend
|
||||
ClearMLVisBackend
|
||||
NeptuneVisBackend
|
||||
DVCLiveVisBackend
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Visualize Training Logs
|
||||
|
||||
MMEngine integrates experiment management tools such as [TensorBoard](https://www.tensorflow.org/tensorboard), [Weights & Biases (WandB)](https://docs.wandb.ai/), [MLflow](https://mlflow.org/docs/latest/index.html), [ClearML](https://clear.ml/docs/latest/docs) and [Neptune](https://docs.neptune.ai/), making it easy to track and visualize metrics like loss and accuracy.
|
||||
MMEngine integrates experiment management tools such as [TensorBoard](https://www.tensorflow.org/tensorboard), [Weights & Biases (WandB)](https://docs.wandb.ai/), [MLflow](https://mlflow.org/docs/latest/index.html), [ClearML](https://clear.ml/docs/latest/docs), [Neptune](https://docs.neptune.ai/) and [DVCLive](https://dvc.org/doc/dvclive), making it easy to track and visualize metrics like loss and accuracy.
|
||||
|
||||
Below, we'll show you how to configure an experiment management tool in just one line, based on the example from [15 minutes to get started with MMEngine](../get_started/15_minutes.md).
|
||||
|
||||
@ -149,3 +149,44 @@ runner.train()
|
||||
```
|
||||
|
||||
More initialization configuration parameters are available at [neptune.init_run API](https://docs.neptune.ai/api/neptune/#init_run).
|
||||
|
||||
## DVCLive
|
||||
|
||||
Before using DVCLive, you need to install `dvclive` dependency library and refer to [iterative.ai](https://dvc.org/doc/start) for configuration. Common configurations are as follows:
|
||||
|
||||
```bash
|
||||
pip install dvclive
|
||||
cd ${WORK_DIR}
|
||||
git init
|
||||
dvc init
|
||||
git commit -m "DVC init"
|
||||
```
|
||||
|
||||
Configure the `Runner` in the initialization parameters of the Runner, and set `vis_backends` to [DVCLiveVisBackend](mmengine.visualization.DVCLiveVisBackend).
|
||||
|
||||
```python
|
||||
runner = Runner(
|
||||
model=MMResNet50(),
|
||||
work_dir='./work_dir_dvc',
|
||||
train_dataloader=train_dataloader,
|
||||
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
|
||||
train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
|
||||
val_dataloader=val_dataloader,
|
||||
val_cfg=dict(),
|
||||
val_evaluator=dict(type=Accuracy),
|
||||
visualizer=dict(type='Visualizer', vis_backends=[dict(type='DVCLiveVisBackend')]),
|
||||
)
|
||||
runner.train()
|
||||
```
|
||||
|
||||
```{note}
|
||||
Recommend not to set `work_dir` as `work_dirs`. Or DVC will give a warning `WARNING:dvclive:Error in cache: bad DVC file name 'work_dirs\xxx.dvc' is git-ignored` if you run experiments in a OpenMMLab's repo.
|
||||
```
|
||||
|
||||
Open the `report.html` file under `work_dir_dvc`, and you will see the visualization as shown in the following image.
|
||||
|
||||

|
||||
|
||||
You can also configure a VSCode extension of [DVC](https://marketplace.visualstudio.com/items?itemName=Iterative.dvc) to visualize the training process.
|
||||
|
||||
More initialization configuration parameters are available at [DVCLive API Reference](https://dvc.org/doc/dvclive/live).
|
||||
|
@ -36,3 +36,4 @@ visualization Backend
|
||||
WandbVisBackend
|
||||
ClearMLVisBackend
|
||||
NeptuneVisBackend
|
||||
DVCLiveVisBackend
|
||||
|
@ -1,6 +1,6 @@
|
||||
# 可视化训练日志
|
||||
|
||||
MMEngine 集成了 [TensorBoard](https://www.tensorflow.org/tensorboard?hl=zh-cn)、[Weights & Biases (WandB)](https://docs.wandb.ai/)、[MLflow](https://mlflow.org/docs/latest/index.html) 、[ClearML](https://clear.ml/docs/latest/docs) 和 [Neptune](https://docs.neptune.ai/) 实验管理工具,你可以很方便地跟踪和可视化损失及准确率等指标。
|
||||
MMEngine 集成了 [TensorBoard](https://www.tensorflow.org/tensorboard?hl=zh-cn)、[Weights & Biases (WandB)](https://docs.wandb.ai/)、[MLflow](https://mlflow.org/docs/latest/index.html) 、[ClearML](https://clear.ml/docs/latest/docs)、[Neptune](https://docs.neptune.ai/) 和 [DVCLive](https://dvc.org/doc/dvclive) 实验管理工具,你可以很方便地跟踪和可视化损失及准确率等指标。
|
||||
|
||||
下面基于[15 分钟上手 MMENGINE](../get_started/15_minutes.md)中的例子介绍如何一行配置实验管理工具。
|
||||
|
||||
@ -149,3 +149,44 @@ runner.train()
|
||||
```
|
||||
|
||||
更多初始化配置参数可点击 [neptune.init_run API](https://docs.neptune.ai/api/neptune/#init_run) 查询。
|
||||
|
||||
## DVCLive
|
||||
|
||||
使用 DVCLive 前需先安装依赖库 `dvclive` 并参考 [iterative.ai](https://dvc.org/doc/start) 进行配置。常见的配置方式如下:
|
||||
|
||||
```bash
|
||||
pip install dvclive
|
||||
cd ${WORK_DIR}
|
||||
git init
|
||||
dvc init
|
||||
git commit -m "DVC init"
|
||||
```
|
||||
|
||||
设置 `Runner` 初始化参数中的 `visualizer`,并将 `vis_backends` 设置为 [DVCLiveVisBackend](mmengine.visualization.DVCLiveVisBackend)。
|
||||
|
||||
```python
|
||||
runner = Runner(
|
||||
model=MMResNet50(),
|
||||
work_dir='./work_dir_dvc',
|
||||
train_dataloader=train_dataloader,
|
||||
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
|
||||
train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
|
||||
val_dataloader=val_dataloader,
|
||||
val_cfg=dict(),
|
||||
val_evaluator=dict(type=Accuracy),
|
||||
visualizer=dict(type='Visualizer', vis_backends=[dict(type='DVCLiveVisBackend')]),
|
||||
)
|
||||
runner.train()
|
||||
```
|
||||
|
||||
```{note}
|
||||
推荐将 `work_dir` 设置为 `work_dirs`。否则,你在 OpenMMLab 仓库中运行试验时,DVC 会给出警告 `WARNING:dvclive:Error in cache: bad DVC file name 'work_dirs\xxx.dvc' is git-ignored`。
|
||||
```
|
||||
|
||||
打开 `work_dir_dvc` 下面的 `report.html` 文件,即可看到如下图的可视化效果。
|
||||
|
||||

|
||||
|
||||
你还可以安装 VSCode 扩展 [DVC](https://marketplace.visualstudio.com/items?itemName=Iterative.dvc) 进行可视化。
|
||||
|
||||
更多初始化配置参数可点击 [DVCLive API Reference](https://dvc.org/doc/dvclive/live) 查询。
|
||||
|
@ -1,11 +1,11 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .vis_backend import (BaseVisBackend, ClearMLVisBackend, LocalVisBackend,
|
||||
MLflowVisBackend, NeptuneVisBackend,
|
||||
from .vis_backend import (BaseVisBackend, ClearMLVisBackend, DVCLiveVisBackend,
|
||||
LocalVisBackend, MLflowVisBackend, NeptuneVisBackend,
|
||||
TensorboardVisBackend, WandbVisBackend)
|
||||
from .visualizer import Visualizer
|
||||
|
||||
__all__ = [
|
||||
'Visualizer', 'BaseVisBackend', 'LocalVisBackend', 'WandbVisBackend',
|
||||
'TensorboardVisBackend', 'MLflowVisBackend', 'ClearMLVisBackend',
|
||||
'NeptuneVisBackend'
|
||||
'NeptuneVisBackend', 'DVCLiveVisBackend'
|
||||
]
|
||||
|
@ -4,6 +4,7 @@ import functools
|
||||
import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
import platform
|
||||
import warnings
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from collections.abc import MutableMapping
|
||||
@ -13,12 +14,12 @@ import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from mmengine.config import Config
|
||||
from mmengine.config import Config, ConfigDict
|
||||
from mmengine.fileio import dump
|
||||
from mmengine.hooks.logger_hook import SUFFIX_TYPE
|
||||
from mmengine.logging import MMLogger, print_log
|
||||
from mmengine.registry import VISBACKENDS
|
||||
from mmengine.utils import scandir
|
||||
from mmengine.utils import digit_version, scandir
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
|
||||
|
||||
@ -1130,3 +1131,181 @@ class NeptuneVisBackend(BaseVisBackend):
|
||||
"""close an opened object."""
|
||||
if hasattr(self, '_neptune'):
|
||||
self._neptune.stop()
|
||||
|
||||
|
||||
@VISBACKENDS.register_module()
|
||||
class DVCLiveVisBackend(BaseVisBackend):
|
||||
"""DVCLive visualization backend class.
|
||||
|
||||
Examples:
|
||||
>>> from mmengine.visualization import DVCLiveVisBackend
|
||||
>>> import numpy as np
|
||||
>>> dvclive_vis_backend = DVCLiveVisBackend(save_dir='temp_dir')
|
||||
>>> img=np.random.randint(0, 256, size=(10, 10, 3))
|
||||
>>> dvclive_vis_backend.add_image('img', img)
|
||||
>>> dvclive_vis_backend.add_scalar('mAP', 0.6)
|
||||
>>> dvclive_vis_backend.add_scalars({'loss': 0.1, 'acc': 0.8})
|
||||
>>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
|
||||
>>> dvclive_vis_backend.add_config(cfg)
|
||||
|
||||
Note:
|
||||
`New in version 0.8.5.`
|
||||
|
||||
Args:
|
||||
save_dir (str, optional): The root directory to save the files
|
||||
produced by the visualizer.
|
||||
artifact_suffix (Tuple[str] or str, optional): The artifact suffix.
|
||||
Defaults to ('.json', '.py', 'yaml').
|
||||
init_kwargs (dict, optional): DVCLive initialization parameters.
|
||||
See `DVCLive <https://dvc.org/doc/dvclive/live>`_ for details.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
save_dir: str,
|
||||
artifact_suffix: SUFFIX_TYPE = ('.json', '.py', 'yaml'),
|
||||
init_kwargs: Optional[dict] = None):
|
||||
super().__init__(save_dir)
|
||||
self._artifact_suffix = artifact_suffix
|
||||
self._init_kwargs = init_kwargs
|
||||
|
||||
def _init_env(self):
|
||||
"""Setup env for dvclive."""
|
||||
if digit_version(platform.python_version()) < digit_version('3.8'):
|
||||
raise RuntimeError('Please use Python 3.8 or higher version '
|
||||
'to use DVCLiveVisBackend.')
|
||||
|
||||
try:
|
||||
import pygit2
|
||||
from dvclive import Live
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Please run "pip install dvclive" to install dvclive')
|
||||
# if no git info, init dvc without git to avoid SCMError
|
||||
try:
|
||||
path = pygit2.discover_repository(os.fspath(os.curdir), True, '')
|
||||
pygit2.Repository(path).default_signature
|
||||
except KeyError:
|
||||
os.system('dvc init -f --no-scm')
|
||||
|
||||
if self._init_kwargs is None:
|
||||
self._init_kwargs = {}
|
||||
self._init_kwargs.setdefault('dir', self._save_dir)
|
||||
self._init_kwargs.setdefault('save_dvc_exp', True)
|
||||
self._init_kwargs.setdefault('cache_images', True)
|
||||
|
||||
self._dvclive = Live(**self._init_kwargs)
|
||||
|
||||
@property # type: ignore
|
||||
@force_init_env
|
||||
def experiment(self):
|
||||
"""Return dvclive object.
|
||||
|
||||
The experiment attribute can get the dvclive backend, If you want to
|
||||
write other data, such as writing a table, you can directly get the
|
||||
dvclive backend through experiment.
|
||||
"""
|
||||
return self._dvclive
|
||||
|
||||
@force_init_env
|
||||
def add_config(self, config: Config, **kwargs) -> None:
|
||||
"""Record the config to dvclive.
|
||||
|
||||
Args:
|
||||
config (Config): The Config object
|
||||
"""
|
||||
assert isinstance(config, Config)
|
||||
self.cfg = config
|
||||
self._dvclive.log_params(self._to_dvc_paramlike(self.cfg))
|
||||
|
||||
@force_init_env
|
||||
def add_image(self,
|
||||
name: str,
|
||||
image: np.ndarray,
|
||||
step: int = 0,
|
||||
**kwargs) -> None:
|
||||
"""Record the image to dvclive.
|
||||
|
||||
Args:
|
||||
name (str): The image identifier.
|
||||
image (np.ndarray): The image to be saved. The format
|
||||
should be RGB.
|
||||
step (int): Useless parameter. Dvclive does not
|
||||
need this parameter. Defaults to 0.
|
||||
"""
|
||||
assert image.dtype == np.uint8
|
||||
save_file_name = f'{name}.png'
|
||||
|
||||
self._dvclive.log_image(save_file_name, image)
|
||||
|
||||
@force_init_env
|
||||
def add_scalar(self,
|
||||
name: str,
|
||||
value: Union[int, float, torch.Tensor, np.ndarray],
|
||||
step: int = 0,
|
||||
**kwargs) -> None:
|
||||
"""Record the scalar data to dvclive.
|
||||
|
||||
Args:
|
||||
name (str): The scalar identifier.
|
||||
value (int, float, torch.Tensor, np.ndarray): Value to save.
|
||||
step (int): Global step value to record. Defaults to 0.
|
||||
"""
|
||||
if isinstance(value, torch.Tensor):
|
||||
value = value.numpy()
|
||||
self._dvclive.step = step
|
||||
self._dvclive.log_metric(name, value)
|
||||
|
||||
@force_init_env
|
||||
def add_scalars(self,
|
||||
scalar_dict: dict,
|
||||
step: int = 0,
|
||||
file_path: Optional[str] = None,
|
||||
**kwargs) -> None:
|
||||
"""Record the scalar's data to dvclive.
|
||||
|
||||
Args:
|
||||
scalar_dict (dict): Key-value pair storing the tag and
|
||||
corresponding values.
|
||||
step (int): Global step value to record. Defaults to 0.
|
||||
file_path (str, optional): Useless parameter. Just for
|
||||
interface unification. Defaults to None.
|
||||
"""
|
||||
for key, value in scalar_dict.items():
|
||||
self.add_scalar(key, value, step, **kwargs)
|
||||
|
||||
def close(self) -> None:
|
||||
"""close an opened dvclive object."""
|
||||
if not hasattr(self, '_dvclive'):
|
||||
return
|
||||
|
||||
file_paths = dict()
|
||||
for filename in scandir(self._save_dir, self._artifact_suffix, True):
|
||||
file_path = osp.join(self._save_dir, filename)
|
||||
relative_path = os.path.relpath(file_path, self._save_dir)
|
||||
dir_path = os.path.dirname(relative_path)
|
||||
file_paths[file_path] = dir_path
|
||||
|
||||
for file_path, dir_path in file_paths.items():
|
||||
self._dvclive.log_artifact(file_path, dir_path)
|
||||
|
||||
self._dvclive.end()
|
||||
|
||||
def _to_dvc_paramlike(self,
|
||||
value: Union[int, float, dict, list, tuple, Config,
|
||||
ConfigDict, torch.Tensor, np.ndarray]):
|
||||
"""Convert the input value to a DVC `ParamLike` recursively.
|
||||
|
||||
Or the `log_params` method of dvclive will raise an error.
|
||||
"""
|
||||
|
||||
if isinstance(value, (dict, Config, ConfigDict)):
|
||||
return {k: self._to_dvc_paramlike(v) for k, v in value.items()}
|
||||
elif isinstance(value, (tuple, list)):
|
||||
return [self._to_dvc_paramlike(item) for item in value]
|
||||
elif isinstance(value, (torch.Tensor, np.ndarray)):
|
||||
return value.tolist()
|
||||
elif isinstance(value, np.generic):
|
||||
return value.item()
|
||||
else:
|
||||
return value
|
||||
|
@ -1,6 +1,7 @@
|
||||
clearml
|
||||
coverage
|
||||
dadaptation
|
||||
dvclive
|
||||
lion-pytorch
|
||||
lmdb
|
||||
mlflow
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import sys
|
||||
import warnings
|
||||
@ -12,9 +13,11 @@ import torch
|
||||
from mmengine import Config
|
||||
from mmengine.fileio import load
|
||||
from mmengine.registry import VISBACKENDS
|
||||
from mmengine.visualization import (ClearMLVisBackend, LocalVisBackend,
|
||||
MLflowVisBackend, NeptuneVisBackend,
|
||||
TensorboardVisBackend, WandbVisBackend)
|
||||
from mmengine.utils import digit_version
|
||||
from mmengine.visualization import (ClearMLVisBackend, DVCLiveVisBackend,
|
||||
LocalVisBackend, MLflowVisBackend,
|
||||
NeptuneVisBackend, TensorboardVisBackend,
|
||||
WandbVisBackend)
|
||||
|
||||
|
||||
class TestLocalVisBackend:
|
||||
@ -391,3 +394,54 @@ class TestNeptuneVisBackend:
|
||||
neptune_vis_backend = NeptuneVisBackend()
|
||||
neptune_vis_backend._init_env()
|
||||
neptune_vis_backend.close()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
digit_version(platform.python_version()) < digit_version('3.8'),
|
||||
reason='DVCLiveVisBackend does not support python version < 3.8')
|
||||
class TestDVCLiveVisBackend:
|
||||
|
||||
def test_init(self):
|
||||
DVCLiveVisBackend('temp_dir')
|
||||
VISBACKENDS.build(dict(type='DVCLiveVisBackend', save_dir='temp_dir'))
|
||||
|
||||
def test_experiment(self):
|
||||
dvclive_vis_backend = DVCLiveVisBackend('temp_dir')
|
||||
assert dvclive_vis_backend.experiment == dvclive_vis_backend._dvclive
|
||||
shutil.rmtree('temp_dir')
|
||||
|
||||
def test_add_config(self):
|
||||
cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
|
||||
dvclive_vis_backend = DVCLiveVisBackend('temp_dir')
|
||||
dvclive_vis_backend.add_config(cfg)
|
||||
shutil.rmtree('temp_dir')
|
||||
|
||||
def test_add_image(self):
|
||||
img = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8)
|
||||
dvclive_vis_backend = DVCLiveVisBackend('temp_dir')
|
||||
dvclive_vis_backend.add_image('img', img)
|
||||
shutil.rmtree('temp_dir')
|
||||
|
||||
def test_add_scalar(self):
|
||||
dvclive_vis_backend = DVCLiveVisBackend('temp_dir')
|
||||
dvclive_vis_backend.add_scalar('mAP', 0.9)
|
||||
# test append mode
|
||||
dvclive_vis_backend.add_scalar('mAP', 0.9)
|
||||
dvclive_vis_backend.add_scalar('mAP', 0.95)
|
||||
shutil.rmtree('temp_dir')
|
||||
|
||||
def test_add_scalars(self):
|
||||
dvclive_vis_backend = DVCLiveVisBackend('temp_dir')
|
||||
input_dict = {'map': 0.7, 'acc': 0.9}
|
||||
dvclive_vis_backend.add_scalars(input_dict)
|
||||
# test append mode
|
||||
dvclive_vis_backend.add_scalars({'map': 0.8, 'acc': 0.8})
|
||||
shutil.rmtree('temp_dir')
|
||||
|
||||
def test_close(self):
|
||||
cfg = Config(dict(work_dir='temp_dir'))
|
||||
dvclive_vis_backend = DVCLiveVisBackend('temp_dir')
|
||||
dvclive_vis_backend._init_env()
|
||||
dvclive_vis_backend.add_config(cfg)
|
||||
dvclive_vis_backend.close()
|
||||
shutil.rmtree('temp_dir')
|
||||
|
Loading…
x
Reference in New Issue
Block a user