[Feature] add auto resume (#1172)

* [Feature] add auto resume

* Update mmseg/utils/find_latest_checkpoint.py

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>

* Update mmseg/utils/find_latest_checkpoint.py

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>

* modify docstring

* Update mmseg/utils/find_latest_checkpoint.py

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>

* add copyright

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>
pull/1801/head
Rockey 2022-01-11 12:27:24 +08:00 committed by GitHub
parent ae51615541
commit 43ad37b478
5 changed files with 93 additions and 2 deletions

View File

@ -11,7 +11,7 @@ from mmcv.utils import build_from_cfg
from mmseg.core import DistEvalHook, EvalHook
from mmseg.datasets import build_dataloader, build_dataset
from mmseg.utils import get_root_logger
from mmseg.utils import find_latest_checkpoint, get_root_logger
def init_random_seed(seed=None, device='cuda'):
@ -160,6 +160,10 @@ def train_segmentor(model,
hook = build_from_cfg(hook_cfg, HOOKS)
runner.register_hook(hook, priority=priority)
if cfg.resume_from is None and cfg.get('auto_resume'):
resume_from = find_latest_checkpoint(cfg.work_dir)
if resume_from is not None:
cfg.resume_from = resume_from
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .collect_env import collect_env
from .logger import get_root_logger
from .misc import find_latest_checkpoint
__all__ = ['get_root_logger', 'collect_env']
__all__ = ['get_root_logger', 'collect_env', 'find_latest_checkpoint']

View File

@ -0,0 +1,41 @@
# Copyright (c) OpenMMLab. All rights reserved.
import glob
import os.path as osp
import warnings
def find_latest_checkpoint(path, suffix='pth'):
"""This function is for finding the latest checkpoint.
It will be used when automatically resume, modified from
https://github.com/open-mmlab/mmdetection/blob/dev-v2.20.0/mmdet/utils/misc.py
Args:
path (str): The path to find checkpoints.
suffix (str): File extension for the checkpoint. Defaults to pth.
Returns:
latest_path(str | None): File path of the latest checkpoint.
"""
if not osp.exists(path):
warnings.warn("The path of the checkpoints doesn't exist.")
return None
if osp.exists(osp.join(path, f'latest.{suffix}')):
return osp.join(path, f'latest.{suffix}')
checkpoints = glob.glob(osp.join(path, f'*.{suffix}'))
if len(checkpoints) == 0:
warnings.warn('The are no checkpoints in the path')
return None
latest = -1
latest_path = ''
for checkpoint in checkpoints:
if len(checkpoint) < len(latest_path):
continue
# `count` is iteration number, as checkpoints are saved as
# 'iter_xx.pth' or 'epoch_xx.pth' and xx is iteration number.
count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0])
if count > latest:
latest = count
latest_path = checkpoint
return latest_path

View File

@ -0,0 +1,40 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
from mmseg.utils import find_latest_checkpoint
def test_find_latest_checkpoint():
with tempfile.TemporaryDirectory() as tempdir:
# no checkpoints in the path
path = tempdir
latest = find_latest_checkpoint(path)
assert latest is None
# The path doesn't exist
path = osp.join(tempdir, 'none')
latest = find_latest_checkpoint(path)
assert latest is None
# test when latest.pth exists
with tempfile.TemporaryDirectory() as tempdir:
with open(osp.join(tempdir, 'latest.pth'), 'w') as f:
f.write('latest')
path = tempdir
latest = find_latest_checkpoint(path)
assert latest == osp.join(tempdir, 'latest.pth')
with tempfile.TemporaryDirectory() as tempdir:
for iter in range(1600, 160001, 1600):
with open(osp.join(tempdir, f'iter_{iter}.pth'), 'w') as f:
f.write(f'iter_{iter}.pth')
latest = find_latest_checkpoint(tempdir)
assert latest == osp.join(tempdir, 'iter_160000.pth')
with tempfile.TemporaryDirectory() as tempdir:
for epoch in range(1, 21):
with open(osp.join(tempdir, f'epoch_{epoch}.pth'), 'w') as f:
f.write(f'epoch_{epoch}.pth')
latest = find_latest_checkpoint(tempdir)
assert latest == osp.join(tempdir, 'epoch_20.pth')

View File

@ -75,6 +75,10 @@ def parse_args():
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument(
'--auto-resume',
action='store_true',
help='resume from the latest checkpoint automatically.')
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
@ -118,6 +122,7 @@ def main():
cfg.gpu_ids = args.gpu_ids
else:
cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
cfg.auto_resume = args.auto_resume
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':