From f46e5f8e5e4fb9c6ca00f824ecd49f57d5fbe238 Mon Sep 17 00:00:00 2001 From: gengenkai <30782254+gengenkai@users.noreply.github.com> Date: Tue, 27 Apr 2021 19:31:23 +0800 Subject: [PATCH] [Fix] Fix add_graph in pavi (#948) * [Fix] Fix add_graph in pavi * change data loader to image * Delete =2.4.0 * pavi-add_graph-0419 * pavi-add_graph-0419 * [Fix] pavi device * fix device in pavi-add graph * img_key * img_key * add no_grad * Delete version.py * add version.py --- mmcv/runner/hooks/logger/pavi.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/mmcv/runner/hooks/logger/pavi.py b/mmcv/runner/hooks/logger/pavi.py index e7c470e13..18c22243d 100644 --- a/mmcv/runner/hooks/logger/pavi.py +++ b/mmcv/runner/hooks/logger/pavi.py @@ -6,6 +6,7 @@ import os.path as osp import yaml import mmcv +from ....parallel.utils import is_module_wrapper from ...dist_utils import master_only from ..hook import HOOKS from .base import LoggerHook @@ -21,12 +22,14 @@ class PaviLoggerHook(LoggerHook): interval=10, ignore_last=True, reset_flag=True, - by_epoch=True): + by_epoch=True, + img_key='img_info'): super(PaviLoggerHook, self).__init__(interval, ignore_last, reset_flag, by_epoch) self.init_kwargs = init_kwargs self.add_graph = add_graph self.add_last_ckpt = add_last_ckpt + self.img_key=img_key @master_only def before_run(self, runner): @@ -66,9 +69,6 @@ class PaviLoggerHook(LoggerHook): self.init_kwargs['session_text'] = session_text self.writer = SummaryWriter(**self.init_kwargs) - if self.add_graph: - self.writer.add_graph(runner.model) - def get_step(self, runner): """Get the total training step/epoch.""" if self.get_mode(runner) == 'val' and self.by_epoch: @@ -95,3 +95,16 @@ class PaviLoggerHook(LoggerHook): tag=self.run_name, snapshot_file_path=ckpt_path, iteration=iteration) + + @master_only + def before_epoch(self, runner): + if runner.epoch == 0 and self.add_graph: + if is_module_wrapper(runner.model): + _model = runner.model.module + else: + _model = runner.model + device = next(_model.parameters()).device + data = next(iter(runner.data_loader)) + image = data[self.img_key][0:1].to(device) + with torch.no_grad(): + self.writer.add_graph(_model, image)