[Feature] Support setting start epoch and interval epoch for PAVI Logger Hook (#2103)

* [Feature] Support setting start epoch and interval epoch for PAVI Logger Hook

* [Feature] Update the coding style as the maintainer wish

* fix: default integer division or modulo by zero

* fix: runner.epoch is less than start and use self.get_epoch instead of runner.epoch

* feat: support for iter-based runner and fix the step bug

* feat: iter based hook

* feat: fix bug and coding style

* fix: coding style

* fix: coding style

* fix: graph may add in evaluation
This commit is contained in:
Peter Ye 2022-07-29 11:40:20 +08:00 committed by GitHub
parent eb4bbbbd64
commit 78f01001d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 214 additions and 34 deletions

View File

@ -35,7 +35,8 @@ class PaviLoggerHook(LoggerHook):
- overwrite_last_training (bool, optional): Whether to upload data
to the training with the same name in the same project, rather
than creating a new one. Defaults to False.
add_graph (bool): Whether to visual model. Default: False.
add_graph (bool): **Deprecated**. Whether to visual model.
Default: False.
add_last_ckpt (bool): Whether to save checkpoint after run.
Default: False.
interval (int): Logging interval (every k iterations). Default: True.
@ -45,6 +46,12 @@ class PaviLoggerHook(LoggerHook):
Default: False.
by_epoch (bool): Whether EpochBasedRunner is used. Default: True.
img_key (string): Get image data from Dataset. Default: 'img_info'.
add_graph_kwargs (dict, optional): A dict contains the params for
adding graph, the keys are as below:
Default: {'active': False, 'start': 0, 'interval': 1}.
add_ckpt_kwargs (dict, optional): A dict contains the params for
adding checkpoint, the keys are as below:
Default: {'active': False, 'start': 0, 'interval': 1}.
"""
def __init__(self,
@ -55,11 +62,21 @@ class PaviLoggerHook(LoggerHook):
ignore_last: bool = True,
reset_flag: bool = False,
by_epoch: bool = True,
img_key: str = 'img_info'):
img_key: str = 'img_info',
add_graph_kwargs: Optional[Dict] = None,
add_ckpt_kwargs: Optional[Dict] = None) -> None:
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.init_kwargs = init_kwargs
self.add_graph = add_graph
add_graph_kwargs = {} if add_graph_kwargs is None else add_graph_kwargs
self.add_graph = add_graph_kwargs.get('active', False)
self.add_graph_start = add_graph_kwargs.get('start', 0)
self.add_graph_interval = add_graph_kwargs.get('interval', 1)
add_ckpt_kwargs = {} if add_ckpt_kwargs is None else add_ckpt_kwargs
self.add_ckpt = add_ckpt_kwargs.get('active', False)
self.add_last_ckpt = add_last_ckpt
self.add_ckpt_start = add_ckpt_kwargs.get('start', 0)
self.add_ckpt_interval = add_ckpt_kwargs.get('interval', 1)
self.img_key = img_key
@master_only
@ -110,6 +127,28 @@ class PaviLoggerHook(LoggerHook):
else:
return self.get_iter(runner)
def _add_ckpt(self, runner, ckpt_path: str, step: int) -> None:
if osp.islink(ckpt_path):
ckpt_path = osp.join(runner.work_dir, os.readlink(ckpt_path))
if osp.isfile(ckpt_path):
self.writer.add_snapshot_file(
tag=self.run_name,
snapshot_file_path=ckpt_path,
iteration=step)
def _add_graph(self, runner) -> None:
if is_module_wrapper(runner.model):
_model = runner.model.module
else:
_model = runner.model
device = next(_model.parameters()).device
data = next(iter(runner.data_loader))
image = data[self.img_key][0:1].to(device)
with torch.no_grad():
self.writer.add_graph(_model, image)
@master_only
def log(self, runner) -> None:
tags = self.get_loggable_tags(runner, add_mode=False)
@ -119,31 +158,73 @@ class PaviLoggerHook(LoggerHook):
@master_only
def after_run(self, runner) -> None:
if self.add_last_ckpt:
ckpt_path = osp.join(runner.work_dir, 'latest.pth')
if osp.islink(ckpt_path):
ckpt_path = osp.join(runner.work_dir, os.readlink(ckpt_path))
if osp.isfile(ckpt_path):
# runner.epoch += 1 has been done before `after_run`.
iteration = runner.epoch if self.by_epoch else runner.iter
return self.writer.add_snapshot_file(
tag=self.run_name,
snapshot_file_path=ckpt_path,
iteration=iteration)
if self.add_last_ckpt:
# using runner.epoch/iter is ok since the step has been + 1
step = runner.epoch if self.by_epoch else runner.iter
ckpt_path = osp.join(runner.work_dir, 'latest.pth')
self._add_ckpt(runner, ckpt_path, step)
# flush the buffer and send a task ending signal to Pavi
self.writer.close()
@master_only
def before_epoch(self, runner) -> None:
if runner.epoch == 0 and self.add_graph:
if is_module_wrapper(runner.model):
_model = runner.model.module
else:
_model = runner.model
device = next(_model.parameters()).device
data = next(iter(runner.data_loader))
image = data[self.img_key][0:1].to(device)
with torch.no_grad():
self.writer.add_graph(_model, image)
def before_train_epoch(self, runner) -> None:
super().before_train_epoch(runner)
if not self.by_epoch:
return None
step = self.get_epoch(runner)
if (self.add_graph and step >= self.add_graph_start
and ((step - self.add_graph_start) % self.add_graph_interval
== 0)): # noqa: E129
self._add_graph(runner)
@master_only
def before_train_iter(self, runner) -> None:
super().before_train_iter(runner)
if self.by_epoch:
return None
step = self.get_iter(runner)
if (self.add_graph and step >= self.add_graph_start
and ((step - self.add_graph_start) % self.add_graph_interval
== 0)): # noqa: E129
self._add_graph(runner)
@master_only
def after_train_epoch(self, runner) -> None:
super().after_train_epoch(runner)
# Do not use runner.epoch since it starts from 0.
if not self.by_epoch:
return None
step = self.get_epoch(runner)
if (self.add_ckpt and step >= self.add_ckpt_start
and ((step - self.add_ckpt_start) % self.add_ckpt_interval
== 0)): # noqa: E129
ckpt_path = osp.join(runner.work_dir, f'epoch_{step}.pth')
self._add_ckpt(runner, ckpt_path, step)
@master_only
def after_train_iter(self, runner) -> None:
super().after_train_iter(runner)
if self.by_epoch:
return None
step = self.get_iter(runner)
if (self.add_ckpt and step >= self.add_ckpt_start
and ((step - self.add_ckpt_start) % self.add_ckpt_interval
== 0)): # noqa: E129
ckpt_path = osp.join(runner.work_dir, f'iter_{step}.pth')
self._add_ckpt(runner, ckpt_path, step)

View File

@ -152,7 +152,7 @@ def test_checkpoint_hook(tmp_path):
runner.run([loader], [('train', 1)])
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
assert runner.meta['hook_msgs']['last_ckpt'] == \
'/'.join([out_dir, basename, 'epoch_4.pth'])
'/'.join([out_dir, basename, 'epoch_4.pth'])
mock_put.assert_called()
mock_remove.assert_called()
mock_isfile.assert_called()
@ -183,7 +183,7 @@ def test_checkpoint_hook(tmp_path):
runner.run([loader], [('train', 1)])
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
assert runner.meta['hook_msgs']['last_ckpt'] == \
'/'.join([out_dir, basename, 'iter_4.pth'])
'/'.join([out_dir, basename, 'iter_4.pth'])
mock_put.assert_called()
mock_remove.assert_called()
mock_isfile.assert_called()
@ -332,7 +332,8 @@ def test_pavi_hook():
loader = DataLoader(torch.ones((5, 2)))
runner = _build_demo_runner()
runner.meta = dict(config_dict=dict(lr=0.02, gpu_ids=range(1)))
hook = PaviLoggerHook(add_graph=False, add_last_ckpt=True)
hook = PaviLoggerHook(
add_graph_kwargs=None, add_last_ckpt=False, add_ckpt_kwargs=None)
runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)])
shutil.rmtree(runner.work_dir)
@ -342,15 +343,113 @@ def test_pavi_hook():
'learning_rate': 0.02,
'momentum': 0.95
}, 1)
def test_pavi_hook_epoch_based():
"""Test setting start epoch and interval epoch."""
sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((5, 2)))
runner = _build_demo_runner(max_epochs=6)
runner.meta = dict(config_dict=dict(lr=0.02, gpu_ids=range(1)))
hook = PaviLoggerHook(
add_graph_kwargs={
'active': False,
'start': 0,
'interval': 1
},
add_last_ckpt=True,
add_ckpt_kwargs={
'active': True,
'start': 1,
'interval': 2
})
runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)])
shutil.rmtree(runner.work_dir)
assert hasattr(hook, 'writer')
# in Windows environment, the latest checkpoint is copied from epoch_1.pth
if platform.system() == 'Windows':
snapshot_file_path = osp.join(runner.work_dir, 'latest.pth')
final_file_path = osp.join(runner.work_dir, 'latest.pth')
else:
snapshot_file_path = osp.join(runner.work_dir, 'epoch_1.pth')
hook.writer.add_snapshot_file.assert_called_with(
tag=runner.work_dir.split('/')[-1],
snapshot_file_path=snapshot_file_path,
iteration=1)
final_file_path = osp.join(runner.work_dir, 'epoch_6.pth')
calls = [
call(
tag=runner.work_dir.split('/')[-1],
snapshot_file_path=osp.join(runner.work_dir, 'epoch_1.pth'),
iteration=1),
call(
tag=runner.work_dir.split('/')[-1],
snapshot_file_path=osp.join(runner.work_dir, 'epoch_3.pth'),
iteration=3),
call(
tag=runner.work_dir.split('/')[-1],
snapshot_file_path=osp.join(runner.work_dir, 'epoch_5.pth'),
iteration=5),
call(
tag=runner.work_dir.split('/')[-1],
snapshot_file_path=osp.join(runner.work_dir, final_file_path),
iteration=6),
]
hook.writer.add_snapshot_file.assert_has_calls(calls, any_order=False)
def test_pavi_hook_iter_based():
"""Test setting start epoch and interval epoch."""
sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((5, 2)))
runner = _build_demo_runner(
'IterBasedRunner', max_iters=15, max_epochs=None)
runner.meta = dict()
hook = PaviLoggerHook(
by_epoch=False,
add_graph_kwargs={
'active': False,
'start': 0,
'interval': 1
},
add_last_ckpt=True,
add_ckpt_kwargs={
'active': True,
'start': 0,
'interval': 4
})
runner.register_hook(CheckpointHook(interval=4, by_epoch=False))
runner.register_hook(hook)
runner.run([loader], [('train', 1)])
shutil.rmtree(runner.work_dir)
assert hasattr(hook, 'writer')
# in Windows environment, the latest checkpoint is copied from epoch_1.pth
if platform.system() == 'Windows':
final_file_path = osp.join(runner.work_dir, 'latest.pth')
else:
final_file_path = osp.join(runner.work_dir, 'iter_15.pth')
calls = [
call(
tag=runner.work_dir.split('/')[-1],
snapshot_file_path=osp.join(runner.work_dir, 'iter_4.pth'),
iteration=4),
call(
tag=runner.work_dir.split('/')[-1],
snapshot_file_path=osp.join(runner.work_dir, 'iter_8.pth'),
iteration=8),
call(
tag=runner.work_dir.split('/')[-1],
snapshot_file_path=osp.join(runner.work_dir, 'iter_12.pth'),
iteration=12),
call(
tag=runner.work_dir.split('/')[-1],
snapshot_file_path=osp.join(runner.work_dir, final_file_path),
iteration=15),
]
hook.writer.add_snapshot_file.assert_has_calls(calls, any_order=False)
def test_sync_buffers_hook():