[Feature] add auto resume (#1172)
* [Feature] add auto resume * Update mmseg/utils/find_latest_checkpoint.py Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> * Update mmseg/utils/find_latest_checkpoint.py Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> * modify docstring * Update mmseg/utils/find_latest_checkpoint.py Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> * add copyright Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>pull/1801/head
parent
ae51615541
commit
43ad37b478
|
@ -11,7 +11,7 @@ from mmcv.utils import build_from_cfg
|
|||
|
||||
from mmseg.core import DistEvalHook, EvalHook
|
||||
from mmseg.datasets import build_dataloader, build_dataset
|
||||
from mmseg.utils import get_root_logger
|
||||
from mmseg.utils import find_latest_checkpoint, get_root_logger
|
||||
|
||||
|
||||
def init_random_seed(seed=None, device='cuda'):
|
||||
|
@ -160,6 +160,10 @@ def train_segmentor(model,
|
|||
hook = build_from_cfg(hook_cfg, HOOKS)
|
||||
runner.register_hook(hook, priority=priority)
|
||||
|
||||
if cfg.resume_from is None and cfg.get('auto_resume'):
|
||||
resume_from = find_latest_checkpoint(cfg.work_dir)
|
||||
if resume_from is not None:
|
||||
cfg.resume_from = resume_from
|
||||
if cfg.resume_from:
|
||||
runner.resume(cfg.resume_from)
|
||||
elif cfg.load_from:
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .collect_env import collect_env
|
||||
from .logger import get_root_logger
|
||||
from .misc import find_latest_checkpoint
|
||||
|
||||
__all__ = ['get_root_logger', 'collect_env']
|
||||
__all__ = ['get_root_logger', 'collect_env', 'find_latest_checkpoint']
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import glob
|
||||
import os.path as osp
|
||||
import warnings
|
||||
|
||||
|
||||
def find_latest_checkpoint(path, suffix='pth'):
|
||||
"""This function is for finding the latest checkpoint.
|
||||
|
||||
It will be used when automatically resume, modified from
|
||||
https://github.com/open-mmlab/mmdetection/blob/dev-v2.20.0/mmdet/utils/misc.py
|
||||
|
||||
Args:
|
||||
path (str): The path to find checkpoints.
|
||||
suffix (str): File extension for the checkpoint. Defaults to pth.
|
||||
|
||||
Returns:
|
||||
latest_path(str | None): File path of the latest checkpoint.
|
||||
"""
|
||||
if not osp.exists(path):
|
||||
warnings.warn("The path of the checkpoints doesn't exist.")
|
||||
return None
|
||||
if osp.exists(osp.join(path, f'latest.{suffix}')):
|
||||
return osp.join(path, f'latest.{suffix}')
|
||||
|
||||
checkpoints = glob.glob(osp.join(path, f'*.{suffix}'))
|
||||
if len(checkpoints) == 0:
|
||||
warnings.warn('The are no checkpoints in the path')
|
||||
return None
|
||||
latest = -1
|
||||
latest_path = ''
|
||||
for checkpoint in checkpoints:
|
||||
if len(checkpoint) < len(latest_path):
|
||||
continue
|
||||
# `count` is iteration number, as checkpoints are saved as
|
||||
# 'iter_xx.pth' or 'epoch_xx.pth' and xx is iteration number.
|
||||
count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0])
|
||||
if count > latest:
|
||||
latest = count
|
||||
latest_path = checkpoint
|
||||
return latest_path
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
|
||||
from mmseg.utils import find_latest_checkpoint
|
||||
|
||||
|
||||
def test_find_latest_checkpoint():
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
# no checkpoints in the path
|
||||
path = tempdir
|
||||
latest = find_latest_checkpoint(path)
|
||||
assert latest is None
|
||||
|
||||
# The path doesn't exist
|
||||
path = osp.join(tempdir, 'none')
|
||||
latest = find_latest_checkpoint(path)
|
||||
assert latest is None
|
||||
|
||||
# test when latest.pth exists
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
with open(osp.join(tempdir, 'latest.pth'), 'w') as f:
|
||||
f.write('latest')
|
||||
path = tempdir
|
||||
latest = find_latest_checkpoint(path)
|
||||
assert latest == osp.join(tempdir, 'latest.pth')
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
for iter in range(1600, 160001, 1600):
|
||||
with open(osp.join(tempdir, f'iter_{iter}.pth'), 'w') as f:
|
||||
f.write(f'iter_{iter}.pth')
|
||||
latest = find_latest_checkpoint(tempdir)
|
||||
assert latest == osp.join(tempdir, 'iter_160000.pth')
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
for epoch in range(1, 21):
|
||||
with open(osp.join(tempdir, f'epoch_{epoch}.pth'), 'w') as f:
|
||||
f.write(f'epoch_{epoch}.pth')
|
||||
latest = find_latest_checkpoint(tempdir)
|
||||
assert latest == osp.join(tempdir, 'epoch_20.pth')
|
|
@ -75,6 +75,10 @@ def parse_args():
|
|||
default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
parser.add_argument(
|
||||
'--auto-resume',
|
||||
action='store_true',
|
||||
help='resume from the latest checkpoint automatically.')
|
||||
args = parser.parse_args()
|
||||
if 'LOCAL_RANK' not in os.environ:
|
||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
|
@ -118,6 +122,7 @@ def main():
|
|||
cfg.gpu_ids = args.gpu_ids
|
||||
else:
|
||||
cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
|
||||
cfg.auto_resume = args.auto_resume
|
||||
|
||||
# init distributed env first, since logger depends on the dist info.
|
||||
if args.launcher == 'none':
|
||||
|
|
Loading…
Reference in New Issue