mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
Use tmp dir as work_dir of runner (#236)
* use tmp dir as work_dir of runner * only run codecov for python 3.7 * remove useless comments
This commit is contained in:
parent
ad98674fa3
commit
af02ac9f01
@ -34,7 +34,7 @@ script: coverage run --branch --source=mmcv -m pytest tests/
|
||||
|
||||
after_success:
|
||||
- coverage report -m
|
||||
- codecov
|
||||
- if [[ $TRAVIS_PYTHON_VERSION == "3.7" ]]; then codecov; fi
|
||||
|
||||
deploy:
|
||||
provider: pypi
|
||||
|
@ -6,8 +6,11 @@ CommandLine:
|
||||
xdoctest tests/test_hooks.py zero
|
||||
|
||||
"""
|
||||
import logging
|
||||
import os.path as osp
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock, call
|
||||
|
||||
import pytest
|
||||
@ -27,6 +30,7 @@ def test_pavi_hook():
|
||||
add_graph=False, add_last_ckpt=True)
|
||||
runner.register_hook(hook)
|
||||
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
|
||||
shutil.rmtree(runner.work_dir)
|
||||
|
||||
assert hasattr(hook, 'writer')
|
||||
hook.writer.add_scalars.assert_called_with('val', {
|
||||
@ -34,7 +38,7 @@ def test_pavi_hook():
|
||||
'momentum': 0.95
|
||||
}, 5)
|
||||
hook.writer.add_snapshot_file.assert_called_with(
|
||||
tag='data',
|
||||
tag=runner.work_dir.split('/')[-1],
|
||||
snapshot_file_path=osp.join(runner.work_dir, 'latest.pth'),
|
||||
iteration=5)
|
||||
|
||||
@ -69,6 +73,7 @@ def test_momentum_runner_hook():
|
||||
interval=1, add_graph=False, add_last_ckpt=True)
|
||||
runner.register_hook(hook)
|
||||
runner.run([loader], [('train', 1)], 1)
|
||||
shutil.rmtree(runner.work_dir)
|
||||
|
||||
# TODO: use a more elegant way to check values
|
||||
assert hasattr(hook, 'writer')
|
||||
@ -117,6 +122,7 @@ def test_cosine_runner_hook():
|
||||
interval=1, add_graph=False, add_last_ckpt=True)
|
||||
runner.register_hook(hook)
|
||||
runner.run([loader], [('train', 1)], 1)
|
||||
shutil.rmtree(runner.work_dir)
|
||||
|
||||
# TODO: use a more elegant way to check values
|
||||
assert hasattr(hook, 'writer')
|
||||
@ -149,6 +155,7 @@ def test_mlflow_hook(log_model):
|
||||
exp_name='test', log_model=log_model)
|
||||
runner.register_hook(hook)
|
||||
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
|
||||
shutil.rmtree(runner.work_dir)
|
||||
|
||||
hook.mlflow.set_experiment.assert_called_with('test')
|
||||
hook.mlflow.log_metrics.assert_called_with(
|
||||
@ -171,6 +178,8 @@ def test_wandb_hook():
|
||||
|
||||
runner.register_hook(hook)
|
||||
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
|
||||
shutil.rmtree(runner.work_dir)
|
||||
|
||||
hook.wandb.init.assert_called_with()
|
||||
hook.wandb.log.assert_called_with({
|
||||
'learning_rate': 0.02,
|
||||
@ -182,7 +191,6 @@ def test_wandb_hook():
|
||||
|
||||
def _build_demo_runner():
|
||||
model = nn.Linear(2, 1)
|
||||
work_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'data')
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.95)
|
||||
|
||||
log_config = dict(
|
||||
@ -190,11 +198,13 @@ def _build_demo_runner():
|
||||
dict(type='TextLoggerHook'),
|
||||
])
|
||||
|
||||
tmp_dir = tempfile.mkdtemp()
|
||||
runner = mmcv.runner.Runner(
|
||||
model=model,
|
||||
work_dir=work_dir,
|
||||
work_dir=tmp_dir,
|
||||
batch_processor=lambda model, x, **kwargs: {'loss': model(x) - 0},
|
||||
optimizer=optimizer)
|
||||
optimizer=optimizer,
|
||||
logger=logging.getLogger())
|
||||
|
||||
runner.register_logger_hooks(log_config)
|
||||
return runner
|
||||
|
@ -1,4 +1,5 @@
|
||||
# Copyright (c) Open-MMLab. All rights reserved.
|
||||
import logging
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
import warnings
|
||||
@ -15,7 +16,8 @@ def test_save_checkpoint():
|
||||
import mmcv.runner
|
||||
|
||||
model = nn.Linear(1, 1)
|
||||
runner = mmcv.runner.Runner(model=model, batch_processor=lambda x: x)
|
||||
runner = mmcv.runner.Runner(
|
||||
model=model, batch_processor=lambda x: x, logger=logging.getLogger())
|
||||
|
||||
with tempfile.TemporaryDirectory() as root:
|
||||
runner.save_checkpoint(root)
|
||||
|
Loading…
x
Reference in New Issue
Block a user