diff --git a/mmrazor/apis/mmcls/train.py b/mmrazor/apis/mmcls/train.py index b0aab724..ec8526d8 100644 --- a/mmrazor/apis/mmcls/train.py +++ b/mmrazor/apis/mmcls/train.py @@ -16,6 +16,7 @@ from mmrazor.core.distributed_wrapper import DistributedDataParallelWrapper from mmrazor.core.hooks import DistSamplerSeedHook from mmrazor.core.optimizer import build_optimizers from mmrazor.datasets.utils import split_dataset +from mmrazor.utils import find_latest_checkpoint def set_random_seed(seed, deterministic=False): @@ -190,6 +191,12 @@ def train_model(model, runner.register_hook( eval_hook(val_dataloader, **eval_cfg), priority='LOW') + resume_from = None + if cfg.resume_from is None and cfg.get('auto_resume'): + resume_from = find_latest_checkpoint(cfg.work_dir) + if resume_from is not None: + cfg.resume_from = resume_from + if cfg.resume_from: runner.resume(cfg.resume_from) elif cfg.load_from: diff --git a/mmrazor/apis/mmdet/train.py b/mmrazor/apis/mmdet/train.py index d3d5cdb4..4f606f47 100644 --- a/mmrazor/apis/mmdet/train.py +++ b/mmrazor/apis/mmdet/train.py @@ -15,6 +15,7 @@ from mmdet.utils import get_root_logger from mmrazor.core.distributed_wrapper import DistributedDataParallelWrapper from mmrazor.core.hooks import DistSamplerSeedHook from mmrazor.core.optimizer import build_optimizers +from mmrazor.utils import find_latest_checkpoint def set_random_seed(seed, deterministic=False): @@ -181,6 +182,12 @@ def train_detector(model, runner.register_hook( eval_hook(val_dataloader, **eval_cfg), priority='LOW') + resume_from = None + if cfg.resume_from is None and cfg.get('auto_resume'): + resume_from = find_latest_checkpoint(cfg.work_dir) + if resume_from is not None: + cfg.resume_from = resume_from + if cfg.resume_from: runner.resume(cfg.resume_from) elif cfg.load_from: diff --git a/mmrazor/apis/mmseg/train.py b/mmrazor/apis/mmseg/train.py index 9933cb17..fb96d704 100644 --- a/mmrazor/apis/mmseg/train.py +++ b/mmrazor/apis/mmseg/train.py @@ -12,6 +12,7 @@ from mmseg.utils import get_root_logger from mmrazor.core.distributed_wrapper import DistributedDataParallelWrapper from mmrazor.core.optimizer import build_optimizers +from mmrazor.utils import find_latest_checkpoint def set_random_seed(seed, deterministic=False): @@ -137,6 +138,12 @@ def train_segmentor(model, runner.register_hook( eval_hook(val_dataloader, **eval_cfg), priority='LOW') + resume_from = None + if cfg.resume_from is None and cfg.get('auto_resume'): + resume_from = find_latest_checkpoint(cfg.work_dir) + if resume_from is not None: + cfg.resume_from = resume_from + if cfg.resume_from: runner.resume(cfg.resume_from) elif cfg.load_from: diff --git a/mmrazor/utils/__init__.py b/mmrazor/utils/__init__.py index c8ab31c1..4d874811 100644 --- a/mmrazor/utils/__init__.py +++ b/mmrazor/utils/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .misc import find_latest_checkpoint from .setup_env import setup_multi_processes -__all__ = ['setup_multi_processes'] +__all__ = ['find_latest_checkpoint', 'setup_multi_processes'] diff --git a/mmrazor/utils/misc.py b/mmrazor/utils/misc.py new file mode 100644 index 00000000..1c12b21e --- /dev/null +++ b/mmrazor/utils/misc.py @@ -0,0 +1,38 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import glob +import os.path as osp +import warnings + + +def find_latest_checkpoint(path, suffix='pth'): + """Find the latest checkpoint from the working directory. + + Args: + path(str): The path to find checkpoints. + suffix(str): File extension. Defaults to pth. + + Returns: + latest_path(str | None): File path of the latest checkpoint. + + References: + .. [1] https://github.com/microsoft/SoftTeacher + /blob/main/ssod/utils/patch.py + """ + if not osp.exists(path): + warnings.warn('The path of checkpoints does not exist.') + return None + if osp.exists(osp.join(path, f'latest.{suffix}')): + return osp.join(path, f'latest.{suffix}') + + checkpoints = glob.glob(osp.join(path, f'*.{suffix}')) + if len(checkpoints) == 0: + warnings.warn('There are no checkpoints in the path.') + return None + latest = -1 + latest_path = None + for checkpoint in checkpoints: + count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0]) + if count > latest: + latest = count + latest_path = checkpoint + return latest_path diff --git a/tests/test_utils/test_misc.py b/tests/test_utils/test_misc.py new file mode 100644 index 00000000..312ecd9e --- /dev/null +++ b/tests/test_utils/test_misc.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import tempfile + +from mmrazor.utils import find_latest_checkpoint + + +def test_find_latest_checkpoint(): + with tempfile.TemporaryDirectory() as tmpdir: + path = tmpdir + latest = find_latest_checkpoint(path) + # There are no checkpoints in the path. + assert latest is None + + path = tmpdir + '/none' + latest = find_latest_checkpoint(path) + # The path does not exist. + assert latest is None + + with tempfile.TemporaryDirectory() as tmpdir: + with open(tmpdir + '/latest.pth', 'w') as f: + f.write('latest') + path = tmpdir + latest = find_latest_checkpoint(path) + assert latest == osp.join(tmpdir, 'latest.pth') + + with tempfile.TemporaryDirectory() as tmpdir: + with open(tmpdir + '/iter_4000.pth', 'w') as f: + f.write('iter_4000') + with open(tmpdir + '/iter_8000.pth', 'w') as f: + f.write('iter_8000') + path = tmpdir + latest = find_latest_checkpoint(path) + assert latest == osp.join(tmpdir, 'iter_8000.pth') + + with tempfile.TemporaryDirectory() as tmpdir: + with open(tmpdir + '/epoch_1.pth', 'w') as f: + f.write('epoch_1') + with open(tmpdir + '/epoch_2.pth', 'w') as f: + f.write('epoch_2') + path = tmpdir + latest = find_latest_checkpoint(path) + assert latest == osp.join(tmpdir, 'epoch_2.pth') diff --git a/tools/mmcls/train_mmcls.py b/tools/mmcls/train_mmcls.py index 0377d6c3..d593061c 100644 --- a/tools/mmcls/train_mmcls.py +++ b/tools/mmcls/train_mmcls.py @@ -26,6 +26,10 @@ def parse_args(): parser.add_argument('--work-dir', help='the dir to save logs and models') parser.add_argument( '--resume-from', help='the checkpoint file to resume from') + parser.add_argument( + '--auto-resume', + action='store_true', + help='resume from the latest checkpoint automatically') parser.add_argument( '--no-validate', action='store_true', @@ -101,6 +105,7 @@ def main(): osp.splitext(osp.basename(args.config))[0]) if args.resume_from is not None: cfg.resume_from = args.resume_from + cfg.auto_resume = args.auto_resume if args.gpus is not None: cfg.gpu_ids = range(1) warnings.warn('`--gpus` is deprecated because we only support ' diff --git a/tools/mmdet/train_mmdet.py b/tools/mmdet/train_mmdet.py index 9273ccd5..1d6d74d5 100644 --- a/tools/mmdet/train_mmdet.py +++ b/tools/mmdet/train_mmdet.py @@ -34,6 +34,10 @@ def parse_args(): parser.add_argument('--work-dir', help='the dir to save logs and models') parser.add_argument( '--resume-from', help='the checkpoint file to resume from') + parser.add_argument( + '--auto-resume', + action='store_true', + help='resume from the latest checkpoint automatically') parser.add_argument( '--no-validate', action='store_true', @@ -112,6 +116,7 @@ def main(): osp.splitext(osp.basename(args.config))[0]) if args.resume_from is not None: cfg.resume_from = args.resume_from + cfg.auto_resume = args.auto_resume if args.gpus is not None: cfg.gpu_ids = range(1) warnings.warn('`--gpus` is deprecated because we only support ' diff --git a/tools/mmseg/train_mmseg.py b/tools/mmseg/train_mmseg.py index 47048415..559f8078 100644 --- a/tools/mmseg/train_mmseg.py +++ b/tools/mmseg/train_mmseg.py @@ -37,6 +37,10 @@ def parse_args(): '--load-from', help='the checkpoint file to load weights from') parser.add_argument( '--resume-from', help='the checkpoint file to resume from') + parser.add_argument( + '--auto-resume', + action='store_true', + help='resume from the latest checkpoint automatically') parser.add_argument( '--no-validate', action='store_true', @@ -114,6 +118,7 @@ def main(): cfg.load_from = args.load_from if args.resume_from is not None: cfg.resume_from = args.resume_from + cfg.auto_resume = args.auto_resume if args.gpus is not None: cfg.gpu_ids = range(1) warnings.warn('`--gpus` is deprecated because we only support '