mirror of https://github.com/open-mmlab/mmcv.git
[Fix] Fix add_graph in pavi (#948)
* [Fix] Fix add_graph in pavi * change data loader to image * Delete =2.4.0 * pavi-add_graph-0419 * pavi-add_graph-0419 * [Fix] pavi device * fix device in pavi-add graph * img_key * img_key * add no_grad * Delete version.py * add version.pypull/1002/head
parent
c142eced17
commit
f46e5f8e5e
|
@ -6,6 +6,7 @@ import os.path as osp
|
|||
import yaml
|
||||
|
||||
import mmcv
|
||||
from ....parallel.utils import is_module_wrapper
|
||||
from ...dist_utils import master_only
|
||||
from ..hook import HOOKS
|
||||
from .base import LoggerHook
|
||||
|
@ -21,12 +22,14 @@ class PaviLoggerHook(LoggerHook):
|
|||
interval=10,
|
||||
ignore_last=True,
|
||||
reset_flag=True,
|
||||
by_epoch=True):
|
||||
by_epoch=True,
|
||||
img_key='img_info'):
|
||||
super(PaviLoggerHook, self).__init__(interval, ignore_last, reset_flag,
|
||||
by_epoch)
|
||||
self.init_kwargs = init_kwargs
|
||||
self.add_graph = add_graph
|
||||
self.add_last_ckpt = add_last_ckpt
|
||||
self.img_key=img_key
|
||||
|
||||
@master_only
|
||||
def before_run(self, runner):
|
||||
|
@ -66,9 +69,6 @@ class PaviLoggerHook(LoggerHook):
|
|||
self.init_kwargs['session_text'] = session_text
|
||||
self.writer = SummaryWriter(**self.init_kwargs)
|
||||
|
||||
if self.add_graph:
|
||||
self.writer.add_graph(runner.model)
|
||||
|
||||
def get_step(self, runner):
|
||||
"""Get the total training step/epoch."""
|
||||
if self.get_mode(runner) == 'val' and self.by_epoch:
|
||||
|
@ -95,3 +95,16 @@ class PaviLoggerHook(LoggerHook):
|
|||
tag=self.run_name,
|
||||
snapshot_file_path=ckpt_path,
|
||||
iteration=iteration)
|
||||
|
||||
@master_only
|
||||
def before_epoch(self, runner):
|
||||
if runner.epoch == 0 and self.add_graph:
|
||||
if is_module_wrapper(runner.model):
|
||||
_model = runner.model.module
|
||||
else:
|
||||
_model = runner.model
|
||||
device = next(_model.parameters()).device
|
||||
data = next(iter(runner.data_loader))
|
||||
image = data[self.img_key][0:1].to(device)
|
||||
with torch.no_grad():
|
||||
self.writer.add_graph(_model, image)
|
||||
|
|
Loading…
Reference in New Issue