mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
[Feature] Support setting start epoch and interval epoch for PAVI Logger Hook (#2103)
* [Feature] Support setting start epoch and interval epoch for PAVI Logger Hook * [Feature] Update the coding style as the maintainer wish * fix: default integer division or modulo by zero * fix: runner.epoch is less than start and use self.get_epoch instead of runner.epoch * feat: support for iter-based runner and fix the step bug * feat: iter based hook * feat: fix bug and coding style * fix: coding style * fix: coding style * fix: graph may add in evaluation
This commit is contained in:
parent
eb4bbbbd64
commit
78f01001d5
@ -35,7 +35,8 @@ class PaviLoggerHook(LoggerHook):
|
||||
- overwrite_last_training (bool, optional): Whether to upload data
|
||||
to the training with the same name in the same project, rather
|
||||
than creating a new one. Defaults to False.
|
||||
add_graph (bool): Whether to visual model. Default: False.
|
||||
add_graph (bool): **Deprecated**. Whether to visual model.
|
||||
Default: False.
|
||||
add_last_ckpt (bool): Whether to save checkpoint after run.
|
||||
Default: False.
|
||||
interval (int): Logging interval (every k iterations). Default: True.
|
||||
@ -45,6 +46,12 @@ class PaviLoggerHook(LoggerHook):
|
||||
Default: False.
|
||||
by_epoch (bool): Whether EpochBasedRunner is used. Default: True.
|
||||
img_key (string): Get image data from Dataset. Default: 'img_info'.
|
||||
add_graph_kwargs (dict, optional): A dict contains the params for
|
||||
adding graph, the keys are as below:
|
||||
Default: {'active': False, 'start': 0, 'interval': 1}.
|
||||
add_ckpt_kwargs (dict, optional): A dict contains the params for
|
||||
adding checkpoint, the keys are as below:
|
||||
Default: {'active': False, 'start': 0, 'interval': 1}.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -55,11 +62,21 @@ class PaviLoggerHook(LoggerHook):
|
||||
ignore_last: bool = True,
|
||||
reset_flag: bool = False,
|
||||
by_epoch: bool = True,
|
||||
img_key: str = 'img_info'):
|
||||
img_key: str = 'img_info',
|
||||
add_graph_kwargs: Optional[Dict] = None,
|
||||
add_ckpt_kwargs: Optional[Dict] = None) -> None:
|
||||
super().__init__(interval, ignore_last, reset_flag, by_epoch)
|
||||
self.init_kwargs = init_kwargs
|
||||
self.add_graph = add_graph
|
||||
add_graph_kwargs = {} if add_graph_kwargs is None else add_graph_kwargs
|
||||
self.add_graph = add_graph_kwargs.get('active', False)
|
||||
self.add_graph_start = add_graph_kwargs.get('start', 0)
|
||||
self.add_graph_interval = add_graph_kwargs.get('interval', 1)
|
||||
|
||||
add_ckpt_kwargs = {} if add_ckpt_kwargs is None else add_ckpt_kwargs
|
||||
self.add_ckpt = add_ckpt_kwargs.get('active', False)
|
||||
self.add_last_ckpt = add_last_ckpt
|
||||
self.add_ckpt_start = add_ckpt_kwargs.get('start', 0)
|
||||
self.add_ckpt_interval = add_ckpt_kwargs.get('interval', 1)
|
||||
self.img_key = img_key
|
||||
|
||||
@master_only
|
||||
@ -110,6 +127,28 @@ class PaviLoggerHook(LoggerHook):
|
||||
else:
|
||||
return self.get_iter(runner)
|
||||
|
||||
def _add_ckpt(self, runner, ckpt_path: str, step: int) -> None:
|
||||
|
||||
if osp.islink(ckpt_path):
|
||||
ckpt_path = osp.join(runner.work_dir, os.readlink(ckpt_path))
|
||||
|
||||
if osp.isfile(ckpt_path):
|
||||
self.writer.add_snapshot_file(
|
||||
tag=self.run_name,
|
||||
snapshot_file_path=ckpt_path,
|
||||
iteration=step)
|
||||
|
||||
def _add_graph(self, runner) -> None:
|
||||
if is_module_wrapper(runner.model):
|
||||
_model = runner.model.module
|
||||
else:
|
||||
_model = runner.model
|
||||
device = next(_model.parameters()).device
|
||||
data = next(iter(runner.data_loader))
|
||||
image = data[self.img_key][0:1].to(device)
|
||||
with torch.no_grad():
|
||||
self.writer.add_graph(_model, image)
|
||||
|
||||
@master_only
|
||||
def log(self, runner) -> None:
|
||||
tags = self.get_loggable_tags(runner, add_mode=False)
|
||||
@ -119,31 +158,73 @@ class PaviLoggerHook(LoggerHook):
|
||||
|
||||
@master_only
|
||||
def after_run(self, runner) -> None:
|
||||
if self.add_last_ckpt:
|
||||
ckpt_path = osp.join(runner.work_dir, 'latest.pth')
|
||||
if osp.islink(ckpt_path):
|
||||
ckpt_path = osp.join(runner.work_dir, os.readlink(ckpt_path))
|
||||
|
||||
if osp.isfile(ckpt_path):
|
||||
# runner.epoch += 1 has been done before `after_run`.
|
||||
iteration = runner.epoch if self.by_epoch else runner.iter
|
||||
return self.writer.add_snapshot_file(
|
||||
tag=self.run_name,
|
||||
snapshot_file_path=ckpt_path,
|
||||
iteration=iteration)
|
||||
if self.add_last_ckpt:
|
||||
# using runner.epoch/iter is ok since the step has been + 1
|
||||
step = runner.epoch if self.by_epoch else runner.iter
|
||||
|
||||
ckpt_path = osp.join(runner.work_dir, 'latest.pth')
|
||||
self._add_ckpt(runner, ckpt_path, step)
|
||||
|
||||
# flush the buffer and send a task ending signal to Pavi
|
||||
self.writer.close()
|
||||
|
||||
@master_only
|
||||
def before_epoch(self, runner) -> None:
|
||||
if runner.epoch == 0 and self.add_graph:
|
||||
if is_module_wrapper(runner.model):
|
||||
_model = runner.model.module
|
||||
else:
|
||||
_model = runner.model
|
||||
device = next(_model.parameters()).device
|
||||
data = next(iter(runner.data_loader))
|
||||
image = data[self.img_key][0:1].to(device)
|
||||
with torch.no_grad():
|
||||
self.writer.add_graph(_model, image)
|
||||
def before_train_epoch(self, runner) -> None:
|
||||
super().before_train_epoch(runner)
|
||||
|
||||
if not self.by_epoch:
|
||||
return None
|
||||
|
||||
step = self.get_epoch(runner)
|
||||
if (self.add_graph and step >= self.add_graph_start
|
||||
and ((step - self.add_graph_start) % self.add_graph_interval
|
||||
== 0)): # noqa: E129
|
||||
self._add_graph(runner)
|
||||
|
||||
@master_only
|
||||
def before_train_iter(self, runner) -> None:
|
||||
super().before_train_iter(runner)
|
||||
|
||||
if self.by_epoch:
|
||||
return None
|
||||
|
||||
step = self.get_iter(runner)
|
||||
if (self.add_graph and step >= self.add_graph_start
|
||||
and ((step - self.add_graph_start) % self.add_graph_interval
|
||||
== 0)): # noqa: E129
|
||||
self._add_graph(runner)
|
||||
|
||||
@master_only
|
||||
def after_train_epoch(self, runner) -> None:
|
||||
super().after_train_epoch(runner)
|
||||
# Do not use runner.epoch since it starts from 0.
|
||||
if not self.by_epoch:
|
||||
return None
|
||||
|
||||
step = self.get_epoch(runner)
|
||||
|
||||
if (self.add_ckpt and step >= self.add_ckpt_start
|
||||
and ((step - self.add_ckpt_start) % self.add_ckpt_interval
|
||||
== 0)): # noqa: E129
|
||||
|
||||
ckpt_path = osp.join(runner.work_dir, f'epoch_{step}.pth')
|
||||
|
||||
self._add_ckpt(runner, ckpt_path, step)
|
||||
|
||||
@master_only
|
||||
def after_train_iter(self, runner) -> None:
|
||||
super().after_train_iter(runner)
|
||||
|
||||
if self.by_epoch:
|
||||
return None
|
||||
|
||||
step = self.get_iter(runner)
|
||||
|
||||
if (self.add_ckpt and step >= self.add_ckpt_start
|
||||
and ((step - self.add_ckpt_start) % self.add_ckpt_interval
|
||||
== 0)): # noqa: E129
|
||||
|
||||
ckpt_path = osp.join(runner.work_dir, f'iter_{step}.pth')
|
||||
|
||||
self._add_ckpt(runner, ckpt_path, step)
|
||||
|
@ -152,7 +152,7 @@ def test_checkpoint_hook(tmp_path):
|
||||
runner.run([loader], [('train', 1)])
|
||||
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
|
||||
assert runner.meta['hook_msgs']['last_ckpt'] == \
|
||||
'/'.join([out_dir, basename, 'epoch_4.pth'])
|
||||
'/'.join([out_dir, basename, 'epoch_4.pth'])
|
||||
mock_put.assert_called()
|
||||
mock_remove.assert_called()
|
||||
mock_isfile.assert_called()
|
||||
@ -183,7 +183,7 @@ def test_checkpoint_hook(tmp_path):
|
||||
runner.run([loader], [('train', 1)])
|
||||
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
|
||||
assert runner.meta['hook_msgs']['last_ckpt'] == \
|
||||
'/'.join([out_dir, basename, 'iter_4.pth'])
|
||||
'/'.join([out_dir, basename, 'iter_4.pth'])
|
||||
mock_put.assert_called()
|
||||
mock_remove.assert_called()
|
||||
mock_isfile.assert_called()
|
||||
@ -332,7 +332,8 @@ def test_pavi_hook():
|
||||
loader = DataLoader(torch.ones((5, 2)))
|
||||
runner = _build_demo_runner()
|
||||
runner.meta = dict(config_dict=dict(lr=0.02, gpu_ids=range(1)))
|
||||
hook = PaviLoggerHook(add_graph=False, add_last_ckpt=True)
|
||||
hook = PaviLoggerHook(
|
||||
add_graph_kwargs=None, add_last_ckpt=False, add_ckpt_kwargs=None)
|
||||
runner.register_hook(hook)
|
||||
runner.run([loader, loader], [('train', 1), ('val', 1)])
|
||||
shutil.rmtree(runner.work_dir)
|
||||
@ -342,15 +343,113 @@ def test_pavi_hook():
|
||||
'learning_rate': 0.02,
|
||||
'momentum': 0.95
|
||||
}, 1)
|
||||
|
||||
|
||||
def test_pavi_hook_epoch_based():
|
||||
"""Test setting start epoch and interval epoch."""
|
||||
sys.modules['pavi'] = MagicMock()
|
||||
|
||||
loader = DataLoader(torch.ones((5, 2)))
|
||||
runner = _build_demo_runner(max_epochs=6)
|
||||
runner.meta = dict(config_dict=dict(lr=0.02, gpu_ids=range(1)))
|
||||
hook = PaviLoggerHook(
|
||||
add_graph_kwargs={
|
||||
'active': False,
|
||||
'start': 0,
|
||||
'interval': 1
|
||||
},
|
||||
add_last_ckpt=True,
|
||||
add_ckpt_kwargs={
|
||||
'active': True,
|
||||
'start': 1,
|
||||
'interval': 2
|
||||
})
|
||||
runner.register_hook(hook)
|
||||
runner.run([loader, loader], [('train', 1), ('val', 1)])
|
||||
shutil.rmtree(runner.work_dir)
|
||||
|
||||
assert hasattr(hook, 'writer')
|
||||
|
||||
# in Windows environment, the latest checkpoint is copied from epoch_1.pth
|
||||
if platform.system() == 'Windows':
|
||||
snapshot_file_path = osp.join(runner.work_dir, 'latest.pth')
|
||||
final_file_path = osp.join(runner.work_dir, 'latest.pth')
|
||||
else:
|
||||
snapshot_file_path = osp.join(runner.work_dir, 'epoch_1.pth')
|
||||
hook.writer.add_snapshot_file.assert_called_with(
|
||||
tag=runner.work_dir.split('/')[-1],
|
||||
snapshot_file_path=snapshot_file_path,
|
||||
iteration=1)
|
||||
final_file_path = osp.join(runner.work_dir, 'epoch_6.pth')
|
||||
calls = [
|
||||
call(
|
||||
tag=runner.work_dir.split('/')[-1],
|
||||
snapshot_file_path=osp.join(runner.work_dir, 'epoch_1.pth'),
|
||||
iteration=1),
|
||||
call(
|
||||
tag=runner.work_dir.split('/')[-1],
|
||||
snapshot_file_path=osp.join(runner.work_dir, 'epoch_3.pth'),
|
||||
iteration=3),
|
||||
call(
|
||||
tag=runner.work_dir.split('/')[-1],
|
||||
snapshot_file_path=osp.join(runner.work_dir, 'epoch_5.pth'),
|
||||
iteration=5),
|
||||
call(
|
||||
tag=runner.work_dir.split('/')[-1],
|
||||
snapshot_file_path=osp.join(runner.work_dir, final_file_path),
|
||||
iteration=6),
|
||||
]
|
||||
hook.writer.add_snapshot_file.assert_has_calls(calls, any_order=False)
|
||||
|
||||
|
||||
def test_pavi_hook_iter_based():
|
||||
"""Test setting start epoch and interval epoch."""
|
||||
sys.modules['pavi'] = MagicMock()
|
||||
|
||||
loader = DataLoader(torch.ones((5, 2)))
|
||||
runner = _build_demo_runner(
|
||||
'IterBasedRunner', max_iters=15, max_epochs=None)
|
||||
runner.meta = dict()
|
||||
hook = PaviLoggerHook(
|
||||
by_epoch=False,
|
||||
add_graph_kwargs={
|
||||
'active': False,
|
||||
'start': 0,
|
||||
'interval': 1
|
||||
},
|
||||
add_last_ckpt=True,
|
||||
add_ckpt_kwargs={
|
||||
'active': True,
|
||||
'start': 0,
|
||||
'interval': 4
|
||||
})
|
||||
|
||||
runner.register_hook(CheckpointHook(interval=4, by_epoch=False))
|
||||
runner.register_hook(hook)
|
||||
|
||||
runner.run([loader], [('train', 1)])
|
||||
shutil.rmtree(runner.work_dir)
|
||||
|
||||
assert hasattr(hook, 'writer')
|
||||
|
||||
# in Windows environment, the latest checkpoint is copied from epoch_1.pth
|
||||
if platform.system() == 'Windows':
|
||||
final_file_path = osp.join(runner.work_dir, 'latest.pth')
|
||||
else:
|
||||
final_file_path = osp.join(runner.work_dir, 'iter_15.pth')
|
||||
calls = [
|
||||
call(
|
||||
tag=runner.work_dir.split('/')[-1],
|
||||
snapshot_file_path=osp.join(runner.work_dir, 'iter_4.pth'),
|
||||
iteration=4),
|
||||
call(
|
||||
tag=runner.work_dir.split('/')[-1],
|
||||
snapshot_file_path=osp.join(runner.work_dir, 'iter_8.pth'),
|
||||
iteration=8),
|
||||
call(
|
||||
tag=runner.work_dir.split('/')[-1],
|
||||
snapshot_file_path=osp.join(runner.work_dir, 'iter_12.pth'),
|
||||
iteration=12),
|
||||
call(
|
||||
tag=runner.work_dir.split('/')[-1],
|
||||
snapshot_file_path=osp.join(runner.work_dir, final_file_path),
|
||||
iteration=15),
|
||||
]
|
||||
hook.writer.add_snapshot_file.assert_has_calls(calls, any_order=False)
|
||||
|
||||
|
||||
def test_sync_buffers_hook():
|
||||
|
Loading…
x
Reference in New Issue
Block a user