diff --git a/mmseg/apis/train.py b/mmseg/apis/train.py index 7e1096bce..5d5bb9c08 100644 --- a/mmseg/apis/train.py +++ b/mmseg/apis/train.py @@ -11,7 +11,7 @@ from mmcv.utils import build_from_cfg from mmseg.core import DistEvalHook, EvalHook from mmseg.datasets import build_dataloader, build_dataset -from mmseg.utils import get_root_logger +from mmseg.utils import find_latest_checkpoint, get_root_logger def init_random_seed(seed=None, device='cuda'): @@ -160,6 +160,10 @@ def train_segmentor(model, hook = build_from_cfg(hook_cfg, HOOKS) runner.register_hook(hook, priority=priority) + if cfg.resume_from is None and cfg.get('auto_resume'): + resume_from = find_latest_checkpoint(cfg.work_dir) + if resume_from is not None: + cfg.resume_from = resume_from if cfg.resume_from: runner.resume(cfg.resume_from) elif cfg.load_from: diff --git a/mmseg/utils/__init__.py b/mmseg/utils/__init__.py index 3f1558052..4b34f4c38 100644 --- a/mmseg/utils/__init__.py +++ b/mmseg/utils/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .collect_env import collect_env from .logger import get_root_logger +from .misc import find_latest_checkpoint -__all__ = ['get_root_logger', 'collect_env'] +__all__ = ['get_root_logger', 'collect_env', 'find_latest_checkpoint'] diff --git a/mmseg/utils/misc.py b/mmseg/utils/misc.py new file mode 100644 index 000000000..bd1b6b163 --- /dev/null +++ b/mmseg/utils/misc.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import glob +import os.path as osp +import warnings + + +def find_latest_checkpoint(path, suffix='pth'): + """This function is for finding the latest checkpoint. + + It will be used when automatically resume, modified from + https://github.com/open-mmlab/mmdetection/blob/dev-v2.20.0/mmdet/utils/misc.py + + Args: + path (str): The path to find checkpoints. + suffix (str): File extension for the checkpoint. Defaults to pth. + + Returns: + latest_path(str | None): File path of the latest checkpoint. + """ + if not osp.exists(path): + warnings.warn("The path of the checkpoints doesn't exist.") + return None + if osp.exists(osp.join(path, f'latest.{suffix}')): + return osp.join(path, f'latest.{suffix}') + + checkpoints = glob.glob(osp.join(path, f'*.{suffix}')) + if len(checkpoints) == 0: + warnings.warn('The are no checkpoints in the path') + return None + latest = -1 + latest_path = '' + for checkpoint in checkpoints: + if len(checkpoint) < len(latest_path): + continue + # `count` is iteration number, as checkpoints are saved as + # 'iter_xx.pth' or 'epoch_xx.pth' and xx is iteration number. + count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0]) + if count > latest: + latest = count + latest_path = checkpoint + return latest_path diff --git a/tests/test_utils/test_misc.py b/tests/test_utils/test_misc.py new file mode 100644 index 000000000..7ce1fa614 --- /dev/null +++ b/tests/test_utils/test_misc.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import tempfile + +from mmseg.utils import find_latest_checkpoint + + +def test_find_latest_checkpoint(): + with tempfile.TemporaryDirectory() as tempdir: + # no checkpoints in the path + path = tempdir + latest = find_latest_checkpoint(path) + assert latest is None + + # The path doesn't exist + path = osp.join(tempdir, 'none') + latest = find_latest_checkpoint(path) + assert latest is None + + # test when latest.pth exists + with tempfile.TemporaryDirectory() as tempdir: + with open(osp.join(tempdir, 'latest.pth'), 'w') as f: + f.write('latest') + path = tempdir + latest = find_latest_checkpoint(path) + assert latest == osp.join(tempdir, 'latest.pth') + + with tempfile.TemporaryDirectory() as tempdir: + for iter in range(1600, 160001, 1600): + with open(osp.join(tempdir, f'iter_{iter}.pth'), 'w') as f: + f.write(f'iter_{iter}.pth') + latest = find_latest_checkpoint(tempdir) + assert latest == osp.join(tempdir, 'iter_160000.pth') + + with tempfile.TemporaryDirectory() as tempdir: + for epoch in range(1, 21): + with open(osp.join(tempdir, f'epoch_{epoch}.pth'), 'w') as f: + f.write(f'epoch_{epoch}.pth') + latest = find_latest_checkpoint(tempdir) + assert latest == osp.join(tempdir, 'epoch_20.pth') diff --git a/tools/train.py b/tools/train.py index 2e0b6e91f..81c7d854e 100644 --- a/tools/train.py +++ b/tools/train.py @@ -75,6 +75,10 @@ def parse_args(): default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument( + '--auto-resume', + action='store_true', + help='resume from the latest checkpoint automatically.') args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank) @@ -118,6 +122,7 @@ def main(): cfg.gpu_ids = args.gpu_ids else: cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) + cfg.auto_resume = args.auto_resume # init distributed env first, since logger depends on the dist info. if args.launcher == 'none':