[Feature] Resume from the latest checkpoint automatically. (#61)

* support auto-resume

* support auto-resume

* support auto-resume

* support auto-resume

Co-authored-by: pppppM <67539920+pppppM@users.noreply.github.com>
pull/132/head
whcao 2022-03-08 11:25:19 +08:00 committed by GitHub
parent 366fd0f095
commit 81e0e3452a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 119 additions and 1 deletions

View File

@ -16,6 +16,7 @@ from mmrazor.core.distributed_wrapper import DistributedDataParallelWrapper
from mmrazor.core.hooks import DistSamplerSeedHook
from mmrazor.core.optimizer import build_optimizers
from mmrazor.datasets.utils import split_dataset
from mmrazor.utils import find_latest_checkpoint
def set_random_seed(seed, deterministic=False):
@ -190,6 +191,12 @@ def train_model(model,
runner.register_hook(
eval_hook(val_dataloader, **eval_cfg), priority='LOW')
resume_from = None
if cfg.resume_from is None and cfg.get('auto_resume'):
resume_from = find_latest_checkpoint(cfg.work_dir)
if resume_from is not None:
cfg.resume_from = resume_from
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:

View File

@ -15,6 +15,7 @@ from mmdet.utils import get_root_logger
from mmrazor.core.distributed_wrapper import DistributedDataParallelWrapper
from mmrazor.core.hooks import DistSamplerSeedHook
from mmrazor.core.optimizer import build_optimizers
from mmrazor.utils import find_latest_checkpoint
def set_random_seed(seed, deterministic=False):
@ -181,6 +182,12 @@ def train_detector(model,
runner.register_hook(
eval_hook(val_dataloader, **eval_cfg), priority='LOW')
resume_from = None
if cfg.resume_from is None and cfg.get('auto_resume'):
resume_from = find_latest_checkpoint(cfg.work_dir)
if resume_from is not None:
cfg.resume_from = resume_from
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:

View File

@ -12,6 +12,7 @@ from mmseg.utils import get_root_logger
from mmrazor.core.distributed_wrapper import DistributedDataParallelWrapper
from mmrazor.core.optimizer import build_optimizers
from mmrazor.utils import find_latest_checkpoint
def set_random_seed(seed, deterministic=False):
@ -137,6 +138,12 @@ def train_segmentor(model,
runner.register_hook(
eval_hook(val_dataloader, **eval_cfg), priority='LOW')
resume_from = None
if cfg.resume_from is None and cfg.get('auto_resume'):
resume_from = find_latest_checkpoint(cfg.work_dir)
if resume_from is not None:
cfg.resume_from = resume_from
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .misc import find_latest_checkpoint
from .setup_env import setup_multi_processes
__all__ = ['setup_multi_processes']
__all__ = ['find_latest_checkpoint', 'setup_multi_processes']

View File

@ -0,0 +1,38 @@
# Copyright (c) OpenMMLab. All rights reserved.
import glob
import os.path as osp
import warnings
def find_latest_checkpoint(path, suffix='pth'):
"""Find the latest checkpoint from the working directory.
Args:
path(str): The path to find checkpoints.
suffix(str): File extension. Defaults to pth.
Returns:
latest_path(str | None): File path of the latest checkpoint.
References:
.. [1] https://github.com/microsoft/SoftTeacher
/blob/main/ssod/utils/patch.py
"""
if not osp.exists(path):
warnings.warn('The path of checkpoints does not exist.')
return None
if osp.exists(osp.join(path, f'latest.{suffix}')):
return osp.join(path, f'latest.{suffix}')
checkpoints = glob.glob(osp.join(path, f'*.{suffix}'))
if len(checkpoints) == 0:
warnings.warn('There are no checkpoints in the path.')
return None
latest = -1
latest_path = None
for checkpoint in checkpoints:
count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0])
if count > latest:
latest = count
latest_path = checkpoint
return latest_path

View File

@ -0,0 +1,43 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
from mmrazor.utils import find_latest_checkpoint
def test_find_latest_checkpoint():
with tempfile.TemporaryDirectory() as tmpdir:
path = tmpdir
latest = find_latest_checkpoint(path)
# There are no checkpoints in the path.
assert latest is None
path = tmpdir + '/none'
latest = find_latest_checkpoint(path)
# The path does not exist.
assert latest is None
with tempfile.TemporaryDirectory() as tmpdir:
with open(tmpdir + '/latest.pth', 'w') as f:
f.write('latest')
path = tmpdir
latest = find_latest_checkpoint(path)
assert latest == osp.join(tmpdir, 'latest.pth')
with tempfile.TemporaryDirectory() as tmpdir:
with open(tmpdir + '/iter_4000.pth', 'w') as f:
f.write('iter_4000')
with open(tmpdir + '/iter_8000.pth', 'w') as f:
f.write('iter_8000')
path = tmpdir
latest = find_latest_checkpoint(path)
assert latest == osp.join(tmpdir, 'iter_8000.pth')
with tempfile.TemporaryDirectory() as tmpdir:
with open(tmpdir + '/epoch_1.pth', 'w') as f:
f.write('epoch_1')
with open(tmpdir + '/epoch_2.pth', 'w') as f:
f.write('epoch_2')
path = tmpdir
latest = find_latest_checkpoint(path)
assert latest == osp.join(tmpdir, 'epoch_2.pth')

View File

@ -26,6 +26,10 @@ def parse_args():
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--resume-from', help='the checkpoint file to resume from')
parser.add_argument(
'--auto-resume',
action='store_true',
help='resume from the latest checkpoint automatically')
parser.add_argument(
'--no-validate',
action='store_true',
@ -101,6 +105,7 @@ def main():
osp.splitext(osp.basename(args.config))[0])
if args.resume_from is not None:
cfg.resume_from = args.resume_from
cfg.auto_resume = args.auto_resume
if args.gpus is not None:
cfg.gpu_ids = range(1)
warnings.warn('`--gpus` is deprecated because we only support '

View File

@ -34,6 +34,10 @@ def parse_args():
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--resume-from', help='the checkpoint file to resume from')
parser.add_argument(
'--auto-resume',
action='store_true',
help='resume from the latest checkpoint automatically')
parser.add_argument(
'--no-validate',
action='store_true',
@ -112,6 +116,7 @@ def main():
osp.splitext(osp.basename(args.config))[0])
if args.resume_from is not None:
cfg.resume_from = args.resume_from
cfg.auto_resume = args.auto_resume
if args.gpus is not None:
cfg.gpu_ids = range(1)
warnings.warn('`--gpus` is deprecated because we only support '

View File

@ -37,6 +37,10 @@ def parse_args():
'--load-from', help='the checkpoint file to load weights from')
parser.add_argument(
'--resume-from', help='the checkpoint file to resume from')
parser.add_argument(
'--auto-resume',
action='store_true',
help='resume from the latest checkpoint automatically')
parser.add_argument(
'--no-validate',
action='store_true',
@ -114,6 +118,7 @@ def main():
cfg.load_from = args.load_from
if args.resume_from is not None:
cfg.resume_from = args.resume_from
cfg.auto_resume = args.auto_resume
if args.gpus is not None:
cfg.gpu_ids = range(1)
warnings.warn('`--gpus` is deprecated because we only support '