mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
Use tmp dir as work_dir of runner (#236)
* use tmp dir as work_dir of runner * only run codecov for python 3.7 * remove useless comments
This commit is contained in:
parent
ad98674fa3
commit
af02ac9f01
@ -34,7 +34,7 @@ script: coverage run --branch --source=mmcv -m pytest tests/
|
|||||||
|
|
||||||
after_success:
|
after_success:
|
||||||
- coverage report -m
|
- coverage report -m
|
||||||
- codecov
|
- if [[ $TRAVIS_PYTHON_VERSION == "3.7" ]]; then codecov; fi
|
||||||
|
|
||||||
deploy:
|
deploy:
|
||||||
provider: pypi
|
provider: pypi
|
||||||
|
@ -6,8 +6,11 @@ CommandLine:
|
|||||||
xdoctest tests/test_hooks.py zero
|
xdoctest tests/test_hooks.py zero
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
import logging
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
|
import tempfile
|
||||||
from unittest.mock import MagicMock, call
|
from unittest.mock import MagicMock, call
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -27,6 +30,7 @@ def test_pavi_hook():
|
|||||||
add_graph=False, add_last_ckpt=True)
|
add_graph=False, add_last_ckpt=True)
|
||||||
runner.register_hook(hook)
|
runner.register_hook(hook)
|
||||||
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
|
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
|
||||||
|
shutil.rmtree(runner.work_dir)
|
||||||
|
|
||||||
assert hasattr(hook, 'writer')
|
assert hasattr(hook, 'writer')
|
||||||
hook.writer.add_scalars.assert_called_with('val', {
|
hook.writer.add_scalars.assert_called_with('val', {
|
||||||
@ -34,7 +38,7 @@ def test_pavi_hook():
|
|||||||
'momentum': 0.95
|
'momentum': 0.95
|
||||||
}, 5)
|
}, 5)
|
||||||
hook.writer.add_snapshot_file.assert_called_with(
|
hook.writer.add_snapshot_file.assert_called_with(
|
||||||
tag='data',
|
tag=runner.work_dir.split('/')[-1],
|
||||||
snapshot_file_path=osp.join(runner.work_dir, 'latest.pth'),
|
snapshot_file_path=osp.join(runner.work_dir, 'latest.pth'),
|
||||||
iteration=5)
|
iteration=5)
|
||||||
|
|
||||||
@ -69,6 +73,7 @@ def test_momentum_runner_hook():
|
|||||||
interval=1, add_graph=False, add_last_ckpt=True)
|
interval=1, add_graph=False, add_last_ckpt=True)
|
||||||
runner.register_hook(hook)
|
runner.register_hook(hook)
|
||||||
runner.run([loader], [('train', 1)], 1)
|
runner.run([loader], [('train', 1)], 1)
|
||||||
|
shutil.rmtree(runner.work_dir)
|
||||||
|
|
||||||
# TODO: use a more elegant way to check values
|
# TODO: use a more elegant way to check values
|
||||||
assert hasattr(hook, 'writer')
|
assert hasattr(hook, 'writer')
|
||||||
@ -117,6 +122,7 @@ def test_cosine_runner_hook():
|
|||||||
interval=1, add_graph=False, add_last_ckpt=True)
|
interval=1, add_graph=False, add_last_ckpt=True)
|
||||||
runner.register_hook(hook)
|
runner.register_hook(hook)
|
||||||
runner.run([loader], [('train', 1)], 1)
|
runner.run([loader], [('train', 1)], 1)
|
||||||
|
shutil.rmtree(runner.work_dir)
|
||||||
|
|
||||||
# TODO: use a more elegant way to check values
|
# TODO: use a more elegant way to check values
|
||||||
assert hasattr(hook, 'writer')
|
assert hasattr(hook, 'writer')
|
||||||
@ -149,6 +155,7 @@ def test_mlflow_hook(log_model):
|
|||||||
exp_name='test', log_model=log_model)
|
exp_name='test', log_model=log_model)
|
||||||
runner.register_hook(hook)
|
runner.register_hook(hook)
|
||||||
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
|
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
|
||||||
|
shutil.rmtree(runner.work_dir)
|
||||||
|
|
||||||
hook.mlflow.set_experiment.assert_called_with('test')
|
hook.mlflow.set_experiment.assert_called_with('test')
|
||||||
hook.mlflow.log_metrics.assert_called_with(
|
hook.mlflow.log_metrics.assert_called_with(
|
||||||
@ -171,6 +178,8 @@ def test_wandb_hook():
|
|||||||
|
|
||||||
runner.register_hook(hook)
|
runner.register_hook(hook)
|
||||||
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
|
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
|
||||||
|
shutil.rmtree(runner.work_dir)
|
||||||
|
|
||||||
hook.wandb.init.assert_called_with()
|
hook.wandb.init.assert_called_with()
|
||||||
hook.wandb.log.assert_called_with({
|
hook.wandb.log.assert_called_with({
|
||||||
'learning_rate': 0.02,
|
'learning_rate': 0.02,
|
||||||
@ -182,7 +191,6 @@ def test_wandb_hook():
|
|||||||
|
|
||||||
def _build_demo_runner():
|
def _build_demo_runner():
|
||||||
model = nn.Linear(2, 1)
|
model = nn.Linear(2, 1)
|
||||||
work_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'data')
|
|
||||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.95)
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.95)
|
||||||
|
|
||||||
log_config = dict(
|
log_config = dict(
|
||||||
@ -190,11 +198,13 @@ def _build_demo_runner():
|
|||||||
dict(type='TextLoggerHook'),
|
dict(type='TextLoggerHook'),
|
||||||
])
|
])
|
||||||
|
|
||||||
|
tmp_dir = tempfile.mkdtemp()
|
||||||
runner = mmcv.runner.Runner(
|
runner = mmcv.runner.Runner(
|
||||||
model=model,
|
model=model,
|
||||||
work_dir=work_dir,
|
work_dir=tmp_dir,
|
||||||
batch_processor=lambda model, x, **kwargs: {'loss': model(x) - 0},
|
batch_processor=lambda model, x, **kwargs: {'loss': model(x) - 0},
|
||||||
optimizer=optimizer)
|
optimizer=optimizer,
|
||||||
|
logger=logging.getLogger())
|
||||||
|
|
||||||
runner.register_logger_hooks(log_config)
|
runner.register_logger_hooks(log_config)
|
||||||
return runner
|
return runner
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
# Copyright (c) Open-MMLab. All rights reserved.
|
# Copyright (c) Open-MMLab. All rights reserved.
|
||||||
|
import logging
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import tempfile
|
import tempfile
|
||||||
import warnings
|
import warnings
|
||||||
@ -15,7 +16,8 @@ def test_save_checkpoint():
|
|||||||
import mmcv.runner
|
import mmcv.runner
|
||||||
|
|
||||||
model = nn.Linear(1, 1)
|
model = nn.Linear(1, 1)
|
||||||
runner = mmcv.runner.Runner(model=model, batch_processor=lambda x: x)
|
runner = mmcv.runner.Runner(
|
||||||
|
model=model, batch_processor=lambda x: x, logger=logging.getLogger())
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as root:
|
with tempfile.TemporaryDirectory() as root:
|
||||||
runner.save_checkpoint(root)
|
runner.save_checkpoint(root)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user