parent
5762b28847
commit
8bf1ecad38
|
@ -31,5 +31,6 @@ visualization Backend
|
|||
|
||||
BaseVisBackend
|
||||
LocalVisBackend
|
||||
MLflowVisBackend
|
||||
TensorboardVisBackend
|
||||
WandbVisBackend
|
||||
|
|
|
@ -31,5 +31,6 @@ visualization Backend
|
|||
|
||||
BaseVisBackend
|
||||
LocalVisBackend
|
||||
MLflowVisBackend
|
||||
TensorboardVisBackend
|
||||
WandbVisBackend
|
||||
|
|
|
@ -306,6 +306,9 @@ class LoggerHook(Hook):
|
|||
runner (Runner): The runner of the training/testing/validation
|
||||
process.
|
||||
"""
|
||||
# close the visualizer
|
||||
runner.visualizer.close()
|
||||
|
||||
# copy or upload logs to self.out_dir
|
||||
if self.out_dir is None:
|
||||
return
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .vis_backend import (BaseVisBackend, LocalVisBackend,
|
||||
from .vis_backend import (BaseVisBackend, LocalVisBackend, MLflowVisBackend,
|
||||
TensorboardVisBackend, WandbVisBackend)
|
||||
from .visualizer import Visualizer
|
||||
|
||||
__all__ = [
|
||||
'Visualizer', 'BaseVisBackend', 'LocalVisBackend', 'WandbVisBackend',
|
||||
'TensorboardVisBackend'
|
||||
'TensorboardVisBackend', 'MLflowVisBackend'
|
||||
]
|
||||
|
|
|
@ -6,6 +6,7 @@ import os
|
|||
import os.path as osp
|
||||
import warnings
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from collections.abc import MutableMapping
|
||||
from typing import Any, Callable, Optional, Sequence, Union
|
||||
|
||||
import cv2
|
||||
|
@ -14,8 +15,10 @@ import torch
|
|||
|
||||
from mmengine.config import Config
|
||||
from mmengine.fileio import dump
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.hooks.logger_hook import SUFFIX_TYPE
|
||||
from mmengine.logging import MMLogger, print_log
|
||||
from mmengine.registry import VISBACKENDS
|
||||
from mmengine.utils import scandir
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
|
||||
|
||||
|
@ -613,3 +616,203 @@ class TensorboardVisBackend(BaseVisBackend):
|
|||
"""close an opened tensorboard object."""
|
||||
if hasattr(self, '_tensorboard'):
|
||||
self._tensorboard.close()
|
||||
|
||||
|
||||
@VISBACKENDS.register_module()
|
||||
class MLflowVisBackend(BaseVisBackend):
|
||||
"""MLflow visualization backend class.
|
||||
|
||||
It can write images, config, scalars, etc. to a
|
||||
mlflow file.
|
||||
|
||||
Examples:
|
||||
>>> from mmengine.visualization import MLflowVisBackend
|
||||
>>> from mmengine import Config
|
||||
>>> import numpy as np
|
||||
>>> vis_backend = MLflowVisBackend(save_dir='temp_dir')
|
||||
>>> img = np.random.randint(0, 256, size=(10, 10, 3))
|
||||
>>> vis_backend.add_image('img.png', img)
|
||||
>>> vis_backend.add_scalar('mAP', 0.6)
|
||||
>>> vis_backend.add_scalars({'loss': 0.1,'acc':0.8})
|
||||
>>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
|
||||
>>> vis_backend.add_config(cfg)
|
||||
|
||||
Args:
|
||||
save_dir (str): The root directory to save the files
|
||||
produced by the backend.
|
||||
exp_name (str, optional): The experiment name. Default to None.
|
||||
run_name (str, optional): The run name. Default to None.
|
||||
tags (dict, optional): The tags to be added to the experiment.
|
||||
Default to None.
|
||||
params (dict, optional): The params to be added to the experiment.
|
||||
Default to None.
|
||||
tracking_uri (str, optional): The tracking uri. Default to None.
|
||||
artifact_suffix (Tuple[str] or str, optional): The artifact suffix.
|
||||
Default to ('.json', '.log', '.py', 'yaml').
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
save_dir: str,
|
||||
exp_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
tags: Optional[dict] = None,
|
||||
params: Optional[dict] = None,
|
||||
tracking_uri: Optional[str] = None,
|
||||
artifact_suffix: SUFFIX_TYPE = ('.json', '.log', '.py',
|
||||
'yaml')):
|
||||
super().__init__(save_dir)
|
||||
self._exp_name = exp_name
|
||||
self._run_name = run_name
|
||||
self._tags = tags
|
||||
self._params = params
|
||||
self._tracking_uri = tracking_uri
|
||||
self._artifact_suffix = artifact_suffix
|
||||
|
||||
def _init_env(self):
|
||||
"""Setup env for MLflow."""
|
||||
if not os.path.exists(self._save_dir):
|
||||
os.makedirs(self._save_dir, exist_ok=True) # type: ignore
|
||||
|
||||
try:
|
||||
import mlflow
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Please run "pip install mlflow" to install mlflow'
|
||||
) # type: ignore
|
||||
self._mlflow = mlflow
|
||||
|
||||
# when mlflow is imported, a default logger is created.
|
||||
# at this time, the default logger's stream is None
|
||||
# so the stream is reopened only when the stream is None
|
||||
# or the stream is closed
|
||||
logger = MMLogger.get_current_instance()
|
||||
for handler in logger.handlers:
|
||||
if handler.stream is None or handler.stream.closed:
|
||||
handler.stream = open(handler.baseFilename, 'a')
|
||||
|
||||
if self._tracking_uri is not None:
|
||||
logger.warning(
|
||||
'Please make sure that the mlflow server is running.')
|
||||
self._mlflow.set_tracking_uri(self._tracking_uri)
|
||||
else:
|
||||
if os.name == 'nt':
|
||||
file_url = f'file:\\{os.path.abspath(self._save_dir)}'
|
||||
else:
|
||||
file_url = f'file://{os.path.abspath(self._save_dir)}'
|
||||
self._mlflow.set_tracking_uri(file_url)
|
||||
|
||||
self._exp_name = self._exp_name or 'Default'
|
||||
|
||||
if self._mlflow.get_experiment_by_name(self._exp_name) is None:
|
||||
self._mlflow.create_experiment(self._exp_name)
|
||||
|
||||
self._mlflow.set_experiment(self._exp_name)
|
||||
|
||||
if self._run_name is not None:
|
||||
self._mlflow.set_tag('mlflow.runName', self._run_name)
|
||||
if self._tags is not None:
|
||||
self._mlflow.set_tags(self._tags)
|
||||
if self._params is not None:
|
||||
self._mlflow.log_params(self._params)
|
||||
|
||||
@property # type: ignore
|
||||
@force_init_env
|
||||
def experiment(self):
|
||||
"""Return MLflow object."""
|
||||
return self._mlflow
|
||||
|
||||
@force_init_env
|
||||
def add_config(self, config: Config, **kwargs) -> None:
|
||||
"""Record the config to mlflow.
|
||||
|
||||
Args:
|
||||
config (Config): The Config object
|
||||
"""
|
||||
self.cfg = config
|
||||
self._mlflow.log_params(self._flatten(self.cfg))
|
||||
self._mlflow.log_text(self.cfg.pretty_text, 'config.py')
|
||||
|
||||
@force_init_env
|
||||
def add_image(self,
|
||||
name: str,
|
||||
image: np.ndarray,
|
||||
step: int = 0,
|
||||
**kwargs) -> None:
|
||||
"""Record the image to mlflow.
|
||||
|
||||
Args:
|
||||
name (str): The image identifier.
|
||||
image (np.ndarray): The image to be saved. The format
|
||||
should be RGB.
|
||||
step (int): Global step value to record. Default to 0.
|
||||
"""
|
||||
self._mlflow.log_image(image, name)
|
||||
|
||||
@force_init_env
|
||||
def add_scalar(self,
|
||||
name: str,
|
||||
value: Union[int, float, torch.Tensor, np.ndarray],
|
||||
step: int = 0,
|
||||
**kwargs) -> None:
|
||||
"""Record the scalar data to mlflow.
|
||||
|
||||
Args:
|
||||
name (str): The scalar identifier.
|
||||
value (int, float, torch.Tensor, np.ndarray): Value to save.
|
||||
step (int): Global step value to record. Default to 0.
|
||||
"""
|
||||
self._mlflow.log_metric(name, value, step)
|
||||
|
||||
@force_init_env
|
||||
def add_scalars(self,
|
||||
scalar_dict: dict,
|
||||
step: int = 0,
|
||||
file_path: Optional[str] = None,
|
||||
**kwargs) -> None:
|
||||
"""Record the scalar's data to mlflow.
|
||||
|
||||
Args:
|
||||
scalar_dict (dict): Key-value pair storing the tag and
|
||||
corresponding values.
|
||||
step (int): Global step value to record. Default to 0.
|
||||
file_path (str, optional): Useless parameter. Just for
|
||||
interface unification. Default to None.
|
||||
"""
|
||||
assert isinstance(scalar_dict, dict)
|
||||
assert 'step' not in scalar_dict, 'Please set it directly ' \
|
||||
'through the step parameter'
|
||||
self._mlflow.log_metrics(scalar_dict, step)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the mlflow."""
|
||||
file_paths = dict()
|
||||
for filename in scandir(self.cfg.work_dir, self._artifact_suffix,
|
||||
True):
|
||||
file_path = osp.join(self.cfg.work_dir, filename)
|
||||
relative_path = os.path.relpath(file_path, self.cfg.work_dir)
|
||||
dir_path = os.path.dirname(relative_path)
|
||||
file_paths[file_path] = dir_path
|
||||
|
||||
for file_path, dir_path in file_paths.items():
|
||||
self._mlflow.log_artifact(file_path, dir_path)
|
||||
|
||||
if hasattr(self, '_mlflow'):
|
||||
self._mlflow.end_run()
|
||||
|
||||
def _flatten(self, d, parent_key='', sep='.') -> dict:
|
||||
"""Flatten the dict."""
|
||||
items = dict()
|
||||
for k, v in d.items():
|
||||
new_key = parent_key + sep + k if parent_key else k
|
||||
if isinstance(v, MutableMapping):
|
||||
items.update(self._flatten(v, new_key, sep=sep))
|
||||
elif isinstance(v, list):
|
||||
if any(isinstance(x, dict) for x in v):
|
||||
for i, x in enumerate(v):
|
||||
items.update(
|
||||
self._flatten(x, new_key + sep + str(i), sep=sep))
|
||||
else:
|
||||
items[new_key] = v
|
||||
else:
|
||||
items[new_key] = v
|
||||
return items
|
||||
|
|
|
@ -2,5 +2,6 @@ coverage
|
|||
dadaptation
|
||||
lion-pytorch
|
||||
lmdb
|
||||
mlflow
|
||||
parameterized
|
||||
pytest
|
||||
|
|
|
@ -12,8 +12,8 @@ import torch
|
|||
from mmengine import Config
|
||||
from mmengine.fileio import load
|
||||
from mmengine.registry import VISBACKENDS
|
||||
from mmengine.visualization import (LocalVisBackend, TensorboardVisBackend,
|
||||
WandbVisBackend)
|
||||
from mmengine.visualization import (LocalVisBackend, MLflowVisBackend,
|
||||
TensorboardVisBackend, WandbVisBackend)
|
||||
|
||||
|
||||
class TestLocalVisBackend:
|
||||
|
@ -242,3 +242,46 @@ class TestWandbVisBackend:
|
|||
wandb_vis_backend._init_env()
|
||||
wandb_vis_backend.close()
|
||||
shutil.rmtree('temp_dir')
|
||||
|
||||
|
||||
class TestMLflowVisBackend:
|
||||
|
||||
def test_init(self):
|
||||
MLflowVisBackend('temp_dir')
|
||||
VISBACKENDS.build(dict(type='MLflowVisBackend', save_dir='temp_dir'))
|
||||
|
||||
def test_experiment(self):
|
||||
mlflow_vis_backend = MLflowVisBackend('temp_dir')
|
||||
assert mlflow_vis_backend.experiment == mlflow_vis_backend._mlflow
|
||||
|
||||
def test_add_config(self):
|
||||
cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
|
||||
mlflow_vis_backend = MLflowVisBackend('temp_dir')
|
||||
mlflow_vis_backend.add_config(cfg)
|
||||
|
||||
def test_add_image(self):
|
||||
image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8)
|
||||
mlflow_vis_backend = MLflowVisBackend('temp_dir')
|
||||
mlflow_vis_backend.add_image('img.png', image)
|
||||
|
||||
def test_add_scalar(self):
|
||||
mlflow_vis_backend = MLflowVisBackend('temp_dir')
|
||||
mlflow_vis_backend.add_scalar('map', 0.9)
|
||||
# test append mode
|
||||
mlflow_vis_backend.add_scalar('map', 0.9)
|
||||
mlflow_vis_backend.add_scalar('map', 0.95)
|
||||
|
||||
def test_add_scalars(self):
|
||||
mlflow_vis_backend = MLflowVisBackend('temp_dir')
|
||||
input_dict = {'map': 0.7, 'acc': 0.9}
|
||||
mlflow_vis_backend.add_scalars(input_dict)
|
||||
# test append mode
|
||||
mlflow_vis_backend.add_scalars({'map': 0.8, 'acc': 0.8})
|
||||
|
||||
def test_close(self):
|
||||
cfg = Config(dict(work_dir='temp_dir'))
|
||||
mlflow_vis_backend = MLflowVisBackend('temp_dir')
|
||||
mlflow_vis_backend._init_env()
|
||||
mlflow_vis_backend.add_config(cfg)
|
||||
mlflow_vis_backend.close()
|
||||
shutil.rmtree('temp_dir')
|
||||
|
|
Loading…
Reference in New Issue