[Feature] Support resume from Ceph (#294)

* support resume from ceph

* move func and refine

* delete symlink

* fix unittest

* perserve _allow_symlink and symlink
pull/308/head
Jiazhen Wang 2022-06-17 10:37:19 +08:00 committed by GitHub
parent d0d7174274
commit 7b55c5bdbf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 50 additions and 92 deletions

View File

@ -1,6 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp import os.path as osp
import warnings
from pathlib import Path from pathlib import Path
from typing import Optional, Sequence, Union from typing import Optional, Sequence, Union
@ -95,19 +94,6 @@ class CheckpointHook(Hook):
runner.logger.info(f'Checkpoints will be saved to {self.out_dir} by ' runner.logger.info(f'Checkpoints will be saved to {self.out_dir} by '
f'{self.file_client.name}.') f'{self.file_client.name}.')
# disable the create_symlink option because some file backends do not
# allow to create a symlink
if 'create_symlink' in self.args:
if self.args[
'create_symlink'] and not self.file_client.allow_symlink:
self.args['create_symlink'] = False
warnings.warn(
'create_symlink is set as True by the user but is changed'
'to be False because creating symbolic link is not '
f'allowed in {self.file_client.name}')
else:
self.args['create_symlink'] = self.file_client.allow_symlink
def after_train_epoch(self, runner) -> None: def after_train_epoch(self, runner) -> None:
"""Save the checkpoint and synchronize buffers after each epoch. """Save the checkpoint and synchronize buffers after each epoch.
@ -142,7 +128,8 @@ class CheckpointHook(Hook):
runner.save_checkpoint( runner.save_checkpoint(
self.out_dir, self.out_dir,
filename=ckpt_filename, ckpt_filename,
self.file_client_args,
save_optimizer=self.save_optimizer, save_optimizer=self.save_optimizer,
save_param_scheduler=self.save_param_scheduler, save_param_scheduler=self.save_param_scheduler,
by_epoch=self.by_epoch, by_epoch=self.by_epoch,

View File

@ -1,7 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .base_loop import BaseLoop from .base_loop import BaseLoop
from .checkpoint import (CheckpointLoader, get_deprecated_model_names, from .checkpoint import (CheckpointLoader, find_latest_checkpoint,
get_external_models, get_mmcls_models, get_state_dict, get_deprecated_model_names, get_external_models,
get_mmcls_models, get_state_dict,
get_torchvision_models, load_checkpoint, get_torchvision_models, load_checkpoint,
load_state_dict, save_checkpoint, weights_to_cpu) load_state_dict, save_checkpoint, weights_to_cpu)
from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
@ -12,5 +13,5 @@ __all__ = [
'get_external_models', 'get_mmcls_models', 'get_deprecated_model_names', 'get_external_models', 'get_mmcls_models', 'get_deprecated_model_names',
'CheckpointLoader', 'load_checkpoint', 'weights_to_cpu', 'get_state_dict', 'CheckpointLoader', 'load_checkpoint', 'weights_to_cpu', 'get_state_dict',
'save_checkpoint', 'EpochBasedTrainLoop', 'IterBasedTrainLoop', 'ValLoop', 'save_checkpoint', 'EpochBasedTrainLoop', 'IterBasedTrainLoop', 'ValLoop',
'TestLoop', 'Runner' 'TestLoop', 'Runner', 'find_latest_checkpoint'
] ]

View File

@ -695,3 +695,26 @@ def save_checkpoint(checkpoint, filename, file_client_args=None):
with io.BytesIO() as f: with io.BytesIO() as f:
torch.save(checkpoint, f) torch.save(checkpoint, f)
file_client.put(f.getvalue(), filename) file_client.put(f.getvalue(), filename)
def find_latest_checkpoint(path: str):
"""Find the latest checkpoint from the given path.
Refer to https://github.com/facebookresearch/fvcore/blob/main/fvcore/common/checkpoint.py # noqa: E501
Args:
path(str): The path to find checkpoints.
Returns:
str or None: File path of the latest checkpoint.
"""
save_file = osp.join(path, 'last_checkpoint')
try:
with open(save_file) as f:
last_saved = f.read().strip()
except OSError:
raise OSError(
'last_checkpoint file does not exist, maybe because it has just'
' been deleted by a separate process')
return last_saved

View File

@ -4,7 +4,6 @@ import os.path as osp
import platform import platform
import random import random
import resource import resource
import shutil
import time import time
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
@ -25,6 +24,7 @@ from mmengine.device import get_device
from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist, from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist,
master_only, sync_random_seed) master_only, sync_random_seed)
from mmengine.evaluator import Evaluator from mmengine.evaluator import Evaluator
from mmengine.fileio import FileClient
from mmengine.hooks import Hook from mmengine.hooks import Hook
from mmengine.logging import LogProcessor, MessageHub, MMLogger from mmengine.logging import LogProcessor, MessageHub, MMLogger
from mmengine.model import (BaseModel, MMDistributedDataParallel, from mmengine.model import (BaseModel, MMDistributedDataParallel,
@ -36,13 +36,13 @@ from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS,
RUNNERS, VISUALIZERS, DefaultScope, RUNNERS, VISUALIZERS, DefaultScope,
count_registered_modules) count_registered_modules)
from mmengine.registry.root import LOG_PROCESSORS from mmengine.registry.root import LOG_PROCESSORS
from mmengine.utils import (TORCH_VERSION, digit_version, from mmengine.utils import (TORCH_VERSION, digit_version, get_git_hash,
find_latest_checkpoint, get_git_hash, is_list_of, is_list_of, set_multi_processing)
set_multi_processing, symlink)
from mmengine.visualization import Visualizer from mmengine.visualization import Visualizer
from .base_loop import BaseLoop from .base_loop import BaseLoop
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model, from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
get_state_dict, save_checkpoint, weights_to_cpu) find_latest_checkpoint, get_state_dict,
save_checkpoint, weights_to_cpu)
from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
from .priority import Priority, get_priority from .priority import Priority, get_priority
@ -1920,10 +1920,10 @@ class Runner:
def save_checkpoint(self, def save_checkpoint(self,
out_dir: str, out_dir: str,
filename: str, filename: str,
file_client_args: Optional[dict] = None,
save_optimizer: bool = True, save_optimizer: bool = True,
save_param_scheduler: bool = True, save_param_scheduler: bool = True,
meta: dict = None, meta: dict = None,
create_symlink: bool = True,
by_epoch: bool = True): by_epoch: bool = True):
"""Save checkpoints. """Save checkpoints.
@ -1933,15 +1933,16 @@ class Runner:
Args: Args:
out_dir (str): The directory that checkpoints are saved. out_dir (str): The directory that checkpoints are saved.
filename (str): The checkpoint filename. filename (str): The checkpoint filename.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. Default: None.
save_optimizer (bool): Whether to save the optimizer to save_optimizer (bool): Whether to save the optimizer to
the checkpoint. Defaults to True. the checkpoint. Defaults to True.
save_param_scheduler (bool): Whether to save the param_scheduler save_param_scheduler (bool): Whether to save the param_scheduler
to the checkpoint. Defaults to True. to the checkpoint. Defaults to True.
meta (dict, optional): The meta information to be saved in the meta (dict, optional): The meta information to be saved in the
checkpoint. Defaults to None. checkpoint. Defaults to None.
create_symlink (bool): Whether to create a symlink by_epoch (bool): Whether the scheduled momentum is updated by
"latest.pth" to point to the latest checkpoint. epochs. Defaults to True.
Defaults to True.
""" """
if meta is None: if meta is None:
meta = {} meta = {}
@ -1961,7 +1962,8 @@ class Runner:
else: else:
meta.update(epoch=self.epoch, iter=self.iter + 1) meta.update(epoch=self.epoch, iter=self.iter + 1)
filepath = osp.join(out_dir, filename) file_client = FileClient.infer_client(file_client_args, out_dir)
filepath = file_client.join_path(out_dir, filename)
meta.update( meta.update(
cfg=self.cfg.pretty_text, cfg=self.cfg.pretty_text,
@ -2007,14 +2009,10 @@ class Runner:
self.call_hook('before_save_checkpoint', checkpoint=checkpoint) self.call_hook('before_save_checkpoint', checkpoint=checkpoint)
save_checkpoint(checkpoint, filepath) save_checkpoint(checkpoint, filepath)
# in some environments, `os.symlink` is not supported, you may need to
# set `create_symlink` to False save_file = osp.join(self.work_dir, 'last_checkpoint')
if create_symlink: with open(save_file, 'w') as f:
dst_file = osp.join(out_dir, 'latest.pth') f.write(filepath)
if platform.system() != 'Windows':
symlink(filename, dst_file)
else:
shutil.copy(filepath, dst_file)
@master_only @master_only
def dump_config(self) -> None: def dump_config(self) -> None:

View File

@ -2,10 +2,9 @@
from .hub import load_url from .hub import load_url
from .manager import ManagerMeta, ManagerMixin from .manager import ManagerMeta, ManagerMixin
from .misc import (check_prerequisites, concat_list, deprecated_api_warning, from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
find_latest_checkpoint, has_batch_norm, has_method, has_batch_norm, has_method, import_modules_from_strings,
import_modules_from_strings, is_list_of, is_list_of, is_method_overridden, is_seq_of, is_str,
is_method_overridden, is_seq_of, is_str, is_tuple_of, is_tuple_of, iter_cast, list_cast, mmcv_full_available,
iter_cast, list_cast, mmcv_full_available,
requires_executable, requires_package, slice_list, requires_executable, requires_package, slice_list,
to_1tuple, to_2tuple, to_3tuple, to_4tuple, to_ntuple, to_1tuple, to_2tuple, to_3tuple, to_4tuple, to_ntuple,
tuple_cast) tuple_cast)
@ -27,6 +26,6 @@ __all__ = [
'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple', 'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple',
'is_method_overridden', 'has_method', 'mmcv_full_available', 'is_method_overridden', 'has_method', 'mmcv_full_available',
'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url', 'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url',
'find_latest_checkpoint', 'ManagerMeta', 'ManagerMixin', 'ManagerMeta', 'ManagerMixin', 'set_multi_processing', 'has_batch_norm',
'set_multi_processing', 'has_batch_norm', 'is_abs' 'is_abs'
] ]

