mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature] add auto resume (#1172)
* [Feature] add auto resume * Update mmseg/utils/find_latest_checkpoint.py Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> * Update mmseg/utils/find_latest_checkpoint.py Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> * modify docstring * Update mmseg/utils/find_latest_checkpoint.py Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> * add copyright Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>
This commit is contained in:
parent
ae51615541
commit
43ad37b478
@ -11,7 +11,7 @@ from mmcv.utils import build_from_cfg
|
|||||||
|
|
||||||
from mmseg.core import DistEvalHook, EvalHook
|
from mmseg.core import DistEvalHook, EvalHook
|
||||||
from mmseg.datasets import build_dataloader, build_dataset
|
from mmseg.datasets import build_dataloader, build_dataset
|
||||||
from mmseg.utils import get_root_logger
|
from mmseg.utils import find_latest_checkpoint, get_root_logger
|
||||||
|
|
||||||
|
|
||||||
def init_random_seed(seed=None, device='cuda'):
|
def init_random_seed(seed=None, device='cuda'):
|
||||||
@ -160,6 +160,10 @@ def train_segmentor(model,
|
|||||||
hook = build_from_cfg(hook_cfg, HOOKS)
|
hook = build_from_cfg(hook_cfg, HOOKS)
|
||||||
runner.register_hook(hook, priority=priority)
|
runner.register_hook(hook, priority=priority)
|
||||||
|
|
||||||
|
if cfg.resume_from is None and cfg.get('auto_resume'):
|
||||||
|
resume_from = find_latest_checkpoint(cfg.work_dir)
|
||||||
|
if resume_from is not None:
|
||||||
|
cfg.resume_from = resume_from
|
||||||
if cfg.resume_from:
|
if cfg.resume_from:
|
||||||
runner.resume(cfg.resume_from)
|
runner.resume(cfg.resume_from)
|
||||||
elif cfg.load_from:
|
elif cfg.load_from:
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .collect_env import collect_env
|
from .collect_env import collect_env
|
||||||
from .logger import get_root_logger
|
from .logger import get_root_logger
|
||||||
|
from .misc import find_latest_checkpoint
|
||||||
|
|
||||||
__all__ = ['get_root_logger', 'collect_env']
|
__all__ = ['get_root_logger', 'collect_env', 'find_latest_checkpoint']
|
||||||
|
41
mmseg/utils/misc.py
Normal file
41
mmseg/utils/misc.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import glob
|
||||||
|
import os.path as osp
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
|
def find_latest_checkpoint(path, suffix='pth'):
|
||||||
|
"""This function is for finding the latest checkpoint.
|
||||||
|
|
||||||
|
It will be used when automatically resume, modified from
|
||||||
|
https://github.com/open-mmlab/mmdetection/blob/dev-v2.20.0/mmdet/utils/misc.py
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): The path to find checkpoints.
|
||||||
|
suffix (str): File extension for the checkpoint. Defaults to pth.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
latest_path(str | None): File path of the latest checkpoint.
|
||||||
|
"""
|
||||||
|
if not osp.exists(path):
|
||||||
|
warnings.warn("The path of the checkpoints doesn't exist.")
|
||||||
|
return None
|
||||||
|
if osp.exists(osp.join(path, f'latest.{suffix}')):
|
||||||
|
return osp.join(path, f'latest.{suffix}')
|
||||||
|
|
||||||
|
checkpoints = glob.glob(osp.join(path, f'*.{suffix}'))
|
||||||
|
if len(checkpoints) == 0:
|
||||||
|
warnings.warn('The are no checkpoints in the path')
|
||||||
|
return None
|
||||||
|
latest = -1
|
||||||
|
latest_path = ''
|
||||||
|
for checkpoint in checkpoints:
|
||||||
|
if len(checkpoint) < len(latest_path):
|
||||||
|
continue
|
||||||
|
# `count` is iteration number, as checkpoints are saved as
|
||||||
|
# 'iter_xx.pth' or 'epoch_xx.pth' and xx is iteration number.
|
||||||
|
count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0])
|
||||||
|
if count > latest:
|
||||||
|
latest = count
|
||||||
|
latest_path = checkpoint
|
||||||
|
return latest_path
|
40
tests/test_utils/test_misc.py
Normal file
40
tests/test_utils/test_misc.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import os.path as osp
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
from mmseg.utils import find_latest_checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_latest_checkpoint():
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
# no checkpoints in the path
|
||||||
|
path = tempdir
|
||||||
|
latest = find_latest_checkpoint(path)
|
||||||
|
assert latest is None
|
||||||
|
|
||||||
|
# The path doesn't exist
|
||||||
|
path = osp.join(tempdir, 'none')
|
||||||
|
latest = find_latest_checkpoint(path)
|
||||||
|
assert latest is None
|
||||||
|
|
||||||
|
# test when latest.pth exists
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
with open(osp.join(tempdir, 'latest.pth'), 'w') as f:
|
||||||
|
f.write('latest')
|
||||||
|
path = tempdir
|
||||||
|
latest = find_latest_checkpoint(path)
|
||||||
|
assert latest == osp.join(tempdir, 'latest.pth')
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
for iter in range(1600, 160001, 1600):
|
||||||
|
with open(osp.join(tempdir, f'iter_{iter}.pth'), 'w') as f:
|
||||||
|
f.write(f'iter_{iter}.pth')
|
||||||
|
latest = find_latest_checkpoint(tempdir)
|
||||||
|
assert latest == osp.join(tempdir, 'iter_160000.pth')
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
for epoch in range(1, 21):
|
||||||
|
with open(osp.join(tempdir, f'epoch_{epoch}.pth'), 'w') as f:
|
||||||
|
f.write(f'epoch_{epoch}.pth')
|
||||||
|
latest = find_latest_checkpoint(tempdir)
|
||||||
|
assert latest == osp.join(tempdir, 'epoch_20.pth')
|
@ -75,6 +75,10 @@ def parse_args():
|
|||||||
default='none',
|
default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
parser.add_argument(
|
||||||
|
'--auto-resume',
|
||||||
|
action='store_true',
|
||||||
|
help='resume from the latest checkpoint automatically.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if 'LOCAL_RANK' not in os.environ:
|
if 'LOCAL_RANK' not in os.environ:
|
||||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||||
@ -118,6 +122,7 @@ def main():
|
|||||||
cfg.gpu_ids = args.gpu_ids
|
cfg.gpu_ids = args.gpu_ids
|
||||||
else:
|
else:
|
||||||
cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
|
cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
|
||||||
|
cfg.auto_resume = args.auto_resume
|
||||||
|
|
||||||
# init distributed env first, since logger depends on the dist info.
|
# init distributed env first, since logger depends on the dist info.
|
||||||
if args.launcher == 'none':
|
if args.launcher == 'none':
|
||||||
|
Loading…
x
Reference in New Issue
Block a user