diff --git a/mmcv/runner/hooks/logger/pavi.py b/mmcv/runner/hooks/logger/pavi.py index 2d2e12cb8..91f08f907 100644 --- a/mmcv/runner/hooks/logger/pavi.py +++ b/mmcv/runner/hooks/logger/pavi.py @@ -35,7 +35,8 @@ class PaviLoggerHook(LoggerHook): - overwrite_last_training (bool, optional): Whether to upload data to the training with the same name in the same project, rather than creating a new one. Defaults to False. - add_graph (bool): Whether to visual model. Default: False. + add_graph (bool): **Deprecated**. Whether to visual model. + Default: False. add_last_ckpt (bool): Whether to save checkpoint after run. Default: False. interval (int): Logging interval (every k iterations). Default: True. @@ -45,6 +46,12 @@ class PaviLoggerHook(LoggerHook): Default: False. by_epoch (bool): Whether EpochBasedRunner is used. Default: True. img_key (string): Get image data from Dataset. Default: 'img_info'. + add_graph_kwargs (dict, optional): A dict contains the params for + adding graph, the keys are as below: + Default: {'active': False, 'start': 0, 'interval': 1}. + add_ckpt_kwargs (dict, optional): A dict contains the params for + adding checkpoint, the keys are as below: + Default: {'active': False, 'start': 0, 'interval': 1}. """ def __init__(self, @@ -55,11 +62,21 @@ class PaviLoggerHook(LoggerHook): ignore_last: bool = True, reset_flag: bool = False, by_epoch: bool = True, - img_key: str = 'img_info'): + img_key: str = 'img_info', + add_graph_kwargs: Optional[Dict] = None, + add_ckpt_kwargs: Optional[Dict] = None) -> None: super().__init__(interval, ignore_last, reset_flag, by_epoch) self.init_kwargs = init_kwargs - self.add_graph = add_graph + add_graph_kwargs = {} if add_graph_kwargs is None else add_graph_kwargs + self.add_graph = add_graph_kwargs.get('active', False) + self.add_graph_start = add_graph_kwargs.get('start', 0) + self.add_graph_interval = add_graph_kwargs.get('interval', 1) + + add_ckpt_kwargs = {} if add_ckpt_kwargs is None else add_ckpt_kwargs + self.add_ckpt = add_ckpt_kwargs.get('active', False) self.add_last_ckpt = add_last_ckpt + self.add_ckpt_start = add_ckpt_kwargs.get('start', 0) + self.add_ckpt_interval = add_ckpt_kwargs.get('interval', 1) self.img_key = img_key @master_only @@ -110,6 +127,28 @@ class PaviLoggerHook(LoggerHook): else: return self.get_iter(runner) + def _add_ckpt(self, runner, ckpt_path: str, step: int) -> None: + + if osp.islink(ckpt_path): + ckpt_path = osp.join(runner.work_dir, os.readlink(ckpt_path)) + + if osp.isfile(ckpt_path): + self.writer.add_snapshot_file( + tag=self.run_name, + snapshot_file_path=ckpt_path, + iteration=step) + + def _add_graph(self, runner) -> None: + if is_module_wrapper(runner.model): + _model = runner.model.module + else: + _model = runner.model + device = next(_model.parameters()).device + data = next(iter(runner.data_loader)) + image = data[self.img_key][0:1].to(device) + with torch.no_grad(): + self.writer.add_graph(_model, image) + @master_only def log(self, runner) -> None: tags = self.get_loggable_tags(runner, add_mode=False) @@ -119,31 +158,73 @@ class PaviLoggerHook(LoggerHook): @master_only def after_run(self, runner) -> None: - if self.add_last_ckpt: - ckpt_path = osp.join(runner.work_dir, 'latest.pth') - if osp.islink(ckpt_path): - ckpt_path = osp.join(runner.work_dir, os.readlink(ckpt_path)) - if osp.isfile(ckpt_path): - # runner.epoch += 1 has been done before `after_run`. - iteration = runner.epoch if self.by_epoch else runner.iter - return self.writer.add_snapshot_file( - tag=self.run_name, - snapshot_file_path=ckpt_path, - iteration=iteration) + if self.add_last_ckpt: + # using runner.epoch/iter is ok since the step has been + 1 + step = runner.epoch if self.by_epoch else runner.iter + + ckpt_path = osp.join(runner.work_dir, 'latest.pth') + self._add_ckpt(runner, ckpt_path, step) # flush the buffer and send a task ending signal to Pavi self.writer.close() @master_only - def before_epoch(self, runner) -> None: - if runner.epoch == 0 and self.add_graph: - if is_module_wrapper(runner.model): - _model = runner.model.module - else: - _model = runner.model - device = next(_model.parameters()).device - data = next(iter(runner.data_loader)) - image = data[self.img_key][0:1].to(device) - with torch.no_grad(): - self.writer.add_graph(_model, image) + def before_train_epoch(self, runner) -> None: + super().before_train_epoch(runner) + + if not self.by_epoch: + return None + + step = self.get_epoch(runner) + if (self.add_graph and step >= self.add_graph_start + and ((step - self.add_graph_start) % self.add_graph_interval + == 0)): # noqa: E129 + self._add_graph(runner) + + @master_only + def before_train_iter(self, runner) -> None: + super().before_train_iter(runner) + + if self.by_epoch: + return None + + step = self.get_iter(runner) + if (self.add_graph and step >= self.add_graph_start + and ((step - self.add_graph_start) % self.add_graph_interval + == 0)): # noqa: E129 + self._add_graph(runner) + + @master_only + def after_train_epoch(self, runner) -> None: + super().after_train_epoch(runner) + # Do not use runner.epoch since it starts from 0. + if not self.by_epoch: + return None + + step = self.get_epoch(runner) + + if (self.add_ckpt and step >= self.add_ckpt_start + and ((step - self.add_ckpt_start) % self.add_ckpt_interval + == 0)): # noqa: E129 + + ckpt_path = osp.join(runner.work_dir, f'epoch_{step}.pth') + + self._add_ckpt(runner, ckpt_path, step) + + @master_only + def after_train_iter(self, runner) -> None: + super().after_train_iter(runner) + + if self.by_epoch: + return None + + step = self.get_iter(runner) + + if (self.add_ckpt and step >= self.add_ckpt_start + and ((step - self.add_ckpt_start) % self.add_ckpt_interval + == 0)): # noqa: E129 + + ckpt_path = osp.join(runner.work_dir, f'iter_{step}.pth') + + self._add_ckpt(runner, ckpt_path, step) diff --git a/tests/test_runner/test_hooks.py b/tests/test_runner/test_hooks.py index bdb93a901..9e4647525 100644 --- a/tests/test_runner/test_hooks.py +++ b/tests/test_runner/test_hooks.py @@ -152,7 +152,7 @@ def test_checkpoint_hook(tmp_path): runner.run([loader], [('train', 1)]) basename = osp.basename(runner.work_dir.rstrip(osp.sep)) assert runner.meta['hook_msgs']['last_ckpt'] == \ - '/'.join([out_dir, basename, 'epoch_4.pth']) + '/'.join([out_dir, basename, 'epoch_4.pth']) mock_put.assert_called() mock_remove.assert_called() mock_isfile.assert_called() @@ -183,7 +183,7 @@ def test_checkpoint_hook(tmp_path): runner.run([loader], [('train', 1)]) basename = osp.basename(runner.work_dir.rstrip(osp.sep)) assert runner.meta['hook_msgs']['last_ckpt'] == \ - '/'.join([out_dir, basename, 'iter_4.pth']) + '/'.join([out_dir, basename, 'iter_4.pth']) mock_put.assert_called() mock_remove.assert_called() mock_isfile.assert_called() @@ -332,7 +332,8 @@ def test_pavi_hook(): loader = DataLoader(torch.ones((5, 2))) runner = _build_demo_runner() runner.meta = dict(config_dict=dict(lr=0.02, gpu_ids=range(1))) - hook = PaviLoggerHook(add_graph=False, add_last_ckpt=True) + hook = PaviLoggerHook( + add_graph_kwargs=None, add_last_ckpt=False, add_ckpt_kwargs=None) runner.register_hook(hook) runner.run([loader, loader], [('train', 1), ('val', 1)]) shutil.rmtree(runner.work_dir) @@ -342,15 +343,113 @@ def test_pavi_hook(): 'learning_rate': 0.02, 'momentum': 0.95 }, 1) + + +def test_pavi_hook_epoch_based(): + """Test setting start epoch and interval epoch.""" + sys.modules['pavi'] = MagicMock() + + loader = DataLoader(torch.ones((5, 2))) + runner = _build_demo_runner(max_epochs=6) + runner.meta = dict(config_dict=dict(lr=0.02, gpu_ids=range(1))) + hook = PaviLoggerHook( + add_graph_kwargs={ + 'active': False, + 'start': 0, + 'interval': 1 + }, + add_last_ckpt=True, + add_ckpt_kwargs={ + 'active': True, + 'start': 1, + 'interval': 2 + }) + runner.register_hook(hook) + runner.run([loader, loader], [('train', 1), ('val', 1)]) + shutil.rmtree(runner.work_dir) + + assert hasattr(hook, 'writer') + # in Windows environment, the latest checkpoint is copied from epoch_1.pth if platform.system() == 'Windows': - snapshot_file_path = osp.join(runner.work_dir, 'latest.pth') + final_file_path = osp.join(runner.work_dir, 'latest.pth') else: - snapshot_file_path = osp.join(runner.work_dir, 'epoch_1.pth') - hook.writer.add_snapshot_file.assert_called_with( - tag=runner.work_dir.split('/')[-1], - snapshot_file_path=snapshot_file_path, - iteration=1) + final_file_path = osp.join(runner.work_dir, 'epoch_6.pth') + calls = [ + call( + tag=runner.work_dir.split('/')[-1], + snapshot_file_path=osp.join(runner.work_dir, 'epoch_1.pth'), + iteration=1), + call( + tag=runner.work_dir.split('/')[-1], + snapshot_file_path=osp.join(runner.work_dir, 'epoch_3.pth'), + iteration=3), + call( + tag=runner.work_dir.split('/')[-1], + snapshot_file_path=osp.join(runner.work_dir, 'epoch_5.pth'), + iteration=5), + call( + tag=runner.work_dir.split('/')[-1], + snapshot_file_path=osp.join(runner.work_dir, final_file_path), + iteration=6), + ] + hook.writer.add_snapshot_file.assert_has_calls(calls, any_order=False) + + +def test_pavi_hook_iter_based(): + """Test setting start epoch and interval epoch.""" + sys.modules['pavi'] = MagicMock() + + loader = DataLoader(torch.ones((5, 2))) + runner = _build_demo_runner( + 'IterBasedRunner', max_iters=15, max_epochs=None) + runner.meta = dict() + hook = PaviLoggerHook( + by_epoch=False, + add_graph_kwargs={ + 'active': False, + 'start': 0, + 'interval': 1 + }, + add_last_ckpt=True, + add_ckpt_kwargs={ + 'active': True, + 'start': 0, + 'interval': 4 + }) + + runner.register_hook(CheckpointHook(interval=4, by_epoch=False)) + runner.register_hook(hook) + + runner.run([loader], [('train', 1)]) + shutil.rmtree(runner.work_dir) + + assert hasattr(hook, 'writer') + + # in Windows environment, the latest checkpoint is copied from epoch_1.pth + if platform.system() == 'Windows': + final_file_path = osp.join(runner.work_dir, 'latest.pth') + else: + final_file_path = osp.join(runner.work_dir, 'iter_15.pth') + calls = [ + call( + tag=runner.work_dir.split('/')[-1], + snapshot_file_path=osp.join(runner.work_dir, 'iter_4.pth'), + iteration=4), + call( + tag=runner.work_dir.split('/')[-1], + snapshot_file_path=osp.join(runner.work_dir, 'iter_8.pth'), + iteration=8), + call( + tag=runner.work_dir.split('/')[-1], + snapshot_file_path=osp.join(runner.work_dir, 'iter_12.pth'), + iteration=12), + call( + tag=runner.work_dir.split('/')[-1], + snapshot_file_path=osp.join(runner.work_dir, final_file_path), + iteration=15), + ] + hook.writer.add_snapshot_file.assert_has_calls(calls, any_order=False) def test_sync_buffers_hook():