View File

@ -1,9 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import collections.abc import collections.abc
import functools import functools
import glob
import itertools import itertools
import os.path as osp
import pkgutil import pkgutil
import subprocess import subprocess
import warnings import warnings
@ -481,40 +479,6 @@ def tensor2imgs(tensor: torch.Tensor,
return imgs return imgs
def find_latest_checkpoint(path: str, suffix: str = 'pth'):
"""Find the latest checkpoint from the given path.
Refer to https://github.com/microsoft/SoftTeacher/blob/main/ssod/utils/patch.py # noqa: E501
Args:
path(str): The path to find checkpoints.
suffix(str): File extension. Defaults to 'pth'.
Returns:
str or None: File path of the latest checkpoint.
"""
if not osp.exists(path):
raise FileNotFoundError('{path} does not exist.')
if osp.exists(osp.join(path, f'latest.{suffix}')):
return osp.join(path, f'latest.{suffix}')
checkpoints = glob.glob(osp.join(path, f'*.{suffix}'))
if len(checkpoints) == 0:
raise FileNotFoundError(f'checkpoints can not be found in {path}. '
'Maybe check the suffix again.')
latest = -1
latest_path = None
for checkpoint in checkpoints:
count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0])
if count > latest:
latest = count
latest_path = checkpoint
return latest_path
def has_batch_norm(model: nn.Module) -> bool: def has_batch_norm(model: nn.Module) -> bool:
"""Detect whether model has a BatchNormalization layer. """Detect whether model has a BatchNormalization layer.

