mim/tests/test_gridsearch.py

69 lines
2.0 KiB
Python
Raw Normal View History

2022-02-23 22:20:01 +08:00
# Copyright (c) OpenMMLab. All rights reserved.
2021-05-20 00:03:43 +08:00
import pytest
import torch
2021-05-20 00:03:43 +08:00
from click.testing import CliRunner
from mim.commands.gridsearch import cli as gridsearch
from mim.commands.install import cli as install
from mim.commands.uninstall import cli as uninstall
2021-05-20 00:03:43 +08:00
def setup_module():
runner = CliRunner()
result = runner.invoke(uninstall, ['mmcv-full', '--yes'])
assert result.exit_code == 0
result = runner.invoke(uninstall, ['mmcls', '--yes'])
assert result.exit_code == 0
2021-05-20 00:03:43 +08:00
@pytest.mark.parametrize('gpus', [
0,
pytest.param(
1,
marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')),
])
def test_gridsearch(gpus, tmp_path):
2021-05-20 00:03:43 +08:00
runner = CliRunner()
result = runner.invoke(install, ['mmcls', '--yes'])
assert result.exit_code == 0
2021-05-24 22:23:45 +08:00
args1 = [
'mmcls', 'tests/data/lenet5_mnist.py', f'--gpus={gpus}',
f'--work-dir={tmp_path}', '--search-args', '--optimizer.lr 1e-3 1e-4'
2021-05-24 22:23:45 +08:00
]
args2 = [
'mmcls', 'tests/data/lenet5_mnist.py', f'--gpus={gpus}',
f'--work-dir={tmp_path}', '--search-args',
'--optimizer.weight_decay 1e-3 1e-4'
2021-05-24 22:23:45 +08:00
]
args3 = [
'mmcls', 'tests/data/xxx.py', f'--gpus={gpus}',
f'--work-dir={tmp_path}', '--search-args', '--optimizer.lr 1e-3 1e-4'
2021-05-24 22:23:45 +08:00
]
args4 = [
'mmcls', 'tests/data/lenet5_mnist.py', f'--gpus={gpus}',
f'--work-dir={tmp_path}', '--search-args'
2021-05-24 22:23:45 +08:00
]
result = runner.invoke(gridsearch, args1)
assert result.exit_code == 0
result = runner.invoke(gridsearch, args2)
assert result.exit_code == 0
result = runner.invoke(gridsearch, args3)
assert result.exit_code != 0
2021-05-20 00:03:43 +08:00
2021-05-24 22:23:45 +08:00
result = runner.invoke(gridsearch, args4)
assert result.exit_code != 0
2021-05-20 00:03:43 +08:00
def teardown_module():
runner = CliRunner()
result = runner.invoke(uninstall, ['mmcv-full', '--yes'])
assert result.exit_code == 0
result = runner.invoke(uninstall, ['mmcls', '--yes'])
assert result.exit_code == 0