Use tmp dir as work_dir of runner (#236)

* use tmp dir as work_dir of runner

* only run codecov for python 3.7

* remove useless comments
This commit is contained in:
Kai Chen 2020-04-22 23:33:54 +08:00 committed by GitHub
parent ad98674fa3
commit af02ac9f01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 18 additions and 6 deletions

View File

@ -34,7 +34,7 @@ script: coverage run --branch --source=mmcv -m pytest tests/
after_success: after_success:
- coverage report -m - coverage report -m
- codecov - if [[ $TRAVIS_PYTHON_VERSION == "3.7" ]]; then codecov; fi
deploy: deploy:
provider: pypi provider: pypi

View File

@ -6,8 +6,11 @@ CommandLine:
xdoctest tests/test_hooks.py zero xdoctest tests/test_hooks.py zero
""" """
import logging
import os.path as osp import os.path as osp
import shutil
import sys import sys
import tempfile
from unittest.mock import MagicMock, call from unittest.mock import MagicMock, call
import pytest import pytest
@ -27,6 +30,7 @@ def test_pavi_hook():
add_graph=False, add_last_ckpt=True) add_graph=False, add_last_ckpt=True)
runner.register_hook(hook) runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 1) runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
shutil.rmtree(runner.work_dir)
assert hasattr(hook, 'writer') assert hasattr(hook, 'writer')
hook.writer.add_scalars.assert_called_with('val', { hook.writer.add_scalars.assert_called_with('val', {
@ -34,7 +38,7 @@ def test_pavi_hook():
'momentum': 0.95 'momentum': 0.95
}, 5) }, 5)
hook.writer.add_snapshot_file.assert_called_with( hook.writer.add_snapshot_file.assert_called_with(
tag='data', tag=runner.work_dir.split('/')[-1],
snapshot_file_path=osp.join(runner.work_dir, 'latest.pth'), snapshot_file_path=osp.join(runner.work_dir, 'latest.pth'),
iteration=5) iteration=5)
@ -69,6 +73,7 @@ def test_momentum_runner_hook():
interval=1, add_graph=False, add_last_ckpt=True) interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook) runner.register_hook(hook)
runner.run([loader], [('train', 1)], 1) runner.run([loader], [('train', 1)], 1)
shutil.rmtree(runner.work_dir)
# TODO: use a more elegant way to check values # TODO: use a more elegant way to check values
assert hasattr(hook, 'writer') assert hasattr(hook, 'writer')
@ -117,6 +122,7 @@ def test_cosine_runner_hook():
interval=1, add_graph=False, add_last_ckpt=True) interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook) runner.register_hook(hook)
runner.run([loader], [('train', 1)], 1) runner.run([loader], [('train', 1)], 1)
shutil.rmtree(runner.work_dir)
# TODO: use a more elegant way to check values # TODO: use a more elegant way to check values
assert hasattr(hook, 'writer') assert hasattr(hook, 'writer')
@ -149,6 +155,7 @@ def test_mlflow_hook(log_model):
exp_name='test', log_model=log_model) exp_name='test', log_model=log_model)
runner.register_hook(hook) runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 1) runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
shutil.rmtree(runner.work_dir)
hook.mlflow.set_experiment.assert_called_with('test') hook.mlflow.set_experiment.assert_called_with('test')
hook.mlflow.log_metrics.assert_called_with( hook.mlflow.log_metrics.assert_called_with(
@ -171,6 +178,8 @@ def test_wandb_hook():
runner.register_hook(hook) runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 1) runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
shutil.rmtree(runner.work_dir)
hook.wandb.init.assert_called_with() hook.wandb.init.assert_called_with()
hook.wandb.log.assert_called_with({ hook.wandb.log.assert_called_with({
'learning_rate': 0.02, 'learning_rate': 0.02,
@ -182,7 +191,6 @@ def test_wandb_hook():
def _build_demo_runner(): def _build_demo_runner():
model = nn.Linear(2, 1) model = nn.Linear(2, 1)
work_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'data')
optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.95) optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.95)
log_config = dict( log_config = dict(
@ -190,11 +198,13 @@ def _build_demo_runner():
dict(type='TextLoggerHook'), dict(type='TextLoggerHook'),
]) ])
tmp_dir = tempfile.mkdtemp()
runner = mmcv.runner.Runner( runner = mmcv.runner.Runner(
model=model, model=model,
work_dir=work_dir, work_dir=tmp_dir,
batch_processor=lambda model, x, **kwargs: {'loss': model(x) - 0}, batch_processor=lambda model, x, **kwargs: {'loss': model(x) - 0},
optimizer=optimizer) optimizer=optimizer,
logger=logging.getLogger())
runner.register_logger_hooks(log_config) runner.register_logger_hooks(log_config)
return runner return runner

View File

@ -1,4 +1,5 @@
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
import logging
import os.path as osp import os.path as osp
import tempfile import tempfile
import warnings import warnings
@ -15,7 +16,8 @@ def test_save_checkpoint():
import mmcv.runner import mmcv.runner
model = nn.Linear(1, 1) model = nn.Linear(1, 1)
runner = mmcv.runner.Runner(model=model, batch_processor=lambda x: x) runner = mmcv.runner.Runner(
model=model, batch_processor=lambda x: x, logger=logging.getLogger())
with tempfile.TemporaryDirectory() as root: with tempfile.TemporaryDirectory() as root:
runner.save_checkpoint(root) runner.save_checkpoint(root)