[Feature]: Add dist semantics in checkpoint hook (#131)

* [Feature]: Add dist semantics in checkpoint hook

* [Fix]: Delete sync buffer in checkpoint hook
pull/140/head
Yuan Liu 2022-03-25 13:46:31 +08:00 committed by GitHub
parent e4859030af
commit 26f24296db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 11 deletions

View File

@ -5,6 +5,7 @@ from pathlib import Path
from typing import Any, Optional, Sequence, Tuple, Union
from mmengine.data import BaseDataSample
from mmengine.dist import master_only
from mmengine.fileio import FileClient
from mmengine.registry import HOOKS
from .hook import Hook
@ -41,8 +42,6 @@ class CheckpointHook(Hook):
Default: -1, which means unlimited.
save_last (bool): Whether to force the last checkpoint to be
saved regardless of interval. Default: True.
sync_buffer (bool): Whether to synchronize buffers in
different gpus. Default: False.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None.
@ -59,7 +58,6 @@ class CheckpointHook(Hook):
out_dir: Optional[Union[str, Path]] = None,
max_keep_ckpts: int = -1,
save_last: bool = True,
sync_buffer: bool = False,
file_client_args: Optional[dict] = None,
**kwargs) -> None:
self.interval = interval
@ -70,7 +68,6 @@ class CheckpointHook(Hook):
self.max_keep_ckpts = max_keep_ckpts
self.save_last = save_last
self.args = kwargs
self.sync_buffer = sync_buffer
self.file_client_args = file_client_args
def before_run(self, runner) -> None:
@ -129,12 +126,9 @@ class CheckpointHook(Hook):
self.save_last and self.is_last_train_epoch(runner)):
runner.logger.info(
f'Saving checkpoint at {runner.epoch + 1} epochs')
if self.sync_buffer:
pass
# TODO
self._save_checkpoint(runner)
# TODO Add master_only decorator
@master_only
def _save_checkpoint(self, runner) -> None:
"""Save the current checkpoint and delete outdated checkpoint.
@ -204,7 +198,4 @@ class CheckpointHook(Hook):
(self.save_last and self.is_last_iter(runner, mode='train')):
runner.logger.info(
f'Saving checkpoint at {runner.iter + 1} iterations')
if self.sync_buffer:
pass
# TODO
self._save_checkpoint(runner)

View File

@ -63,6 +63,7 @@ class TestCheckpointHook:
runner.work_dir = './tmp'
runner.epoch = 9
runner.meta = dict()
runner.model = Mock()
# by epoch is True
checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
@ -100,6 +101,7 @@ class TestCheckpointHook:
runner.work_dir = './tmp'
runner.iter = 9
runner.meta = dict()
runner.model = Mock()
# by epoch is True
checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)