[Feature] Add DVCLiveVisBackend (#1336)

This commit is contained in:
Range King 2023-09-08 17:22:23 +08:00 committed by GitHub
parent 45ee96d0c4
commit 273fb2b333
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 328 additions and 10 deletions

View File

@ -36,3 +36,4 @@ visualization Backend
WandbVisBackend
ClearMLVisBackend
NeptuneVisBackend
DVCLiveVisBackend

View File

@ -1,6 +1,6 @@
# Visualize Training Logs
MMEngine integrates experiment management tools such as [TensorBoard](https://www.tensorflow.org/tensorboard), [Weights & Biases (WandB)](https://docs.wandb.ai/), [MLflow](https://mlflow.org/docs/latest/index.html), [ClearML](https://clear.ml/docs/latest/docs) and [Neptune](https://docs.neptune.ai/), making it easy to track and visualize metrics like loss and accuracy.
MMEngine integrates experiment management tools such as [TensorBoard](https://www.tensorflow.org/tensorboard), [Weights & Biases (WandB)](https://docs.wandb.ai/), [MLflow](https://mlflow.org/docs/latest/index.html), [ClearML](https://clear.ml/docs/latest/docs), [Neptune](https://docs.neptune.ai/) and [DVCLive](https://dvc.org/doc/dvclive), making it easy to track and visualize metrics like loss and accuracy.
Below, we'll show you how to configure an experiment management tool in just one line, based on the example from [15 minutes to get started with MMEngine](../get_started/15_minutes.md).
@ -149,3 +149,44 @@ runner.train()
```
More initialization configuration parameters are available at [neptune.init_run API](https://docs.neptune.ai/api/neptune/#init_run).
## DVCLive
Before using DVCLive, you need to install `dvclive` dependency library and refer to [iterative.ai](https://dvc.org/doc/start) for configuration. Common configurations are as follows:
```bash
pip install dvclive
cd ${WORK_DIR}
git init
dvc init
git commit -m "DVC init"
```
Configure the `Runner` in the initialization parameters of the Runner, and set `vis_backends` to [DVCLiveVisBackend](mmengine.visualization.DVCLiveVisBackend).
```python
runner = Runner(
model=MMResNet50(),
work_dir='./work_dir_dvc',
train_dataloader=train_dataloader,
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
val_dataloader=val_dataloader,
val_cfg=dict(),
val_evaluator=dict(type=Accuracy),
visualizer=dict(type='Visualizer', vis_backends=[dict(type='DVCLiveVisBackend')]),
)
runner.train()
```
```{note}
Recommend not to set `work_dir` as `work_dirs`. Or DVC will give a warning `WARNING:dvclive:Error in cache: bad DVC file name 'work_dirs\xxx.dvc' is git-ignored` if you run experiments in a OpenMMLab's repo.
```
Open the `report.html` file under `work_dir_dvc`, and you will see the visualization as shown in the following image.
![image](https://github.com/open-mmlab/mmengine/assets/58739961/47d85520-9a4a-4143-a449-12ed7347cc63)
You can also configure a VSCode extension of [DVC](https://marketplace.visualstudio.com/items?itemName=Iterative.dvc) to visualize the training process.
More initialization configuration parameters are available at [DVCLive API Reference](https://dvc.org/doc/dvclive/live).

View File

@ -36,3 +36,4 @@ visualization Backend
WandbVisBackend
ClearMLVisBackend
NeptuneVisBackend
DVCLiveVisBackend

View File

@ -1,6 +1,6 @@
# 可视化训练日志
MMEngine 集成了 [TensorBoard](https://www.tensorflow.org/tensorboard?hl=zh-cn)、[Weights & Biases (WandB)](https://docs.wandb.ai/)、[MLflow](https://mlflow.org/docs/latest/index.html) 、[ClearML](https://clear.ml/docs/latest/docs) 和 [Neptune](https://docs.neptune.ai/) 实验管理工具,你可以很方便地跟踪和可视化损失及准确率等指标。
MMEngine 集成了 [TensorBoard](https://www.tensorflow.org/tensorboard?hl=zh-cn)、[Weights & Biases (WandB)](https://docs.wandb.ai/)、[MLflow](https://mlflow.org/docs/latest/index.html) 、[ClearML](https://clear.ml/docs/latest/docs)、[Neptune](https://docs.neptune.ai/) 和 [DVCLive](https://dvc.org/doc/dvclive) 实验管理工具,你可以很方便地跟踪和可视化损失及准确率等指标。
下面基于[15 分钟上手 MMENGINE](../get_started/15_minutes.md)中的例子介绍如何一行配置实验管理工具。
@ -149,3 +149,44 @@ runner.train()
```
更多初始化配置参数可点击 [neptune.init_run API](https://docs.neptune.ai/api/neptune/#init_run) 查询。
## DVCLive
使用 DVCLive 前需先安装依赖库 `dvclive` 并参考 [iterative.ai](https://dvc.org/doc/start) 进行配置。常见的配置方式如下:
```bash
pip install dvclive
cd ${WORK_DIR}
git init
dvc init
git commit -m "DVC init"
```
设置 `Runner` 初始化参数中的 `visualizer`,并将 `vis_backends` 设置为 [DVCLiveVisBackend](mmengine.visualization.DVCLiveVisBackend)。
```python
runner = Runner(
model=MMResNet50(),
work_dir='./work_dir_dvc',
train_dataloader=train_dataloader,
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
val_dataloader=val_dataloader,
val_cfg=dict(),
val_evaluator=dict(type=Accuracy),
visualizer=dict(type='Visualizer', vis_backends=[dict(type='DVCLiveVisBackend')]),
)
runner.train()
```
```{note}
推荐将 `work_dir` 设置为 `work_dirs`。否则,你在 OpenMMLab 仓库中运行试验时DVC 会给出警告 `WARNING:dvclive:Error in cache: bad DVC file name 'work_dirs\xxx.dvc' is git-ignored`
```
打开 `work_dir_dvc` 下面的 `report.html` 文件,即可看到如下图的可视化效果。
![image](https://github.com/open-mmlab/mmengine/assets/58739961/47d85520-9a4a-4143-a449-12ed7347cc63)
你还可以安装 VSCode 扩展 [DVC](https://marketplace.visualstudio.com/items?itemName=Iterative.dvc) 进行可视化。
更多初始化配置参数可点击 [DVCLive API Reference](https://dvc.org/doc/dvclive/live) 查询。

View File

@ -1,11 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .vis_backend import (BaseVisBackend, ClearMLVisBackend, LocalVisBackend,
MLflowVisBackend, NeptuneVisBackend,
from .vis_backend import (BaseVisBackend, ClearMLVisBackend, DVCLiveVisBackend,
LocalVisBackend, MLflowVisBackend, NeptuneVisBackend,
TensorboardVisBackend, WandbVisBackend)
from .visualizer import Visualizer
__all__ = [
'Visualizer', 'BaseVisBackend', 'LocalVisBackend', 'WandbVisBackend',
'TensorboardVisBackend', 'MLflowVisBackend', 'ClearMLVisBackend',
'NeptuneVisBackend'
'NeptuneVisBackend', 'DVCLiveVisBackend'
]

View File

@ -4,6 +4,7 @@ import functools
import logging
import os
import os.path as osp
import platform
import warnings
from abc import ABCMeta, abstractmethod
from collections.abc import MutableMapping
@ -13,12 +14,12 @@ import cv2
import numpy as np
import torch
from mmengine.config import Config
from mmengine.config import Config, ConfigDict
from mmengine.fileio import dump
from mmengine.hooks.logger_hook import SUFFIX_TYPE
from mmengine.logging import MMLogger, print_log
from mmengine.registry import VISBACKENDS
from mmengine.utils import scandir
from mmengine.utils import digit_version, scandir
from mmengine.utils.dl_utils import TORCH_VERSION
@ -1130,3 +1131,181 @@ class NeptuneVisBackend(BaseVisBackend):
"""close an opened object."""
if hasattr(self, '_neptune'):
self._neptune.stop()
@VISBACKENDS.register_module()
class DVCLiveVisBackend(BaseVisBackend):
"""DVCLive visualization backend class.
Examples:
>>> from mmengine.visualization import DVCLiveVisBackend
>>> import numpy as np
>>> dvclive_vis_backend = DVCLiveVisBackend(save_dir='temp_dir')
>>> img=np.random.randint(0, 256, size=(10, 10, 3))
>>> dvclive_vis_backend.add_image('img', img)
>>> dvclive_vis_backend.add_scalar('mAP', 0.6)
>>> dvclive_vis_backend.add_scalars({'loss': 0.1, 'acc': 0.8})
>>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
>>> dvclive_vis_backend.add_config(cfg)
Note:
`New in version 0.8.5.`
Args:
save_dir (str, optional): The root directory to save the files
produced by the visualizer.
artifact_suffix (Tuple[str] or str, optional): The artifact suffix.
Defaults to ('.json', '.py', 'yaml').
init_kwargs (dict, optional): DVCLive initialization parameters.
See `DVCLive <https://dvc.org/doc/dvclive/live>`_ for details.
Defaults to None.
"""
def __init__(self,
save_dir: str,
artifact_suffix: SUFFIX_TYPE = ('.json', '.py', 'yaml'),
init_kwargs: Optional[dict] = None):
super().__init__(save_dir)
self._artifact_suffix = artifact_suffix
self._init_kwargs = init_kwargs
def _init_env(self):
"""Setup env for dvclive."""
if digit_version(platform.python_version()) < digit_version('3.8'):
raise RuntimeError('Please use Python 3.8 or higher version '
'to use DVCLiveVisBackend.')
try:
import pygit2
from dvclive import Live
except ImportError:
raise ImportError(
'Please run "pip install dvclive" to install dvclive')
# if no git info, init dvc without git to avoid SCMError
try:
path = pygit2.discover_repository(os.fspath(os.curdir), True, '')
pygit2.Repository(path).default_signature
except KeyError:
os.system('dvc init -f --no-scm')
if self._init_kwargs is None:
self._init_kwargs = {}
self._init_kwargs.setdefault('dir', self._save_dir)
self._init_kwargs.setdefault('save_dvc_exp', True)
self._init_kwargs.setdefault('cache_images', True)
self._dvclive = Live(**self._init_kwargs)
@property # type: ignore
@force_init_env
def experiment(self):
"""Return dvclive object.
The experiment attribute can get the dvclive backend, If you want to
write other data, such as writing a table, you can directly get the
dvclive backend through experiment.
"""
return self._dvclive
@force_init_env
def add_config(self, config: Config, **kwargs) -> None:
"""Record the config to dvclive.
Args:
config (Config): The Config object
"""
assert isinstance(config, Config)
self.cfg = config
self._dvclive.log_params(self._to_dvc_paramlike(self.cfg))
@force_init_env
def add_image(self,
name: str,
image: np.ndarray,
step: int = 0,
**kwargs) -> None:
"""Record the image to dvclive.
Args:
name (str): The image identifier.
image (np.ndarray): The image to be saved. The format
should be RGB.
step (int): Useless parameter. Dvclive does not
need this parameter. Defaults to 0.
"""
assert image.dtype == np.uint8
save_file_name = f'{name}.png'
self._dvclive.log_image(save_file_name, image)
@force_init_env
def add_scalar(self,
name: str,
value: Union[int, float, torch.Tensor, np.ndarray],
step: int = 0,
**kwargs) -> None:
"""Record the scalar data to dvclive.
Args:
name (str): The scalar identifier.
value (int, float, torch.Tensor, np.ndarray): Value to save.
step (int): Global step value to record. Defaults to 0.
"""
if isinstance(value, torch.Tensor):
value = value.numpy()
self._dvclive.step = step
self._dvclive.log_metric(name, value)
@force_init_env
def add_scalars(self,
scalar_dict: dict,
step: int = 0,
file_path: Optional[str] = None,
**kwargs) -> None:
"""Record the scalar's data to dvclive.
Args:
scalar_dict (dict): Key-value pair storing the tag and
corresponding values.
step (int): Global step value to record. Defaults to 0.
file_path (str, optional): Useless parameter. Just for
interface unification. Defaults to None.
"""
for key, value in scalar_dict.items():
self.add_scalar(key, value, step, **kwargs)
def close(self) -> None:
"""close an opened dvclive object."""
if not hasattr(self, '_dvclive'):
return
file_paths = dict()
for filename in scandir(self._save_dir, self._artifact_suffix, True):
file_path = osp.join(self._save_dir, filename)
relative_path = os.path.relpath(file_path, self._save_dir)
dir_path = os.path.dirname(relative_path)
file_paths[file_path] = dir_path
for file_path, dir_path in file_paths.items():
self._dvclive.log_artifact(file_path, dir_path)
self._dvclive.end()
def _to_dvc_paramlike(self,
value: Union[int, float, dict, list, tuple, Config,
ConfigDict, torch.Tensor, np.ndarray]):
"""Convert the input value to a DVC `ParamLike` recursively.
Or the `log_params` method of dvclive will raise an error.
"""
if isinstance(value, (dict, Config, ConfigDict)):
return {k: self._to_dvc_paramlike(v) for k, v in value.items()}
elif isinstance(value, (tuple, list)):
return [self._to_dvc_paramlike(item) for item in value]
elif isinstance(value, (torch.Tensor, np.ndarray)):
return value.tolist()
elif isinstance(value, np.generic):
return value.item()
else:
return value

View File

@ -1,6 +1,7 @@
clearml
coverage
dadaptation
dvclive
lion-pytorch
lmdb
mlflow

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import platform
import shutil
import sys
import warnings
@ -12,9 +13,11 @@ import torch
from mmengine import Config
from mmengine.fileio import load
from mmengine.registry import VISBACKENDS
from mmengine.visualization import (ClearMLVisBackend, LocalVisBackend,
MLflowVisBackend, NeptuneVisBackend,
TensorboardVisBackend, WandbVisBackend)
from mmengine.utils import digit_version
from mmengine.visualization import (ClearMLVisBackend, DVCLiveVisBackend,
LocalVisBackend, MLflowVisBackend,
NeptuneVisBackend, TensorboardVisBackend,
WandbVisBackend)
class TestLocalVisBackend:
@ -391,3 +394,54 @@ class TestNeptuneVisBackend:
neptune_vis_backend = NeptuneVisBackend()
neptune_vis_backend._init_env()
neptune_vis_backend.close()
@pytest.mark.skipif(
digit_version(platform.python_version()) < digit_version('3.8'),
reason='DVCLiveVisBackend does not support python version < 3.8')
class TestDVCLiveVisBackend:
def test_init(self):
DVCLiveVisBackend('temp_dir')
VISBACKENDS.build(dict(type='DVCLiveVisBackend', save_dir='temp_dir'))
def test_experiment(self):
dvclive_vis_backend = DVCLiveVisBackend('temp_dir')
assert dvclive_vis_backend.experiment == dvclive_vis_backend._dvclive
shutil.rmtree('temp_dir')
def test_add_config(self):
cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
dvclive_vis_backend = DVCLiveVisBackend('temp_dir')
dvclive_vis_backend.add_config(cfg)
shutil.rmtree('temp_dir')
def test_add_image(self):
img = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8)
dvclive_vis_backend = DVCLiveVisBackend('temp_dir')
dvclive_vis_backend.add_image('img', img)
shutil.rmtree('temp_dir')
def test_add_scalar(self):
dvclive_vis_backend = DVCLiveVisBackend('temp_dir')
dvclive_vis_backend.add_scalar('mAP', 0.9)
# test append mode
dvclive_vis_backend.add_scalar('mAP', 0.9)
dvclive_vis_backend.add_scalar('mAP', 0.95)
shutil.rmtree('temp_dir')
def test_add_scalars(self):
dvclive_vis_backend = DVCLiveVisBackend('temp_dir')
input_dict = {'map': 0.7, 'acc': 0.9}
dvclive_vis_backend.add_scalars(input_dict)
# test append mode
dvclive_vis_backend.add_scalars({'map': 0.8, 'acc': 0.8})
shutil.rmtree('temp_dir')
def test_close(self):
cfg = Config(dict(work_dir='temp_dir'))
dvclive_vis_backend = DVCLiveVisBackend('temp_dir')
dvclive_vis_backend._init_env()
dvclive_vis_backend.add_config(cfg)
dvclive_vis_backend.close()
shutil.rmtree('temp_dir')