mmsegmentation/tests/test_utils/test_misc.py

41 lines
1.4 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
from mmseg.utils import find_latest_checkpoint
def test_find_latest_checkpoint():
with tempfile.TemporaryDirectory() as tempdir:
# no checkpoints in the path
path = tempdir
latest = find_latest_checkpoint(path)
assert latest is None
# The path doesn't exist
path = osp.join(tempdir, 'none')
latest = find_latest_checkpoint(path)
assert latest is None
# test when latest.pth exists
with tempfile.TemporaryDirectory() as tempdir:
with open(osp.join(tempdir, 'latest.pth'), 'w') as f:
f.write('latest')
path = tempdir
latest = find_latest_checkpoint(path)
assert latest == osp.join(tempdir, 'latest.pth')
with tempfile.TemporaryDirectory() as tempdir:
for iter in range(1600, 160001, 1600):
with open(osp.join(tempdir, f'iter_{iter}.pth'), 'w') as f:
f.write(f'iter_{iter}.pth')
latest = find_latest_checkpoint(tempdir)
assert latest == osp.join(tempdir, 'iter_160000.pth')
with tempfile.TemporaryDirectory() as tempdir:
for epoch in range(1, 21):
with open(osp.join(tempdir, f'epoch_{epoch}.pth'), 'w') as f:
f.write(f'epoch_{epoch}.pth')
latest = find_latest_checkpoint(tempdir)
assert latest == osp.join(tempdir, 'epoch_20.pth')