[Feature]: Add dist semantics in checkpoint hook (#131)
* [Feature]: Add dist semantics in checkpoint hook * [Fix]: Delete sync buffer in checkpoint hookpull/140/head
parent
e4859030af
commit
26f24296db
|
@ -5,6 +5,7 @@ from pathlib import Path
|
|||
from typing import Any, Optional, Sequence, Tuple, Union
|
||||
|
||||
from mmengine.data import BaseDataSample
|
||||
from mmengine.dist import master_only
|
||||
from mmengine.fileio import FileClient
|
||||
from mmengine.registry import HOOKS
|
||||
from .hook import Hook
|
||||
|
@ -41,8 +42,6 @@ class CheckpointHook(Hook):
|
|||
Default: -1, which means unlimited.
|
||||
save_last (bool): Whether to force the last checkpoint to be
|
||||
saved regardless of interval. Default: True.
|
||||
sync_buffer (bool): Whether to synchronize buffers in
|
||||
different gpus. Default: False.
|
||||
file_client_args (dict, optional): Arguments to instantiate a
|
||||
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
||||
Default: None.
|
||||
|
@ -59,7 +58,6 @@ class CheckpointHook(Hook):
|
|||
out_dir: Optional[Union[str, Path]] = None,
|
||||
max_keep_ckpts: int = -1,
|
||||
save_last: bool = True,
|
||||
sync_buffer: bool = False,
|
||||
file_client_args: Optional[dict] = None,
|
||||
**kwargs) -> None:
|
||||
self.interval = interval
|
||||
|
@ -70,7 +68,6 @@ class CheckpointHook(Hook):
|
|||
self.max_keep_ckpts = max_keep_ckpts
|
||||
self.save_last = save_last
|
||||
self.args = kwargs
|
||||
self.sync_buffer = sync_buffer
|
||||
self.file_client_args = file_client_args
|
||||
|
||||
def before_run(self, runner) -> None:
|
||||
|
@ -129,12 +126,9 @@ class CheckpointHook(Hook):
|
|||
self.save_last and self.is_last_train_epoch(runner)):
|
||||
runner.logger.info(
|
||||
f'Saving checkpoint at {runner.epoch + 1} epochs')
|
||||
if self.sync_buffer:
|
||||
pass
|
||||
# TODO
|
||||
self._save_checkpoint(runner)
|
||||
|
||||
# TODO Add master_only decorator
|
||||
@master_only
|
||||
def _save_checkpoint(self, runner) -> None:
|
||||
"""Save the current checkpoint and delete outdated checkpoint.
|
||||
|
||||
|
@ -204,7 +198,4 @@ class CheckpointHook(Hook):
|
|||
(self.save_last and self.is_last_iter(runner, mode='train')):
|
||||
runner.logger.info(
|
||||
f'Saving checkpoint at {runner.iter + 1} iterations')
|
||||
if self.sync_buffer:
|
||||
pass
|
||||
# TODO
|
||||
self._save_checkpoint(runner)
|
||||
|
|
|
@ -63,6 +63,7 @@ class TestCheckpointHook:
|
|||
runner.work_dir = './tmp'
|
||||
runner.epoch = 9
|
||||
runner.meta = dict()
|
||||
runner.model = Mock()
|
||||
|
||||
# by epoch is True
|
||||
checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
|
||||
|
@ -100,6 +101,7 @@ class TestCheckpointHook:
|
|||
runner.work_dir = './tmp'
|
||||
runner.iter = 9
|
||||
runner.meta = dict()
|
||||
runner.model = Mock()
|
||||
|
||||
# by epoch is True
|
||||
checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
|
||||
|
|
Loading…
Reference in New Issue