mmcv/tests/test_runner.py
David de la Iglesia Castro d5f190d12d
Add MlflowLoggerHook (#221)
* Add MLflowLoggerHook

* Add MLflowLoggerHook to __all__

* Update name

* Fix tracking.MlflowClient setup

* Fix log_metric

* Fix mlflow_pytorch import

* Handle active_run

* Fix self.mlflow reference

* Simplify using high level API

* Fix set_experiment

* Add only_if_torch_available decorator and test_mlflow_hook

* Add missing import in hooks

* Fix torch available check

* Patch mlflow.pytorch in test

* Parametrize log_model

* Fix log_model parametrize

* Add docstring

* Move wand patch

* Fix flake8

* Add regression test for non numeric metric

* Only log numbers

* Rename experiment_name-> exp_name

* Remove pytest skip
2020-04-14 23:54:55 +08:00

31 lines
809 B
Python

# Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp
import tempfile
import warnings
def test_save_checkpoint():
try:
import torch
from torch import nn
except ImportError:
warnings.warn('Skipping test_save_checkpoint in the absense of torch')
return
import mmcv.runner
model = nn.Linear(1, 1)
runner = mmcv.runner.Runner(model=model, batch_processor=lambda x: x)
with tempfile.TemporaryDirectory() as root:
runner.save_checkpoint(root)
latest_path = osp.join(root, 'latest.pth')
epoch1_path = osp.join(root, 'epoch_1.pth')
assert osp.exists(latest_path)
assert osp.exists(epoch1_path)
assert osp.realpath(latest_path) == osp.realpath(epoch1_path)
torch.load(latest_path)