[Fix]: fix add graph function is not called bug in visualization hooks (#632)

* fix add graph func is not called bug

* move add graph call to NaiveVisualizationHook.before_train

* Update mmengine/hooks/naive_visualization_hook.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* adjust param sequence and add docstring

* minor refine

* Update mmengine/visualization/vis_backend.py

* update version info

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: HAOCHENYE <21724054@zju.edu.cn>
This commit is contained in:
shenmishajing 2022-11-21 11:52:48 +08:00 committed by GitHub
parent 9d5b417f07
commit b7aa4dd885
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 1 deletions

View File

@ -49,6 +49,14 @@ class NaiveVisualizationHook(Hook):
unpad_image = input[:unpad_height, :unpad_width]
return unpad_image
def before_train(self, runner) -> None:
"""Call add_graph method of visualizer.
Args:
runner (Runner): The runner of the training process.
"""
runner.visualizer.add_graph(runner.model, None)
def after_test_iter(self,
runner,
batch_idx: int,

View File

@ -360,6 +360,8 @@ class WandbVisBackend(BaseVisBackend):
`wandb docs <https://docs.wandb.ai/ref/python/run#log_code>`_
for details. Defaults to None.
New in version 0.3.0.
watch_kwargs (optional, dict): Agurments for ``wandb.watch``.
New in version 0.4.0.
"""
def __init__(self,
@ -367,12 +369,14 @@ class WandbVisBackend(BaseVisBackend):
init_kwargs: Optional[dict] = None,
define_metric_cfg: Optional[dict] = None,
commit: Optional[bool] = True,
log_code_name: Optional[str] = None):
log_code_name: Optional[str] = None,
watch_kwargs: Optional[dict] = None):
super().__init__(save_dir)
self._init_kwargs = init_kwargs
self._define_metric_cfg = define_metric_cfg
self._commit = commit
self._log_code_name = log_code_name
self._watch_kwargs = watch_kwargs if watch_kwargs is not None else {}
def _init_env(self):
"""Setup env for wandb."""
@ -415,6 +419,17 @@ class WandbVisBackend(BaseVisBackend):
self._wandb.config.update(dict(config))
self._wandb.run.log_code(name=self._log_code_name)
@force_init_env
def add_graph(self, model: torch.nn.Module, data_batch: Sequence[dict],
**kwargs) -> None:
"""Record the model graph.
Args:
model (torch.nn.Module): Model to draw.
data_batch (Sequence[dict]): Batch of data from dataloader.
"""
self._wandb.watch(model, **self._watch_kwargs)
@force_init_env
def add_image(self,
name: str,