mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix]: fix add graph function is not called bug in visualization hooks (#632)
* fix add graph func is not called bug * move add graph call to NaiveVisualizationHook.before_train * Update mmengine/hooks/naive_visualization_hook.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * adjust param sequence and add docstring * minor refine * Update mmengine/visualization/vis_backend.py * update version info Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: HAOCHENYE <21724054@zju.edu.cn>
This commit is contained in:
parent
9d5b417f07
commit
b7aa4dd885
@ -49,6 +49,14 @@ class NaiveVisualizationHook(Hook):
|
||||
unpad_image = input[:unpad_height, :unpad_width]
|
||||
return unpad_image
|
||||
|
||||
def before_train(self, runner) -> None:
|
||||
"""Call add_graph method of visualizer.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
runner.visualizer.add_graph(runner.model, None)
|
||||
|
||||
def after_test_iter(self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
|
@ -360,6 +360,8 @@ class WandbVisBackend(BaseVisBackend):
|
||||
`wandb docs <https://docs.wandb.ai/ref/python/run#log_code>`_
|
||||
for details. Defaults to None.
|
||||
New in version 0.3.0.
|
||||
watch_kwargs (optional, dict): Agurments for ``wandb.watch``.
|
||||
New in version 0.4.0.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -367,12 +369,14 @@ class WandbVisBackend(BaseVisBackend):
|
||||
init_kwargs: Optional[dict] = None,
|
||||
define_metric_cfg: Optional[dict] = None,
|
||||
commit: Optional[bool] = True,
|
||||
log_code_name: Optional[str] = None):
|
||||
log_code_name: Optional[str] = None,
|
||||
watch_kwargs: Optional[dict] = None):
|
||||
super().__init__(save_dir)
|
||||
self._init_kwargs = init_kwargs
|
||||
self._define_metric_cfg = define_metric_cfg
|
||||
self._commit = commit
|
||||
self._log_code_name = log_code_name
|
||||
self._watch_kwargs = watch_kwargs if watch_kwargs is not None else {}
|
||||
|
||||
def _init_env(self):
|
||||
"""Setup env for wandb."""
|
||||
@ -415,6 +419,17 @@ class WandbVisBackend(BaseVisBackend):
|
||||
self._wandb.config.update(dict(config))
|
||||
self._wandb.run.log_code(name=self._log_code_name)
|
||||
|
||||
@force_init_env
|
||||
def add_graph(self, model: torch.nn.Module, data_batch: Sequence[dict],
|
||||
**kwargs) -> None:
|
||||
"""Record the model graph.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Model to draw.
|
||||
data_batch (Sequence[dict]): Batch of data from dataloader.
|
||||
"""
|
||||
self._wandb.watch(model, **self._watch_kwargs)
|
||||
|
||||
@force_init_env
|
||||
def add_image(self,
|
||||
name: str,
|
||||
|
Loading…
x
Reference in New Issue
Block a user