[Feature] Support resume from Ceph (#294)

* support resume from ceph

* move func and refine

* delete symlink

* fix unittest

* perserve _allow_symlink and symlink
pull/308/head
Jiazhen Wang 2022-06-17 10:37:19 +08:00 committed by GitHub
parent d0d7174274
commit 7b55c5bdbf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 50 additions and 92 deletions

View File

@ -1,6 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import warnings
from pathlib import Path
from typing import Optional, Sequence, Union
@ -95,19 +94,6 @@ class CheckpointHook(Hook):
runner.logger.info(f'Checkpoints will be saved to {self.out_dir} by '
f'{self.file_client.name}.')
# disable the create_symlink option because some file backends do not
# allow to create a symlink
if 'create_symlink' in self.args:
if self.args[
'create_symlink'] and not self.file_client.allow_symlink:
self.args['create_symlink'] = False
warnings.warn(
'create_symlink is set as True by the user but is changed'
'to be False because creating symbolic link is not '
f'allowed in {self.file_client.name}')
else:
self.args['create_symlink'] = self.file_client.allow_symlink
def after_train_epoch(self, runner) -> None:
"""Save the checkpoint and synchronize buffers after each epoch.
@ -142,7 +128,8 @@ class CheckpointHook(Hook):
runner.save_checkpoint(
self.out_dir,
filename=ckpt_filename,
ckpt_filename,
self.file_client_args,
save_optimizer=self.save_optimizer,
save_param_scheduler=self.save_param_scheduler,
by_epoch=self.by_epoch,

View File

@ -1,7 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base_loop import BaseLoop
from .checkpoint import (CheckpointLoader, get_deprecated_model_names,
get_external_models, get_mmcls_models, get_state_dict,
from .checkpoint import (CheckpointLoader, find_latest_checkpoint,
get_deprecated_model_names, get_external_models,
get_mmcls_models, get_state_dict,
get_torchvision_models, load_checkpoint,
load_state_dict, save_checkpoint, weights_to_cpu)
from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
@ -12,5 +13,5 @@ __all__ = [
'get_external_models', 'get_mmcls_models', 'get_deprecated_model_names',
'CheckpointLoader', 'load_checkpoint', 'weights_to_cpu', 'get_state_dict',
'save_checkpoint', 'EpochBasedTrainLoop', 'IterBasedTrainLoop', 'ValLoop',
'TestLoop', 'Runner'
'TestLoop', 'Runner', 'find_latest_checkpoint'
]

View File

@ -695,3 +695,26 @@ def save_checkpoint(checkpoint, filename, file_client_args=None):
with io.BytesIO() as f:
torch.save(checkpoint, f)
file_client.put(f.getvalue(), filename)
def find_latest_checkpoint(path: str):
"""Find the latest checkpoint from the given path.
Refer to https://github.com/facebookresearch/fvcore/blob/main/fvcore/common/checkpoint.py # noqa: E501
Args:
path(str): The path to find checkpoints.
Returns:
str or None: File path of the latest checkpoint.
"""
save_file = osp.join(path, 'last_checkpoint')
try:
with open(save_file) as f:
last_saved = f.read().strip()
except OSError:
raise OSError(
'last_checkpoint file does not exist, maybe because it has just'
' been deleted by a separate process')
return last_saved

View File

@ -4,7 +4,6 @@ import os.path as osp
import platform
import random
import resource
import shutil
import time
import warnings
from collections import OrderedDict
@ -25,6 +24,7 @@ from mmengine.device import get_device
from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist,
master_only, sync_random_seed)
from mmengine.evaluator import Evaluator
from mmengine.fileio import FileClient
from mmengine.hooks import Hook
from mmengine.logging import LogProcessor, MessageHub, MMLogger
from mmengine.model import (BaseModel, MMDistributedDataParallel,
@ -36,13 +36,13 @@ from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS,
RUNNERS, VISUALIZERS, DefaultScope,
count_registered_modules)
from mmengine.registry.root import LOG_PROCESSORS
from mmengine.utils import (TORCH_VERSION, digit_version,
find_latest_checkpoint, get_git_hash, is_list_of,
set_multi_processing, symlink)
from mmengine.utils import (TORCH_VERSION, digit_version, get_git_hash,
is_list_of, set_multi_processing)
from mmengine.visualization import Visualizer
from .base_loop import BaseLoop
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
get_state_dict, save_checkpoint, weights_to_cpu)
find_latest_checkpoint, get_state_dict,
save_checkpoint, weights_to_cpu)
from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
from .priority import Priority, get_priority
@ -1920,10 +1920,10 @@ class Runner:
def save_checkpoint(self,
out_dir: str,
filename: str,
file_client_args: Optional[dict] = None,
save_optimizer: bool = True,
save_param_scheduler: bool = True,
meta: dict = None,
create_symlink: bool = True,
by_epoch: bool = True):
"""Save checkpoints.
@ -1933,15 +1933,16 @@ class Runner:
Args:
out_dir (str): The directory that checkpoints are saved.
filename (str): The checkpoint filename.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. Default: None.
save_optimizer (bool): Whether to save the optimizer to
the checkpoint. Defaults to True.
save_param_scheduler (bool): Whether to save the param_scheduler
to the checkpoint. Defaults to True.
meta (dict, optional): The meta information to be saved in the
checkpoint. Defaults to None.
create_symlink (bool): Whether to create a symlink
"latest.pth" to point to the latest checkpoint.
Defaults to True.
by_epoch (bool): Whether the scheduled momentum is updated by
epochs. Defaults to True.
"""
if meta is None:
meta = {}
@ -1961,7 +1962,8 @@ class Runner:
else:
meta.update(epoch=self.epoch, iter=self.iter + 1)
filepath = osp.join(out_dir, filename)
file_client = FileClient.infer_client(file_client_args, out_dir)
filepath = file_client.join_path(out_dir, filename)
meta.update(
cfg=self.cfg.pretty_text,
@ -2007,14 +2009,10 @@ class Runner:
self.call_hook('before_save_checkpoint', checkpoint=checkpoint)
save_checkpoint(checkpoint, filepath)
# in some environments, `os.symlink` is not supported, you may need to
# set `create_symlink` to False
if create_symlink:
dst_file = osp.join(out_dir, 'latest.pth')
if platform.system() != 'Windows':
symlink(filename, dst_file)
else:
shutil.copy(filepath, dst_file)
save_file = osp.join(self.work_dir, 'last_checkpoint')
with open(save_file, 'w') as f:
f.write(filepath)
@master_only
def dump_config(self) -> None:

View File

@ -2,10 +2,9 @@
from .hub import load_url
from .manager import ManagerMeta, ManagerMixin
from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
find_latest_checkpoint, has_batch_norm, has_method,
import_modules_from_strings, is_list_of,
is_method_overridden, is_seq_of, is_str, is_tuple_of,
iter_cast, list_cast, mmcv_full_available,
has_batch_norm, has_method, import_modules_from_strings,
is_list_of, is_method_overridden, is_seq_of, is_str,
is_tuple_of, iter_cast, list_cast, mmcv_full_available,
requires_executable, requires_package, slice_list,
to_1tuple, to_2tuple, to_3tuple, to_4tuple, to_ntuple,
tuple_cast)
@ -27,6 +26,6 @@ __all__ = [
'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple',
'is_method_overridden', 'has_method', 'mmcv_full_available',
'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url',
'find_latest_checkpoint', 'ManagerMeta', 'ManagerMixin',
'set_multi_processing', 'has_batch_norm', 'is_abs'
'ManagerMeta', 'ManagerMixin', 'set_multi_processing', 'has_batch_norm',
'is_abs'
]

View File

@ -1,9 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import collections.abc
import functools
import glob
import itertools
import os.path as osp
import pkgutil
import subprocess
import warnings
@ -481,40 +479,6 @@ def tensor2imgs(tensor: torch.Tensor,
return imgs
def find_latest_checkpoint(path: str, suffix: str = 'pth'):
"""Find the latest checkpoint from the given path.
Refer to https://github.com/microsoft/SoftTeacher/blob/main/ssod/utils/patch.py # noqa: E501
Args:
path(str): The path to find checkpoints.
suffix(str): File extension. Defaults to 'pth'.
Returns:
str or None: File path of the latest checkpoint.
"""
if not osp.exists(path):
raise FileNotFoundError('{path} does not exist.')
if osp.exists(osp.join(path, f'latest.{suffix}')):
return osp.join(path, f'latest.{suffix}')
checkpoints = glob.glob(osp.join(path, f'*.{suffix}'))
if len(checkpoints) == 0:
raise FileNotFoundError(f'checkpoints can not be found in {path}. '
'Maybe check the suffix again.')
latest = -1
latest_path = None
for checkpoint in checkpoints:
count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0])
if count > latest:
latest = count
latest_path = checkpoint
return latest_path
def has_batch_norm(model: nn.Module) -> bool:
"""Detect whether model has a BatchNormalization layer.

View File

@ -46,18 +46,6 @@ class TestCheckpointHook:
assert checkpoint_hook.out_dir == (
f'test_dir/{osp.basename(work_dir)}')
# create_symlink in args and create_symlink is True
checkpoint_hook = CheckpointHook(
interval=1, by_epoch=True, out_dir='test_dir', create_symlink=True)
checkpoint_hook.before_train(runner)
assert checkpoint_hook.args['create_symlink']
runner.work_dir = 's3://path/of/file'
checkpoint_hook = CheckpointHook(
interval=1, by_epoch=True, create_symlink=True)
checkpoint_hook.before_train(runner)
assert not checkpoint_hook.args['create_symlink']
def test_after_train_epoch(self, tmp_path):
runner = Mock()
work_dir = str(tmp_path)

View File

@ -1505,7 +1505,6 @@ class TestRunner(TestCase):
# 1.1 test `save_checkpoint` which is called by `CheckpointHook`
path = osp.join(self.temp_dir, 'epoch_3.pth')
self.assertTrue(osp.exists(path))
self.assertTrue(osp.exists(osp.join(self.temp_dir, 'latest.pth')))
self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_4.pth')))
ckpt = torch.load(path)
@ -1672,7 +1671,6 @@ class TestRunner(TestCase):
# 2.1 test `save_checkpoint` which is called by `CheckpointHook`
path = osp.join(self.temp_dir, 'iter_12.pth')
self.assertTrue(osp.exists(path))
self.assertTrue(osp.exists(osp.join(self.temp_dir, 'latest.pth')))
self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_13.pth')))
ckpt = torch.load(path)