[Fix]: Fix resume bug (#389)
* [Fix]: Fix resume bug * [Fix]: Change last_checkpoint check logic * [Fix]: Fix lint * [Fix]: Change warning to print_logpull/405/head
parent
84cd19aaa2
commit
81c3de54b9
|
@ -8,7 +8,7 @@ import warnings
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from typing import Callable, Dict
|
from typing import Callable, Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
|
@ -17,6 +17,7 @@ import mmengine
|
||||||
from mmengine.dist import get_dist_info
|
from mmengine.dist import get_dist_info
|
||||||
from mmengine.fileio import FileClient
|
from mmengine.fileio import FileClient
|
||||||
from mmengine.fileio import load as load_file
|
from mmengine.fileio import load as load_file
|
||||||
|
from mmengine.logging import print_log
|
||||||
from mmengine.model import is_model_wrapper
|
from mmengine.model import is_model_wrapper
|
||||||
from mmengine.utils import load_url, mkdir_or_exist
|
from mmengine.utils import load_url, mkdir_or_exist
|
||||||
|
|
||||||
|
@ -697,7 +698,7 @@ def save_checkpoint(checkpoint, filename, file_client_args=None):
|
||||||
file_client.put(f.getvalue(), filename)
|
file_client.put(f.getvalue(), filename)
|
||||||
|
|
||||||
|
|
||||||
def find_latest_checkpoint(path: str):
|
def find_latest_checkpoint(path: str) -> Optional[str]:
|
||||||
"""Find the latest checkpoint from the given path.
|
"""Find the latest checkpoint from the given path.
|
||||||
|
|
||||||
Refer to https://github.com/facebookresearch/fvcore/blob/main/fvcore/common/checkpoint.py # noqa: E501
|
Refer to https://github.com/facebookresearch/fvcore/blob/main/fvcore/common/checkpoint.py # noqa: E501
|
||||||
|
@ -709,12 +710,11 @@ def find_latest_checkpoint(path: str):
|
||||||
str or None: File path of the latest checkpoint.
|
str or None: File path of the latest checkpoint.
|
||||||
"""
|
"""
|
||||||
save_file = osp.join(path, 'last_checkpoint')
|
save_file = osp.join(path, 'last_checkpoint')
|
||||||
try:
|
last_saved: Optional[str]
|
||||||
|
if os.path.exists(save_file):
|
||||||
with open(save_file) as f:
|
with open(save_file) as f:
|
||||||
last_saved = f.read().strip()
|
last_saved = f.read().strip()
|
||||||
except OSError:
|
else:
|
||||||
raise OSError(
|
print_log('Did not find last_checkpoint to be resumed.')
|
||||||
'last_checkpoint file does not exist, maybe because it has just'
|
last_saved = None
|
||||||
' been deleted by a separate process')
|
|
||||||
|
|
||||||
return last_saved
|
return last_saved
|
||||||
|
|
Loading…
Reference in New Issue