[Fix]: Fix resume bug (#389)
* [Fix]: Fix resume bug * [Fix]: Change last_checkpoint check logic * [Fix]: Fix lint * [Fix]: Change warning to print_logpull/405/head
parent
84cd19aaa2
commit
81c3de54b9
|
@ -8,7 +8,7 @@ import warnings
|
|||
from collections import OrderedDict
|
||||
from importlib import import_module
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Callable, Dict
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
|
@ -17,6 +17,7 @@ import mmengine
|
|||
from mmengine.dist import get_dist_info
|
||||
from mmengine.fileio import FileClient
|
||||
from mmengine.fileio import load as load_file
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.model import is_model_wrapper
|
||||
from mmengine.utils import load_url, mkdir_or_exist
|
||||
|
||||
|
@ -697,7 +698,7 @@ def save_checkpoint(checkpoint, filename, file_client_args=None):
|
|||
file_client.put(f.getvalue(), filename)
|
||||
|
||||
|
||||
def find_latest_checkpoint(path: str):
|
||||
def find_latest_checkpoint(path: str) -> Optional[str]:
|
||||
"""Find the latest checkpoint from the given path.
|
||||
|
||||
Refer to https://github.com/facebookresearch/fvcore/blob/main/fvcore/common/checkpoint.py # noqa: E501
|
||||
|
@ -709,12 +710,11 @@ def find_latest_checkpoint(path: str):
|
|||
str or None: File path of the latest checkpoint.
|
||||
"""
|
||||
save_file = osp.join(path, 'last_checkpoint')
|
||||
try:
|
||||
last_saved: Optional[str]
|
||||
if os.path.exists(save_file):
|
||||
with open(save_file) as f:
|
||||
last_saved = f.read().strip()
|
||||
except OSError:
|
||||
raise OSError(
|
||||
'last_checkpoint file does not exist, maybe because it has just'
|
||||
' been deleted by a separate process')
|
||||
|
||||
else:
|
||||
print_log('Did not find last_checkpoint to be resumed.')
|
||||
last_saved = None
|
||||
return last_saved
|
||||
|
|
Loading…
Reference in New Issue