[Feature] Resume from the latest checkpoint automatically. (#61)
* support auto-resume * support auto-resume * support auto-resume * support auto-resume Co-authored-by: pppppM <67539920+pppppM@users.noreply.github.com>pull/132/head
parent
366fd0f095
commit
81e0e3452a
|
@ -16,6 +16,7 @@ from mmrazor.core.distributed_wrapper import DistributedDataParallelWrapper
|
|||
from mmrazor.core.hooks import DistSamplerSeedHook
|
||||
from mmrazor.core.optimizer import build_optimizers
|
||||
from mmrazor.datasets.utils import split_dataset
|
||||
from mmrazor.utils import find_latest_checkpoint
|
||||
|
||||
|
||||
def set_random_seed(seed, deterministic=False):
|
||||
|
@ -190,6 +191,12 @@ def train_model(model,
|
|||
runner.register_hook(
|
||||
eval_hook(val_dataloader, **eval_cfg), priority='LOW')
|
||||
|
||||
resume_from = None
|
||||
if cfg.resume_from is None and cfg.get('auto_resume'):
|
||||
resume_from = find_latest_checkpoint(cfg.work_dir)
|
||||
if resume_from is not None:
|
||||
cfg.resume_from = resume_from
|
||||
|
||||
if cfg.resume_from:
|
||||
runner.resume(cfg.resume_from)
|
||||
elif cfg.load_from:
|
||||
|
|
|
@ -15,6 +15,7 @@ from mmdet.utils import get_root_logger
|
|||
from mmrazor.core.distributed_wrapper import DistributedDataParallelWrapper
|
||||
from mmrazor.core.hooks import DistSamplerSeedHook
|
||||
from mmrazor.core.optimizer import build_optimizers
|
||||
from mmrazor.utils import find_latest_checkpoint
|
||||
|
||||
|
||||
def set_random_seed(seed, deterministic=False):
|
||||
|
@ -181,6 +182,12 @@ def train_detector(model,
|
|||
runner.register_hook(
|
||||
eval_hook(val_dataloader, **eval_cfg), priority='LOW')
|
||||
|
||||
resume_from = None
|
||||
if cfg.resume_from is None and cfg.get('auto_resume'):
|
||||
resume_from = find_latest_checkpoint(cfg.work_dir)
|
||||
if resume_from is not None:
|
||||
cfg.resume_from = resume_from
|
||||
|
||||
if cfg.resume_from:
|
||||
runner.resume(cfg.resume_from)
|
||||
elif cfg.load_from:
|
||||
|
|
|
@ -12,6 +12,7 @@ from mmseg.utils import get_root_logger
|
|||
|
||||
from mmrazor.core.distributed_wrapper import DistributedDataParallelWrapper
|
||||
from mmrazor.core.optimizer import build_optimizers
|
||||
from mmrazor.utils import find_latest_checkpoint
|
||||
|
||||
|
||||
def set_random_seed(seed, deterministic=False):
|
||||
|
@ -137,6 +138,12 @@ def train_segmentor(model,
|
|||
runner.register_hook(
|
||||
eval_hook(val_dataloader, **eval_cfg), priority='LOW')
|
||||
|
||||
resume_from = None
|
||||
if cfg.resume_from is None and cfg.get('auto_resume'):
|
||||
resume_from = find_latest_checkpoint(cfg.work_dir)
|
||||
if resume_from is not None:
|
||||
cfg.resume_from = resume_from
|
||||
|
||||
if cfg.resume_from:
|
||||
runner.resume(cfg.resume_from)
|
||||
elif cfg.load_from:
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .misc import find_latest_checkpoint
|
||||
from .setup_env import setup_multi_processes
|
||||
|
||||
__all__ = ['setup_multi_processes']
|
||||
__all__ = ['find_latest_checkpoint', 'setup_multi_processes']
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import glob
|
||||
import os.path as osp
|
||||
import warnings
|
||||
|
||||
|
||||
def find_latest_checkpoint(path, suffix='pth'):
|
||||
"""Find the latest checkpoint from the working directory.
|
||||
|
||||
Args:
|
||||
path(str): The path to find checkpoints.
|
||||
suffix(str): File extension. Defaults to pth.
|
||||
|
||||
Returns:
|
||||
latest_path(str | None): File path of the latest checkpoint.
|
||||
|
||||
References:
|
||||
.. [1] https://github.com/microsoft/SoftTeacher
|
||||
/blob/main/ssod/utils/patch.py
|
||||
"""
|
||||
if not osp.exists(path):
|
||||
warnings.warn('The path of checkpoints does not exist.')
|
||||
return None
|
||||
if osp.exists(osp.join(path, f'latest.{suffix}')):
|
||||
return osp.join(path, f'latest.{suffix}')
|
||||
|
||||
checkpoints = glob.glob(osp.join(path, f'*.{suffix}'))
|
||||
if len(checkpoints) == 0:
|
||||
warnings.warn('There are no checkpoints in the path.')
|
||||
return None
|
||||
latest = -1
|
||||
latest_path = None
|
||||
for checkpoint in checkpoints:
|
||||
count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0])
|
||||
if count > latest:
|
||||
latest = count
|
||||
latest_path = checkpoint
|
||||
return latest_path
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
|
||||
from mmrazor.utils import find_latest_checkpoint
|
||||
|
||||
|
||||
def test_find_latest_checkpoint():
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = tmpdir
|
||||
latest = find_latest_checkpoint(path)
|
||||
# There are no checkpoints in the path.
|
||||
assert latest is None
|
||||
|
||||
path = tmpdir + '/none'
|
||||
latest = find_latest_checkpoint(path)
|
||||
# The path does not exist.
|
||||
assert latest is None
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with open(tmpdir + '/latest.pth', 'w') as f:
|
||||
f.write('latest')
|
||||
path = tmpdir
|
||||
latest = find_latest_checkpoint(path)
|
||||
assert latest == osp.join(tmpdir, 'latest.pth')
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with open(tmpdir + '/iter_4000.pth', 'w') as f:
|
||||
f.write('iter_4000')
|
||||
with open(tmpdir + '/iter_8000.pth', 'w') as f:
|
||||
f.write('iter_8000')
|
||||
path = tmpdir
|
||||
latest = find_latest_checkpoint(path)
|
||||
assert latest == osp.join(tmpdir, 'iter_8000.pth')
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with open(tmpdir + '/epoch_1.pth', 'w') as f:
|
||||
f.write('epoch_1')
|
||||
with open(tmpdir + '/epoch_2.pth', 'w') as f:
|
||||
f.write('epoch_2')
|
||||
path = tmpdir
|
||||
latest = find_latest_checkpoint(path)
|
||||
assert latest == osp.join(tmpdir, 'epoch_2.pth')
|
|
@ -26,6 +26,10 @@ def parse_args():
|
|||
parser.add_argument('--work-dir', help='the dir to save logs and models')
|
||||
parser.add_argument(
|
||||
'--resume-from', help='the checkpoint file to resume from')
|
||||
parser.add_argument(
|
||||
'--auto-resume',
|
||||
action='store_true',
|
||||
help='resume from the latest checkpoint automatically')
|
||||
parser.add_argument(
|
||||
'--no-validate',
|
||||
action='store_true',
|
||||
|
@ -101,6 +105,7 @@ def main():
|
|||
osp.splitext(osp.basename(args.config))[0])
|
||||
if args.resume_from is not None:
|
||||
cfg.resume_from = args.resume_from
|
||||
cfg.auto_resume = args.auto_resume
|
||||
if args.gpus is not None:
|
||||
cfg.gpu_ids = range(1)
|
||||
warnings.warn('`--gpus` is deprecated because we only support '
|
||||
|
|
|
@ -34,6 +34,10 @@ def parse_args():
|
|||
parser.add_argument('--work-dir', help='the dir to save logs and models')
|
||||
parser.add_argument(
|
||||
'--resume-from', help='the checkpoint file to resume from')
|
||||
parser.add_argument(
|
||||
'--auto-resume',
|
||||
action='store_true',
|
||||
help='resume from the latest checkpoint automatically')
|
||||
parser.add_argument(
|
||||
'--no-validate',
|
||||
action='store_true',
|
||||
|
@ -112,6 +116,7 @@ def main():
|
|||
osp.splitext(osp.basename(args.config))[0])
|
||||
if args.resume_from is not None:
|
||||
cfg.resume_from = args.resume_from
|
||||
cfg.auto_resume = args.auto_resume
|
||||
if args.gpus is not None:
|
||||
cfg.gpu_ids = range(1)
|
||||
warnings.warn('`--gpus` is deprecated because we only support '
|
||||
|
|
|
@ -37,6 +37,10 @@ def parse_args():
|
|||
'--load-from', help='the checkpoint file to load weights from')
|
||||
parser.add_argument(
|
||||
'--resume-from', help='the checkpoint file to resume from')
|
||||
parser.add_argument(
|
||||
'--auto-resume',
|
||||
action='store_true',
|
||||
help='resume from the latest checkpoint automatically')
|
||||
parser.add_argument(
|
||||
'--no-validate',
|
||||
action='store_true',
|
||||
|
@ -114,6 +118,7 @@ def main():
|
|||
cfg.load_from = args.load_from
|
||||
if args.resume_from is not None:
|
||||
cfg.resume_from = args.resume_from
|
||||
cfg.auto_resume = args.auto_resume
|
||||
if args.gpus is not None:
|
||||
cfg.gpu_ids = range(1)
|
||||
warnings.warn('`--gpus` is deprecated because we only support '
|
||||
|
|
Loading…
Reference in New Issue