[Feature] Support resume from Ceph (#294)
* support resume from ceph * move func and refine * delete symlink * fix unittest * perserve _allow_symlink and symlinkpull/308/head
parent
d0d7174274
commit
7b55c5bdbf
|
@ -1,6 +1,5 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import warnings
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Sequence, Union
|
from typing import Optional, Sequence, Union
|
||||||
|
|
||||||
|
@ -95,19 +94,6 @@ class CheckpointHook(Hook):
|
||||||
runner.logger.info(f'Checkpoints will be saved to {self.out_dir} by '
|
runner.logger.info(f'Checkpoints will be saved to {self.out_dir} by '
|
||||||
f'{self.file_client.name}.')
|
f'{self.file_client.name}.')
|
||||||
|
|
||||||
# disable the create_symlink option because some file backends do not
|
|
||||||
# allow to create a symlink
|
|
||||||
if 'create_symlink' in self.args:
|
|
||||||
if self.args[
|
|
||||||
'create_symlink'] and not self.file_client.allow_symlink:
|
|
||||||
self.args['create_symlink'] = False
|
|
||||||
warnings.warn(
|
|
||||||
'create_symlink is set as True by the user but is changed'
|
|
||||||
'to be False because creating symbolic link is not '
|
|
||||||
f'allowed in {self.file_client.name}')
|
|
||||||
else:
|
|
||||||
self.args['create_symlink'] = self.file_client.allow_symlink
|
|
||||||
|
|
||||||
def after_train_epoch(self, runner) -> None:
|
def after_train_epoch(self, runner) -> None:
|
||||||
"""Save the checkpoint and synchronize buffers after each epoch.
|
"""Save the checkpoint and synchronize buffers after each epoch.
|
||||||
|
|
||||||
|
@ -142,7 +128,8 @@ class CheckpointHook(Hook):
|
||||||
|
|
||||||
runner.save_checkpoint(
|
runner.save_checkpoint(
|
||||||
self.out_dir,
|
self.out_dir,
|
||||||
filename=ckpt_filename,
|
ckpt_filename,
|
||||||
|
self.file_client_args,
|
||||||
save_optimizer=self.save_optimizer,
|
save_optimizer=self.save_optimizer,
|
||||||
save_param_scheduler=self.save_param_scheduler,
|
save_param_scheduler=self.save_param_scheduler,
|
||||||
by_epoch=self.by_epoch,
|
by_epoch=self.by_epoch,
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .base_loop import BaseLoop
|
from .base_loop import BaseLoop
|
||||||
from .checkpoint import (CheckpointLoader, get_deprecated_model_names,
|
from .checkpoint import (CheckpointLoader, find_latest_checkpoint,
|
||||||
get_external_models, get_mmcls_models, get_state_dict,
|
get_deprecated_model_names, get_external_models,
|
||||||
|
get_mmcls_models, get_state_dict,
|
||||||
get_torchvision_models, load_checkpoint,
|
get_torchvision_models, load_checkpoint,
|
||||||
load_state_dict, save_checkpoint, weights_to_cpu)
|
load_state_dict, save_checkpoint, weights_to_cpu)
|
||||||
from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
|
from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
|
||||||
|
@ -12,5 +13,5 @@ __all__ = [
|
||||||
'get_external_models', 'get_mmcls_models', 'get_deprecated_model_names',
|
'get_external_models', 'get_mmcls_models', 'get_deprecated_model_names',
|
||||||
'CheckpointLoader', 'load_checkpoint', 'weights_to_cpu', 'get_state_dict',
|
'CheckpointLoader', 'load_checkpoint', 'weights_to_cpu', 'get_state_dict',
|
||||||
'save_checkpoint', 'EpochBasedTrainLoop', 'IterBasedTrainLoop', 'ValLoop',
|
'save_checkpoint', 'EpochBasedTrainLoop', 'IterBasedTrainLoop', 'ValLoop',
|
||||||
'TestLoop', 'Runner'
|
'TestLoop', 'Runner', 'find_latest_checkpoint'
|
||||||
]
|
]
|
||||||
|
|
|
@ -695,3 +695,26 @@ def save_checkpoint(checkpoint, filename, file_client_args=None):
|
||||||
with io.BytesIO() as f:
|
with io.BytesIO() as f:
|
||||||
torch.save(checkpoint, f)
|
torch.save(checkpoint, f)
|
||||||
file_client.put(f.getvalue(), filename)
|
file_client.put(f.getvalue(), filename)
|
||||||
|
|
||||||
|
|
||||||
|
def find_latest_checkpoint(path: str):
|
||||||
|
"""Find the latest checkpoint from the given path.
|
||||||
|
|
||||||
|
Refer to https://github.com/facebookresearch/fvcore/blob/main/fvcore/common/checkpoint.py # noqa: E501
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path(str): The path to find checkpoints.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str or None: File path of the latest checkpoint.
|
||||||
|
"""
|
||||||
|
save_file = osp.join(path, 'last_checkpoint')
|
||||||
|
try:
|
||||||
|
with open(save_file) as f:
|
||||||
|
last_saved = f.read().strip()
|
||||||
|
except OSError:
|
||||||
|
raise OSError(
|
||||||
|
'last_checkpoint file does not exist, maybe because it has just'
|
||||||
|
' been deleted by a separate process')
|
||||||
|
|
||||||
|
return last_saved
|
||||||
|
|
|
@ -4,7 +4,6 @@ import os.path as osp
|
||||||
import platform
|
import platform
|
||||||
import random
|
import random
|
||||||
import resource
|
import resource
|
||||||
import shutil
|
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
@ -25,6 +24,7 @@ from mmengine.device import get_device
|
||||||
from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist,
|
from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist,
|
||||||
master_only, sync_random_seed)
|
master_only, sync_random_seed)
|
||||||
from mmengine.evaluator import Evaluator
|
from mmengine.evaluator import Evaluator
|
||||||
|
from mmengine.fileio import FileClient
|
||||||
from mmengine.hooks import Hook
|
from mmengine.hooks import Hook
|
||||||
from mmengine.logging import LogProcessor, MessageHub, MMLogger
|
from mmengine.logging import LogProcessor, MessageHub, MMLogger
|
||||||
from mmengine.model import (BaseModel, MMDistributedDataParallel,
|
from mmengine.model import (BaseModel, MMDistributedDataParallel,
|
||||||
|
@ -36,13 +36,13 @@ from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS,
|
||||||
RUNNERS, VISUALIZERS, DefaultScope,
|
RUNNERS, VISUALIZERS, DefaultScope,
|
||||||
count_registered_modules)
|
count_registered_modules)
|
||||||
from mmengine.registry.root import LOG_PROCESSORS
|
from mmengine.registry.root import LOG_PROCESSORS
|
||||||
from mmengine.utils import (TORCH_VERSION, digit_version,
|
from mmengine.utils import (TORCH_VERSION, digit_version, get_git_hash,
|
||||||
find_latest_checkpoint, get_git_hash, is_list_of,
|
is_list_of, set_multi_processing)
|
||||||
set_multi_processing, symlink)
|
|
||||||
from mmengine.visualization import Visualizer
|
from mmengine.visualization import Visualizer
|
||||||
from .base_loop import BaseLoop
|
from .base_loop import BaseLoop
|
||||||
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
|
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
|
||||||
get_state_dict, save_checkpoint, weights_to_cpu)
|
find_latest_checkpoint, get_state_dict,
|
||||||
|
save_checkpoint, weights_to_cpu)
|
||||||
from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
|
from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
|
||||||
from .priority import Priority, get_priority
|
from .priority import Priority, get_priority
|
||||||
|
|
||||||
|
@ -1920,10 +1920,10 @@ class Runner:
|
||||||
def save_checkpoint(self,
|
def save_checkpoint(self,
|
||||||
out_dir: str,
|
out_dir: str,
|
||||||
filename: str,
|
filename: str,
|
||||||
|
file_client_args: Optional[dict] = None,
|
||||||
save_optimizer: bool = True,
|
save_optimizer: bool = True,
|
||||||
save_param_scheduler: bool = True,
|
save_param_scheduler: bool = True,
|
||||||
meta: dict = None,
|
meta: dict = None,
|
||||||
create_symlink: bool = True,
|
|
||||||
by_epoch: bool = True):
|
by_epoch: bool = True):
|
||||||
"""Save checkpoints.
|
"""Save checkpoints.
|
||||||
|
|
||||||
|
@ -1933,15 +1933,16 @@ class Runner:
|
||||||
Args:
|
Args:
|
||||||
out_dir (str): The directory that checkpoints are saved.
|
out_dir (str): The directory that checkpoints are saved.
|
||||||
filename (str): The checkpoint filename.
|
filename (str): The checkpoint filename.
|
||||||
|
file_client_args (dict, optional): Arguments to instantiate a
|
||||||
|
FileClient. Default: None.
|
||||||
save_optimizer (bool): Whether to save the optimizer to
|
save_optimizer (bool): Whether to save the optimizer to
|
||||||
the checkpoint. Defaults to True.
|
the checkpoint. Defaults to True.
|
||||||
save_param_scheduler (bool): Whether to save the param_scheduler
|
save_param_scheduler (bool): Whether to save the param_scheduler
|
||||||
to the checkpoint. Defaults to True.
|
to the checkpoint. Defaults to True.
|
||||||
meta (dict, optional): The meta information to be saved in the
|
meta (dict, optional): The meta information to be saved in the
|
||||||
checkpoint. Defaults to None.
|
checkpoint. Defaults to None.
|
||||||
create_symlink (bool): Whether to create a symlink
|
by_epoch (bool): Whether the scheduled momentum is updated by
|
||||||
"latest.pth" to point to the latest checkpoint.
|
epochs. Defaults to True.
|
||||||
Defaults to True.
|
|
||||||
"""
|
"""
|
||||||
if meta is None:
|
if meta is None:
|
||||||
meta = {}
|
meta = {}
|
||||||
|
@ -1961,7 +1962,8 @@ class Runner:
|
||||||
else:
|
else:
|
||||||
meta.update(epoch=self.epoch, iter=self.iter + 1)
|
meta.update(epoch=self.epoch, iter=self.iter + 1)
|
||||||
|
|
||||||
filepath = osp.join(out_dir, filename)
|
file_client = FileClient.infer_client(file_client_args, out_dir)
|
||||||
|
filepath = file_client.join_path(out_dir, filename)
|
||||||
|
|
||||||
meta.update(
|
meta.update(
|
||||||
cfg=self.cfg.pretty_text,
|
cfg=self.cfg.pretty_text,
|
||||||
|
@ -2007,14 +2009,10 @@ class Runner:
|
||||||
|
|
||||||
self.call_hook('before_save_checkpoint', checkpoint=checkpoint)
|
self.call_hook('before_save_checkpoint', checkpoint=checkpoint)
|
||||||
save_checkpoint(checkpoint, filepath)
|
save_checkpoint(checkpoint, filepath)
|
||||||
# in some environments, `os.symlink` is not supported, you may need to
|
|
||||||
# set `create_symlink` to False
|
save_file = osp.join(self.work_dir, 'last_checkpoint')
|
||||||
if create_symlink:
|
with open(save_file, 'w') as f:
|
||||||
dst_file = osp.join(out_dir, 'latest.pth')
|
f.write(filepath)
|
||||||
if platform.system() != 'Windows':
|
|
||||||
symlink(filename, dst_file)
|
|
||||||
else:
|
|
||||||
shutil.copy(filepath, dst_file)
|
|
||||||
|
|
||||||
@master_only
|
@master_only
|
||||||
def dump_config(self) -> None:
|
def dump_config(self) -> None:
|
||||||
|
|
|
@ -2,10 +2,9 @@
|
||||||
from .hub import load_url
|
from .hub import load_url
|
||||||
from .manager import ManagerMeta, ManagerMixin
|
from .manager import ManagerMeta, ManagerMixin
|
||||||
from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
|
from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
|
||||||
find_latest_checkpoint, has_batch_norm, has_method,
|
has_batch_norm, has_method, import_modules_from_strings,
|
||||||
import_modules_from_strings, is_list_of,
|
is_list_of, is_method_overridden, is_seq_of, is_str,
|
||||||
is_method_overridden, is_seq_of, is_str, is_tuple_of,
|
is_tuple_of, iter_cast, list_cast, mmcv_full_available,
|
||||||
iter_cast, list_cast, mmcv_full_available,
|
|
||||||
requires_executable, requires_package, slice_list,
|
requires_executable, requires_package, slice_list,
|
||||||
to_1tuple, to_2tuple, to_3tuple, to_4tuple, to_ntuple,
|
to_1tuple, to_2tuple, to_3tuple, to_4tuple, to_ntuple,
|
||||||
tuple_cast)
|
tuple_cast)
|
||||||
|
@ -27,6 +26,6 @@ __all__ = [
|
||||||
'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple',
|
'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple',
|
||||||
'is_method_overridden', 'has_method', 'mmcv_full_available',
|
'is_method_overridden', 'has_method', 'mmcv_full_available',
|
||||||
'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url',
|
'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url',
|
||||||
'find_latest_checkpoint', 'ManagerMeta', 'ManagerMixin',
|
'ManagerMeta', 'ManagerMixin', 'set_multi_processing', 'has_batch_norm',
|
||||||
'set_multi_processing', 'has_batch_norm', 'is_abs'
|
'is_abs'
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,9 +1,7 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import collections.abc
|
import collections.abc
|
||||||
import functools
|
import functools
|
||||||
import glob
|
|
||||||
import itertools
|
import itertools
|
||||||
import os.path as osp
|
|
||||||
import pkgutil
|
import pkgutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import warnings
|
import warnings
|
||||||
|
@ -481,40 +479,6 @@ def tensor2imgs(tensor: torch.Tensor,
|
||||||
return imgs
|
return imgs
|
||||||
|
|
||||||
|
|
||||||
def find_latest_checkpoint(path: str, suffix: str = 'pth'):
|
|
||||||
"""Find the latest checkpoint from the given path.
|
|
||||||
|
|
||||||
Refer to https://github.com/microsoft/SoftTeacher/blob/main/ssod/utils/patch.py # noqa: E501
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path(str): The path to find checkpoints.
|
|
||||||
suffix(str): File extension. Defaults to 'pth'.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str or None: File path of the latest checkpoint.
|
|
||||||
"""
|
|
||||||
if not osp.exists(path):
|
|
||||||
raise FileNotFoundError('{path} does not exist.')
|
|
||||||
|
|
||||||
if osp.exists(osp.join(path, f'latest.{suffix}')):
|
|
||||||
return osp.join(path, f'latest.{suffix}')
|
|
||||||
|
|
||||||
checkpoints = glob.glob(osp.join(path, f'*.{suffix}'))
|
|
||||||
if len(checkpoints) == 0:
|
|
||||||
raise FileNotFoundError(f'checkpoints can not be found in {path}. '
|
|
||||||
'Maybe check the suffix again.')
|
|
||||||
|
|
||||||
latest = -1
|
|
||||||
latest_path = None
|
|
||||||
for checkpoint in checkpoints:
|
|
||||||
count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0])
|
|
||||||
if count > latest:
|
|
||||||
latest = count
|
|
||||||
latest_path = checkpoint
|
|
||||||
|
|
||||||
return latest_path
|
|
||||||
|
|
||||||
|
|
||||||
def has_batch_norm(model: nn.Module) -> bool:
|
def has_batch_norm(model: nn.Module) -> bool:
|
||||||
"""Detect whether model has a BatchNormalization layer.
|
"""Detect whether model has a BatchNormalization layer.
|
||||||
|
|
||||||
|
|
|
@ -46,18 +46,6 @@ class TestCheckpointHook:
|
||||||
assert checkpoint_hook.out_dir == (
|
assert checkpoint_hook.out_dir == (
|
||||||
f'test_dir/{osp.basename(work_dir)}')
|
f'test_dir/{osp.basename(work_dir)}')
|
||||||
|
|
||||||
# create_symlink in args and create_symlink is True
|
|
||||||
checkpoint_hook = CheckpointHook(
|
|
||||||
interval=1, by_epoch=True, out_dir='test_dir', create_symlink=True)
|
|
||||||
checkpoint_hook.before_train(runner)
|
|
||||||
assert checkpoint_hook.args['create_symlink']
|
|
||||||
|
|
||||||
runner.work_dir = 's3://path/of/file'
|
|
||||||
checkpoint_hook = CheckpointHook(
|
|
||||||
interval=1, by_epoch=True, create_symlink=True)
|
|
||||||
checkpoint_hook.before_train(runner)
|
|
||||||
assert not checkpoint_hook.args['create_symlink']
|
|
||||||
|
|
||||||
def test_after_train_epoch(self, tmp_path):
|
def test_after_train_epoch(self, tmp_path):
|
||||||
runner = Mock()
|
runner = Mock()
|
||||||
work_dir = str(tmp_path)
|
work_dir = str(tmp_path)
|
||||||
|
|
|
@ -1505,7 +1505,6 @@ class TestRunner(TestCase):
|
||||||
# 1.1 test `save_checkpoint` which is called by `CheckpointHook`
|
# 1.1 test `save_checkpoint` which is called by `CheckpointHook`
|
||||||
path = osp.join(self.temp_dir, 'epoch_3.pth')
|
path = osp.join(self.temp_dir, 'epoch_3.pth')
|
||||||
self.assertTrue(osp.exists(path))
|
self.assertTrue(osp.exists(path))
|
||||||
self.assertTrue(osp.exists(osp.join(self.temp_dir, 'latest.pth')))
|
|
||||||
self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_4.pth')))
|
self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_4.pth')))
|
||||||
|
|
||||||
ckpt = torch.load(path)
|
ckpt = torch.load(path)
|
||||||
|
@ -1672,7 +1671,6 @@ class TestRunner(TestCase):
|
||||||
# 2.1 test `save_checkpoint` which is called by `CheckpointHook`
|
# 2.1 test `save_checkpoint` which is called by `CheckpointHook`
|
||||||
path = osp.join(self.temp_dir, 'iter_12.pth')
|
path = osp.join(self.temp_dir, 'iter_12.pth')
|
||||||
self.assertTrue(osp.exists(path))
|
self.assertTrue(osp.exists(path))
|
||||||
self.assertTrue(osp.exists(osp.join(self.temp_dir, 'latest.pth')))
|
|
||||||
self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_13.pth')))
|
self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_13.pth')))
|
||||||
|
|
||||||
ckpt = torch.load(path)
|
ckpt = torch.load(path)
|
||||||
|
|
Loading…
Reference in New Issue