42 lines
1.4 KiB
Python
42 lines
1.4 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import glob
|
|
import os.path as osp
|
|
import warnings
|
|
|
|
|
|
def find_latest_checkpoint(path, suffix='pth'):
|
|
"""This function is for finding the latest checkpoint.
|
|
|
|
It will be used when automatically resume, modified from
|
|
https://github.com/open-mmlab/mmdetection/blob/dev-v2.20.0/mmdet/utils/misc.py
|
|
|
|
Args:
|
|
path (str): The path to find checkpoints.
|
|
suffix (str): File extension for the checkpoint. Defaults to pth.
|
|
|
|
Returns:
|
|
latest_path(str | None): File path of the latest checkpoint.
|
|
"""
|
|
if not osp.exists(path):
|
|
warnings.warn("The path of the checkpoints doesn't exist.")
|
|
return None
|
|
if osp.exists(osp.join(path, f'latest.{suffix}')):
|
|
return osp.join(path, f'latest.{suffix}')
|
|
|
|
checkpoints = glob.glob(osp.join(path, f'*.{suffix}'))
|
|
if len(checkpoints) == 0:
|
|
warnings.warn('The are no checkpoints in the path')
|
|
return None
|
|
latest = -1
|
|
latest_path = ''
|
|
for checkpoint in checkpoints:
|
|
if len(checkpoint) < len(latest_path):
|
|
continue
|
|
# `count` is iteration number, as checkpoints are saved as
|
|
# 'iter_xx.pth' or 'epoch_xx.pth' and xx is iteration number.
|
|
count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0])
|
|
if count > latest:
|
|
latest = count
|
|
latest_path = checkpoint
|
|
return latest_path
|