[Fix] Fix add_graph in pavi (#948)

* [Fix] Fix add_graph in pavi

* change data loader to image

* Delete =2.4.0

* pavi-add_graph-0419

* pavi-add_graph-0419

* [Fix] pavi device

* fix device in pavi-add graph

* img_key

* img_key

* add no_grad

* Delete version.py

* add version.py
pull/1002/head
gengenkai 2021-04-27 19:31:23 +08:00 committed by GitHub
parent c142eced17
commit f46e5f8e5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 17 additions and 4 deletions

View File

@ -6,6 +6,7 @@ import os.path as osp
import yaml
import mmcv
from ....parallel.utils import is_module_wrapper
from ...dist_utils import master_only
from ..hook import HOOKS
from .base import LoggerHook
@ -21,12 +22,14 @@ class PaviLoggerHook(LoggerHook):
interval=10,
ignore_last=True,
reset_flag=True,
by_epoch=True):
by_epoch=True,
img_key='img_info'):
super(PaviLoggerHook, self).__init__(interval, ignore_last, reset_flag,
by_epoch)
self.init_kwargs = init_kwargs
self.add_graph = add_graph
self.add_last_ckpt = add_last_ckpt
self.img_key=img_key
@master_only
def before_run(self, runner):
@ -66,9 +69,6 @@ class PaviLoggerHook(LoggerHook):
self.init_kwargs['session_text'] = session_text
self.writer = SummaryWriter(**self.init_kwargs)
if self.add_graph:
self.writer.add_graph(runner.model)
def get_step(self, runner):
"""Get the total training step/epoch."""
if self.get_mode(runner) == 'val' and self.by_epoch:
@ -95,3 +95,16 @@ class PaviLoggerHook(LoggerHook):
tag=self.run_name,
snapshot_file_path=ckpt_path,
iteration=iteration)
@master_only
def before_epoch(self, runner):
if runner.epoch == 0 and self.add_graph:
if is_module_wrapper(runner.model):
_model = runner.model.module
else:
_model = runner.model
device = next(_model.parameters()).device
data = next(iter(runner.data_loader))
image = data[self.img_key][0:1].to(device)
with torch.no_grad():
self.writer.add_graph(_model, image)