[Feature] Support resume from Ceph (#294)
* support resume from ceph * move func and refine * delete symlink * fix unittest * perserve _allow_symlink and symlinkpull/308/head
parent
d0d7174274
commit
7b55c5bdbf
|
@ -1,6 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
|
@ -95,19 +94,6 @@ class CheckpointHook(Hook):
|
|||
runner.logger.info(f'Checkpoints will be saved to {self.out_dir} by '
|
||||
f'{self.file_client.name}.')
|
||||
|
||||
# disable the create_symlink option because some file backends do not
|
||||
# allow to create a symlink
|
||||
if 'create_symlink' in self.args:
|
||||
if self.args[
|
||||
'create_symlink'] and not self.file_client.allow_symlink:
|
||||
self.args['create_symlink'] = False
|
||||
warnings.warn(
|
||||
'create_symlink is set as True by the user but is changed'
|
||||
'to be False because creating symbolic link is not '
|
||||
f'allowed in {self.file_client.name}')
|
||||
else:
|
||||
self.args['create_symlink'] = self.file_client.allow_symlink
|
||||
|
||||
def after_train_epoch(self, runner) -> None:
|
||||
"""Save the checkpoint and synchronize buffers after each epoch.
|
||||
|
||||
|
@ -142,7 +128,8 @@ class CheckpointHook(Hook):
|
|||
|
||||
runner.save_checkpoint(
|
||||
self.out_dir,
|
||||
filename=ckpt_filename,
|
||||
ckpt_filename,
|
||||
self.file_client_args,
|
||||
save_optimizer=self.save_optimizer,
|
||||
save_param_scheduler=self.save_param_scheduler,
|
||||
by_epoch=self.by_epoch,
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .base_loop import BaseLoop
|
||||
from .checkpoint import (CheckpointLoader, get_deprecated_model_names,
|
||||
get_external_models, get_mmcls_models, get_state_dict,
|
||||
from .checkpoint import (CheckpointLoader, find_latest_checkpoint,
|
||||
get_deprecated_model_names, get_external_models,
|
||||
get_mmcls_models, get_state_dict,
|
||||
get_torchvision_models, load_checkpoint,
|
||||
load_state_dict, save_checkpoint, weights_to_cpu)
|
||||
from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
|
||||
|
@ -12,5 +13,5 @@ __all__ = [
|
|||
'get_external_models', 'get_mmcls_models', 'get_deprecated_model_names',
|
||||
'CheckpointLoader', 'load_checkpoint', 'weights_to_cpu', 'get_state_dict',
|
||||
'save_checkpoint', 'EpochBasedTrainLoop', 'IterBasedTrainLoop', 'ValLoop',
|
||||
'TestLoop', 'Runner'
|
||||
'TestLoop', 'Runner', 'find_latest_checkpoint'
|
||||
]
|
||||
|
|
|
@ -695,3 +695,26 @@ def save_checkpoint(checkpoint, filename, file_client_args=None):
|
|||
with io.BytesIO() as f:
|
||||
torch.save(checkpoint, f)
|
||||
file_client.put(f.getvalue(), filename)
|
||||
|
||||
|
||||
def find_latest_checkpoint(path: str):
|
||||
"""Find the latest checkpoint from the given path.
|
||||
|
||||
Refer to https://github.com/facebookresearch/fvcore/blob/main/fvcore/common/checkpoint.py # noqa: E501
|
||||
|
||||
Args:
|
||||
path(str): The path to find checkpoints.
|
||||
|
||||
Returns:
|
||||
str or None: File path of the latest checkpoint.
|
||||
"""
|
||||
save_file = osp.join(path, 'last_checkpoint')
|
||||
try:
|
||||
with open(save_file) as f:
|
||||
last_saved = f.read().strip()
|
||||
except OSError:
|
||||
raise OSError(
|
||||
'last_checkpoint file does not exist, maybe because it has just'
|
||||
' been deleted by a separate process')
|
||||
|
||||
return last_saved
|
||||
|
|
|
@ -4,7 +4,6 @@ import os.path as osp
|
|||
import platform
|
||||
import random
|
||||
import resource
|
||||
import shutil
|
||||
import time
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
|
@ -25,6 +24,7 @@ from mmengine.device import get_device
|
|||
from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist,
|
||||
master_only, sync_random_seed)
|
||||
from mmengine.evaluator import Evaluator
|
||||
from mmengine.fileio import FileClient
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.logging import LogProcessor, MessageHub, MMLogger
|
||||
from mmengine.model import (BaseModel, MMDistributedDataParallel,
|
||||
|
@ -36,13 +36,13 @@ from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS,
|
|||
RUNNERS, VISUALIZERS, DefaultScope,
|
||||
count_registered_modules)
|
||||
from mmengine.registry.root import LOG_PROCESSORS
|
||||
from mmengine.utils import (TORCH_VERSION, digit_version,
|
||||
find_latest_checkpoint, get_git_hash, is_list_of,
|
||||
set_multi_processing, symlink)
|
||||
from mmengine.utils import (TORCH_VERSION, digit_version, get_git_hash,
|
||||
is_list_of, set_multi_processing)
|
||||
from mmengine.visualization import Visualizer
|
||||
from .base_loop import BaseLoop
|
||||
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
|
||||
get_state_dict, save_checkpoint, weights_to_cpu)
|
||||
find_latest_checkpoint, get_state_dict,
|
||||
save_checkpoint, weights_to_cpu)
|
||||
from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
|
||||
from .priority import Priority, get_priority
|
||||
|
||||
|
@ -1920,10 +1920,10 @@ class Runner:
|
|||
def save_checkpoint(self,
|
||||
out_dir: str,
|
||||
filename: str,
|
||||
file_client_args: Optional[dict] = None,
|
||||
save_optimizer: bool = True,
|
||||
save_param_scheduler: bool = True,
|
||||
meta: dict = None,
|
||||
create_symlink: bool = True,
|
||||
by_epoch: bool = True):
|
||||
"""Save checkpoints.
|
||||
|
||||
|
@ -1933,15 +1933,16 @@ class Runner:
|
|||
Args:
|
||||
out_dir (str): The directory that checkpoints are saved.
|
||||
filename (str): The checkpoint filename.
|
||||
file_client_args (dict, optional): Arguments to instantiate a
|
||||
FileClient. Default: None.
|
||||
save_optimizer (bool): Whether to save the optimizer to
|
||||
the checkpoint. Defaults to True.
|
||||
save_param_scheduler (bool): Whether to save the param_scheduler
|
||||
to the checkpoint. Defaults to True.
|
||||
meta (dict, optional): The meta information to be saved in the
|
||||
checkpoint. Defaults to None.
|
||||
create_symlink (bool): Whether to create a symlink
|
||||
"latest.pth" to point to the latest checkpoint.
|
||||
Defaults to True.
|
||||
by_epoch (bool): Whether the scheduled momentum is updated by
|
||||
epochs. Defaults to True.
|
||||
"""
|
||||
if meta is None:
|
||||
meta = {}
|
||||
|
@ -1961,7 +1962,8 @@ class Runner:
|
|||
else:
|
||||
meta.update(epoch=self.epoch, iter=self.iter + 1)
|
||||
|
||||
filepath = osp.join(out_dir, filename)
|
||||
file_client = FileClient.infer_client(file_client_args, out_dir)
|
||||
filepath = file_client.join_path(out_dir, filename)
|
||||
|
||||
meta.update(
|
||||
cfg=self.cfg.pretty_text,
|
||||
|
@ -2007,14 +2009,10 @@ class Runner:
|
|||
|
||||
self.call_hook('before_save_checkpoint', checkpoint=checkpoint)
|
||||
save_checkpoint(checkpoint, filepath)
|
||||
# in some environments, `os.symlink` is not supported, you may need to
|
||||
# set `create_symlink` to False
|
||||
if create_symlink:
|
||||
dst_file = osp.join(out_dir, 'latest.pth')
|
||||
if platform.system() != 'Windows':
|
||||
symlink(filename, dst_file)
|
||||
else:
|
||||
shutil.copy(filepath, dst_file)
|
||||
|
||||
save_file = osp.join(self.work_dir, 'last_checkpoint')
|
||||
with open(save_file, 'w') as f:
|
||||
f.write(filepath)
|
||||
|
||||
@master_only
|
||||
def dump_config(self) -> None:
|
||||
|
|
|
@ -2,10 +2,9 @@
|
|||
from .hub import load_url
|
||||
from .manager import ManagerMeta, ManagerMixin
|
||||
from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
|
||||
find_latest_checkpoint, has_batch_norm, has_method,
|
||||
import_modules_from_strings, is_list_of,
|
||||
is_method_overridden, is_seq_of, is_str, is_tuple_of,
|
||||
iter_cast, list_cast, mmcv_full_available,
|
||||
has_batch_norm, has_method, import_modules_from_strings,
|
||||
is_list_of, is_method_overridden, is_seq_of, is_str,
|
||||
is_tuple_of, iter_cast, list_cast, mmcv_full_available,
|
||||
requires_executable, requires_package, slice_list,
|
||||
to_1tuple, to_2tuple, to_3tuple, to_4tuple, to_ntuple,
|
||||
tuple_cast)
|
||||
|
@ -27,6 +26,6 @@ __all__ = [
|
|||
'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple',
|
||||
'is_method_overridden', 'has_method', 'mmcv_full_available',
|
||||
'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url',
|
||||
'find_latest_checkpoint', 'ManagerMeta', 'ManagerMixin',
|
||||
'set_multi_processing', 'has_batch_norm', 'is_abs'
|
||||
'ManagerMeta', 'ManagerMixin', 'set_multi_processing', 'has_batch_norm',
|
||||
'is_abs'
|
||||
]
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import collections.abc
|
||||
import functools
|
||||
import glob
|
||||
import itertools
|
||||
import os.path as osp
|
||||
import pkgutil
|
||||
import subprocess
|
||||
import warnings
|
||||
|
@ -481,40 +479,6 @@ def tensor2imgs(tensor: torch.Tensor,
|
|||
return imgs
|
||||
|
||||
|
||||
def find_latest_checkpoint(path: str, suffix: str = 'pth'):
|
||||
"""Find the latest checkpoint from the given path.
|
||||
|
||||
Refer to https://github.com/microsoft/SoftTeacher/blob/main/ssod/utils/patch.py # noqa: E501
|
||||
|
||||
Args:
|
||||
path(str): The path to find checkpoints.
|
||||
suffix(str): File extension. Defaults to 'pth'.
|
||||
|
||||
Returns:
|
||||
str or None: File path of the latest checkpoint.
|
||||
"""
|
||||
if not osp.exists(path):
|
||||
raise FileNotFoundError('{path} does not exist.')
|
||||
|
||||
if osp.exists(osp.join(path, f'latest.{suffix}')):
|
||||
return osp.join(path, f'latest.{suffix}')
|
||||
|
||||
checkpoints = glob.glob(osp.join(path, f'*.{suffix}'))
|
||||
if len(checkpoints) == 0:
|
||||
raise FileNotFoundError(f'checkpoints can not be found in {path}. '
|
||||
'Maybe check the suffix again.')
|
||||
|
||||
latest = -1
|
||||
latest_path = None
|
||||
for checkpoint in checkpoints:
|
||||
count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0])
|
||||
if count > latest:
|
||||
latest = count
|
||||
latest_path = checkpoint
|
||||
|
||||
return latest_path
|
||||
|
||||
|
||||
def has_batch_norm(model: nn.Module) -> bool:
|
||||
"""Detect whether model has a BatchNormalization layer.
|
||||
|
||||
|
|
|
@ -46,18 +46,6 @@ class TestCheckpointHook:
|
|||
assert checkpoint_hook.out_dir == (
|
||||
f'test_dir/{osp.basename(work_dir)}')
|
||||
|
||||
# create_symlink in args and create_symlink is True
|
||||
checkpoint_hook = CheckpointHook(
|
||||
interval=1, by_epoch=True, out_dir='test_dir', create_symlink=True)
|
||||
checkpoint_hook.before_train(runner)
|
||||
assert checkpoint_hook.args['create_symlink']
|
||||
|
||||
runner.work_dir = 's3://path/of/file'
|
||||
checkpoint_hook = CheckpointHook(
|
||||
interval=1, by_epoch=True, create_symlink=True)
|
||||
checkpoint_hook.before_train(runner)
|
||||
assert not checkpoint_hook.args['create_symlink']
|
||||
|
||||
def test_after_train_epoch(self, tmp_path):
|
||||
runner = Mock()
|
||||
work_dir = str(tmp_path)
|
||||
|
|
|
@ -1505,7 +1505,6 @@ class TestRunner(TestCase):
|
|||
# 1.1 test `save_checkpoint` which is called by `CheckpointHook`
|
||||
path = osp.join(self.temp_dir, 'epoch_3.pth')
|
||||
self.assertTrue(osp.exists(path))
|
||||
self.assertTrue(osp.exists(osp.join(self.temp_dir, 'latest.pth')))
|
||||
self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_4.pth')))
|
||||
|
||||
ckpt = torch.load(path)
|
||||
|
@ -1672,7 +1671,6 @@ class TestRunner(TestCase):
|
|||
# 2.1 test `save_checkpoint` which is called by `CheckpointHook`
|
||||
path = osp.join(self.temp_dir, 'iter_12.pth')
|
||||
self.assertTrue(osp.exists(path))
|
||||
self.assertTrue(osp.exists(osp.join(self.temp_dir, 'latest.pth')))
|
||||
self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_13.pth')))
|
||||
|
||||
ckpt = torch.load(path)
|
||||
|
|
Loading…
Reference in New Issue