diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index 08152877..334d544f 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -1,6 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp -import warnings from pathlib import Path from typing import Optional, Sequence, Union @@ -95,19 +94,6 @@ class CheckpointHook(Hook): runner.logger.info(f'Checkpoints will be saved to {self.out_dir} by ' f'{self.file_client.name}.') - # disable the create_symlink option because some file backends do not - # allow to create a symlink - if 'create_symlink' in self.args: - if self.args[ - 'create_symlink'] and not self.file_client.allow_symlink: - self.args['create_symlink'] = False - warnings.warn( - 'create_symlink is set as True by the user but is changed' - 'to be False because creating symbolic link is not ' - f'allowed in {self.file_client.name}') - else: - self.args['create_symlink'] = self.file_client.allow_symlink - def after_train_epoch(self, runner) -> None: """Save the checkpoint and synchronize buffers after each epoch. @@ -142,7 +128,8 @@ class CheckpointHook(Hook): runner.save_checkpoint( self.out_dir, - filename=ckpt_filename, + ckpt_filename, + self.file_client_args, save_optimizer=self.save_optimizer, save_param_scheduler=self.save_param_scheduler, by_epoch=self.by_epoch, diff --git a/mmengine/runner/__init__.py b/mmengine/runner/__init__.py index eae1b920..043c56fd 100644 --- a/mmengine/runner/__init__.py +++ b/mmengine/runner/__init__.py @@ -1,7 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from .base_loop import BaseLoop -from .checkpoint import (CheckpointLoader, get_deprecated_model_names, - get_external_models, get_mmcls_models, get_state_dict, +from .checkpoint import (CheckpointLoader, find_latest_checkpoint, + get_deprecated_model_names, get_external_models, + get_mmcls_models, get_state_dict, get_torchvision_models, load_checkpoint, load_state_dict, save_checkpoint, weights_to_cpu) from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop @@ -12,5 +13,5 @@ __all__ = [ 'get_external_models', 'get_mmcls_models', 'get_deprecated_model_names', 'CheckpointLoader', 'load_checkpoint', 'weights_to_cpu', 'get_state_dict', 'save_checkpoint', 'EpochBasedTrainLoop', 'IterBasedTrainLoop', 'ValLoop', - 'TestLoop', 'Runner' + 'TestLoop', 'Runner', 'find_latest_checkpoint' ] diff --git a/mmengine/runner/checkpoint.py b/mmengine/runner/checkpoint.py index 734087a2..058e20eb 100644 --- a/mmengine/runner/checkpoint.py +++ b/mmengine/runner/checkpoint.py @@ -695,3 +695,26 @@ def save_checkpoint(checkpoint, filename, file_client_args=None): with io.BytesIO() as f: torch.save(checkpoint, f) file_client.put(f.getvalue(), filename) + + +def find_latest_checkpoint(path: str): + """Find the latest checkpoint from the given path. + + Refer to https://github.com/facebookresearch/fvcore/blob/main/fvcore/common/checkpoint.py # noqa: E501 + + Args: + path(str): The path to find checkpoints. + + Returns: + str or None: File path of the latest checkpoint. + """ + save_file = osp.join(path, 'last_checkpoint') + try: + with open(save_file) as f: + last_saved = f.read().strip() + except OSError: + raise OSError( + 'last_checkpoint file does not exist, maybe because it has just' + ' been deleted by a separate process') + + return last_saved diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 89d5e646..09863bdd 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -4,7 +4,6 @@ import os.path as osp import platform import random import resource -import shutil import time import warnings from collections import OrderedDict @@ -25,6 +24,7 @@ from mmengine.device import get_device from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist, master_only, sync_random_seed) from mmengine.evaluator import Evaluator +from mmengine.fileio import FileClient from mmengine.hooks import Hook from mmengine.logging import LogProcessor, MessageHub, MMLogger from mmengine.model import (BaseModel, MMDistributedDataParallel, @@ -36,13 +36,13 @@ from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS, RUNNERS, VISUALIZERS, DefaultScope, count_registered_modules) from mmengine.registry.root import LOG_PROCESSORS -from mmengine.utils import (TORCH_VERSION, digit_version, - find_latest_checkpoint, get_git_hash, is_list_of, - set_multi_processing, symlink) +from mmengine.utils import (TORCH_VERSION, digit_version, get_git_hash, + is_list_of, set_multi_processing) from mmengine.visualization import Visualizer from .base_loop import BaseLoop from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model, - get_state_dict, save_checkpoint, weights_to_cpu) + find_latest_checkpoint, get_state_dict, + save_checkpoint, weights_to_cpu) from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop from .priority import Priority, get_priority @@ -1920,10 +1920,10 @@ class Runner: def save_checkpoint(self, out_dir: str, filename: str, + file_client_args: Optional[dict] = None, save_optimizer: bool = True, save_param_scheduler: bool = True, meta: dict = None, - create_symlink: bool = True, by_epoch: bool = True): """Save checkpoints. @@ -1933,15 +1933,16 @@ class Runner: Args: out_dir (str): The directory that checkpoints are saved. filename (str): The checkpoint filename. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. Default: None. save_optimizer (bool): Whether to save the optimizer to the checkpoint. Defaults to True. save_param_scheduler (bool): Whether to save the param_scheduler to the checkpoint. Defaults to True. meta (dict, optional): The meta information to be saved in the checkpoint. Defaults to None. - create_symlink (bool): Whether to create a symlink - "latest.pth" to point to the latest checkpoint. - Defaults to True. + by_epoch (bool): Whether the scheduled momentum is updated by + epochs. Defaults to True. """ if meta is None: meta = {} @@ -1961,7 +1962,8 @@ class Runner: else: meta.update(epoch=self.epoch, iter=self.iter + 1) - filepath = osp.join(out_dir, filename) + file_client = FileClient.infer_client(file_client_args, out_dir) + filepath = file_client.join_path(out_dir, filename) meta.update( cfg=self.cfg.pretty_text, @@ -2007,14 +2009,10 @@ class Runner: self.call_hook('before_save_checkpoint', checkpoint=checkpoint) save_checkpoint(checkpoint, filepath) - # in some environments, `os.symlink` is not supported, you may need to - # set `create_symlink` to False - if create_symlink: - dst_file = osp.join(out_dir, 'latest.pth') - if platform.system() != 'Windows': - symlink(filename, dst_file) - else: - shutil.copy(filepath, dst_file) + + save_file = osp.join(self.work_dir, 'last_checkpoint') + with open(save_file, 'w') as f: + f.write(filepath) @master_only def dump_config(self) -> None: diff --git a/mmengine/utils/__init__.py b/mmengine/utils/__init__.py index 56483df3..0375578b 100644 --- a/mmengine/utils/__init__.py +++ b/mmengine/utils/__init__.py @@ -2,10 +2,9 @@ from .hub import load_url from .manager import ManagerMeta, ManagerMixin from .misc import (check_prerequisites, concat_list, deprecated_api_warning, - find_latest_checkpoint, has_batch_norm, has_method, - import_modules_from_strings, is_list_of, - is_method_overridden, is_seq_of, is_str, is_tuple_of, - iter_cast, list_cast, mmcv_full_available, + has_batch_norm, has_method, import_modules_from_strings, + is_list_of, is_method_overridden, is_seq_of, is_str, + is_tuple_of, iter_cast, list_cast, mmcv_full_available, requires_executable, requires_package, slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple, to_ntuple, tuple_cast) @@ -27,6 +26,6 @@ __all__ = [ 'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple', 'is_method_overridden', 'has_method', 'mmcv_full_available', 'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url', - 'find_latest_checkpoint', 'ManagerMeta', 'ManagerMixin', - 'set_multi_processing', 'has_batch_norm', 'is_abs' + 'ManagerMeta', 'ManagerMixin', 'set_multi_processing', 'has_batch_norm', + 'is_abs' ] diff --git a/mmengine/utils/misc.py b/mmengine/utils/misc.py index 5e151561..fb8de1fc 100644 --- a/mmengine/utils/misc.py +++ b/mmengine/utils/misc.py @@ -1,9 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import collections.abc import functools -import glob import itertools -import os.path as osp import pkgutil import subprocess import warnings @@ -481,40 +479,6 @@ def tensor2imgs(tensor: torch.Tensor, return imgs -def find_latest_checkpoint(path: str, suffix: str = 'pth'): - """Find the latest checkpoint from the given path. - - Refer to https://github.com/microsoft/SoftTeacher/blob/main/ssod/utils/patch.py # noqa: E501 - - Args: - path(str): The path to find checkpoints. - suffix(str): File extension. Defaults to 'pth'. - - Returns: - str or None: File path of the latest checkpoint. - """ - if not osp.exists(path): - raise FileNotFoundError('{path} does not exist.') - - if osp.exists(osp.join(path, f'latest.{suffix}')): - return osp.join(path, f'latest.{suffix}') - - checkpoints = glob.glob(osp.join(path, f'*.{suffix}')) - if len(checkpoints) == 0: - raise FileNotFoundError(f'checkpoints can not be found in {path}. ' - 'Maybe check the suffix again.') - - latest = -1 - latest_path = None - for checkpoint in checkpoints: - count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0]) - if count > latest: - latest = count - latest_path = checkpoint - - return latest_path - - def has_batch_norm(model: nn.Module) -> bool: """Detect whether model has a BatchNormalization layer. diff --git a/tests/test_hook/test_checkpoint_hook.py b/tests/test_hook/test_checkpoint_hook.py index 10f682cd..b9267fc0 100644 --- a/tests/test_hook/test_checkpoint_hook.py +++ b/tests/test_hook/test_checkpoint_hook.py @@ -46,18 +46,6 @@ class TestCheckpointHook: assert checkpoint_hook.out_dir == ( f'test_dir/{osp.basename(work_dir)}') - # create_symlink in args and create_symlink is True - checkpoint_hook = CheckpointHook( - interval=1, by_epoch=True, out_dir='test_dir', create_symlink=True) - checkpoint_hook.before_train(runner) - assert checkpoint_hook.args['create_symlink'] - - runner.work_dir = 's3://path/of/file' - checkpoint_hook = CheckpointHook( - interval=1, by_epoch=True, create_symlink=True) - checkpoint_hook.before_train(runner) - assert not checkpoint_hook.args['create_symlink'] - def test_after_train_epoch(self, tmp_path): runner = Mock() work_dir = str(tmp_path) diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index 5dd91996..440bd6ca 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -1505,7 +1505,6 @@ class TestRunner(TestCase): # 1.1 test `save_checkpoint` which is called by `CheckpointHook` path = osp.join(self.temp_dir, 'epoch_3.pth') self.assertTrue(osp.exists(path)) - self.assertTrue(osp.exists(osp.join(self.temp_dir, 'latest.pth'))) self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_4.pth'))) ckpt = torch.load(path) @@ -1672,7 +1671,6 @@ class TestRunner(TestCase): # 2.1 test `save_checkpoint` which is called by `CheckpointHook` path = osp.join(self.temp_dir, 'iter_12.pth') self.assertTrue(osp.exists(path)) - self.assertTrue(osp.exists(osp.join(self.temp_dir, 'latest.pth'))) self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_13.pth'))) ckpt = torch.load(path)