2020-01-10 13:34:42 +08:00
|
|
|
# Copyright (c) Open-MMLab. All rights reserved.
|
2020-04-22 23:33:54 +08:00
|
|
|
import logging
|
2019-06-15 14:56:27 +03:00
|
|
|
import os.path as osp
|
|
|
|
import tempfile
|
|
|
|
import warnings
|
2020-02-24 22:31:36 +08:00
|
|
|
|
2019-06-15 14:56:27 +03:00
|
|
|
|
|
|
|
def test_save_checkpoint():
|
|
|
|
try:
|
|
|
|
import torch
|
2019-12-01 20:32:50 -08:00
|
|
|
from torch import nn
|
2019-06-15 14:56:27 +03:00
|
|
|
except ImportError:
|
|
|
|
warnings.warn('Skipping test_save_checkpoint in the absense of torch')
|
|
|
|
return
|
|
|
|
|
|
|
|
import mmcv.runner
|
|
|
|
|
|
|
|
model = nn.Linear(1, 1)
|
2020-04-22 23:33:54 +08:00
|
|
|
runner = mmcv.runner.Runner(
|
|
|
|
model=model, batch_processor=lambda x: x, logger=logging.getLogger())
|
2019-06-15 14:56:27 +03:00
|
|
|
|
|
|
|
with tempfile.TemporaryDirectory() as root:
|
|
|
|
runner.save_checkpoint(root)
|
|
|
|
|
|
|
|
latest_path = osp.join(root, 'latest.pth')
|
|
|
|
epoch1_path = osp.join(root, 'epoch_1.pth')
|
|
|
|
|
|
|
|
assert osp.exists(latest_path)
|
|
|
|
assert osp.exists(epoch1_path)
|
2020-02-24 22:31:36 +08:00
|
|
|
assert osp.realpath(latest_path) == osp.realpath(epoch1_path)
|
2019-06-15 14:56:27 +03:00
|
|
|
|
|
|
|
torch.load(latest_path)
|