mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix]: fix add graph function is not called bug in visualization hooks (#632)
* fix add graph func is not called bug * move add graph call to NaiveVisualizationHook.before_train * Update mmengine/hooks/naive_visualization_hook.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * adjust param sequence and add docstring * minor refine * Update mmengine/visualization/vis_backend.py * update version info Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: HAOCHENYE <21724054@zju.edu.cn>
This commit is contained in:
parent
9d5b417f07
commit
b7aa4dd885
@ -49,6 +49,14 @@ class NaiveVisualizationHook(Hook):
|
|||||||
unpad_image = input[:unpad_height, :unpad_width]
|
unpad_image = input[:unpad_height, :unpad_width]
|
||||||
return unpad_image
|
return unpad_image
|
||||||
|
|
||||||
|
def before_train(self, runner) -> None:
|
||||||
|
"""Call add_graph method of visualizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (Runner): The runner of the training process.
|
||||||
|
"""
|
||||||
|
runner.visualizer.add_graph(runner.model, None)
|
||||||
|
|
||||||
def after_test_iter(self,
|
def after_test_iter(self,
|
||||||
runner,
|
runner,
|
||||||
batch_idx: int,
|
batch_idx: int,
|
||||||
|
@ -360,6 +360,8 @@ class WandbVisBackend(BaseVisBackend):
|
|||||||
`wandb docs <https://docs.wandb.ai/ref/python/run#log_code>`_
|
`wandb docs <https://docs.wandb.ai/ref/python/run#log_code>`_
|
||||||
for details. Defaults to None.
|
for details. Defaults to None.
|
||||||
New in version 0.3.0.
|
New in version 0.3.0.
|
||||||
|
watch_kwargs (optional, dict): Agurments for ``wandb.watch``.
|
||||||
|
New in version 0.4.0.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -367,12 +369,14 @@ class WandbVisBackend(BaseVisBackend):
|
|||||||
init_kwargs: Optional[dict] = None,
|
init_kwargs: Optional[dict] = None,
|
||||||
define_metric_cfg: Optional[dict] = None,
|
define_metric_cfg: Optional[dict] = None,
|
||||||
commit: Optional[bool] = True,
|
commit: Optional[bool] = True,
|
||||||
log_code_name: Optional[str] = None):
|
log_code_name: Optional[str] = None,
|
||||||
|
watch_kwargs: Optional[dict] = None):
|
||||||
super().__init__(save_dir)
|
super().__init__(save_dir)
|
||||||
self._init_kwargs = init_kwargs
|
self._init_kwargs = init_kwargs
|
||||||
self._define_metric_cfg = define_metric_cfg
|
self._define_metric_cfg = define_metric_cfg
|
||||||
self._commit = commit
|
self._commit = commit
|
||||||
self._log_code_name = log_code_name
|
self._log_code_name = log_code_name
|
||||||
|
self._watch_kwargs = watch_kwargs if watch_kwargs is not None else {}
|
||||||
|
|
||||||
def _init_env(self):
|
def _init_env(self):
|
||||||
"""Setup env for wandb."""
|
"""Setup env for wandb."""
|
||||||
@ -415,6 +419,17 @@ class WandbVisBackend(BaseVisBackend):
|
|||||||
self._wandb.config.update(dict(config))
|
self._wandb.config.update(dict(config))
|
||||||
self._wandb.run.log_code(name=self._log_code_name)
|
self._wandb.run.log_code(name=self._log_code_name)
|
||||||
|
|
||||||
|
@force_init_env
|
||||||
|
def add_graph(self, model: torch.nn.Module, data_batch: Sequence[dict],
|
||||||
|
**kwargs) -> None:
|
||||||
|
"""Record the model graph.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): Model to draw.
|
||||||
|
data_batch (Sequence[dict]): Batch of data from dataloader.
|
||||||
|
"""
|
||||||
|
self._wandb.watch(model, **self._watch_kwargs)
|
||||||
|
|
||||||
@force_init_env
|
@force_init_env
|
||||||
def add_image(self,
|
def add_image(self,
|
||||||
name: str,
|
name: str,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user