View File

@ -46,18 +46,6 @@ class TestCheckpointHook:
assert checkpoint_hook.out_dir == ( assert checkpoint_hook.out_dir == (
f'test_dir/{osp.basename(work_dir)}') f'test_dir/{osp.basename(work_dir)}')
# create_symlink in args and create_symlink is True
checkpoint_hook = CheckpointHook(
interval=1, by_epoch=True, out_dir='test_dir', create_symlink=True)
checkpoint_hook.before_train(runner)
assert checkpoint_hook.args['create_symlink']
runner.work_dir = 's3://path/of/file'
checkpoint_hook = CheckpointHook(
interval=1, by_epoch=True, create_symlink=True)
checkpoint_hook.before_train(runner)
assert not checkpoint_hook.args['create_symlink']
def test_after_train_epoch(self, tmp_path): def test_after_train_epoch(self, tmp_path):
runner = Mock() runner = Mock()
work_dir = str(tmp_path) work_dir = str(tmp_path)

View File

@ -1505,7 +1505,6 @@ class TestRunner(TestCase):
# 1.1 test `save_checkpoint` which is called by `CheckpointHook` # 1.1 test `save_checkpoint` which is called by `CheckpointHook`
path = osp.join(self.temp_dir, 'epoch_3.pth') path = osp.join(self.temp_dir, 'epoch_3.pth')
self.assertTrue(osp.exists(path)) self.assertTrue(osp.exists(path))
self.assertTrue(osp.exists(osp.join(self.temp_dir, 'latest.pth')))
self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_4.pth'))) self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_4.pth')))
ckpt = torch.load(path) ckpt = torch.load(path)
@ -1672,7 +1671,6 @@ class TestRunner(TestCase):
# 2.1 test `save_checkpoint` which is called by `CheckpointHook` # 2.1 test `save_checkpoint` which is called by `CheckpointHook`
path = osp.join(self.temp_dir, 'iter_12.pth') path = osp.join(self.temp_dir, 'iter_12.pth')
self.assertTrue(osp.exists(path)) self.assertTrue(osp.exists(path))
self.assertTrue(osp.exists(osp.join(self.temp_dir, 'latest.pth')))
self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_13.pth'))) self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_13.pth')))
ckpt = torch.load(path) ckpt = torch.load(path)