mim/tests/test_gridsearch.py
yancong 23e63429ae
use mmengine.Config instead of mmcv.Config (#155)
* use mmengine.Config instead of mmcv.Config

* install mmengine in CI

* fix mmengine install issue
2022-09-01 11:33:31 +08:00

79 lines
2.5 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from click.testing import CliRunner
from mim.commands.gridsearch import cli as gridsearch
from mim.commands.install import cli as install
from mim.commands.uninstall import cli as uninstall
def setup_module():
runner = CliRunner()
result = runner.invoke(uninstall, ['mmcv-full', '--yes'])
assert result.exit_code == 0
result = runner.invoke(uninstall, ['mmengine', '--yes'])
assert result.exit_code == 0
result = runner.invoke(uninstall, ['mmcls', '--yes'])
assert result.exit_code == 0
@pytest.mark.parametrize('gpus', [
0,
pytest.param(
1,
marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')),
])
def test_gridsearch(gpus, tmp_path):
runner = CliRunner()
result = runner.invoke(install, ['mmcls', '--yes'])
assert result.exit_code == 0
result = runner.invoke(install, ['mmengine', '--yes'])
assert result.exit_code == 0
# Since `mminstall.txt` is not included in the distribution of
# mmcls<=0.23.1, we need to install mmcv-full manually.
result = runner.invoke(install, ['mmcv-full', '--yes'])
assert result.exit_code == 0
args1 = [
'mmcls', 'tests/data/lenet5_mnist.py', f'--gpus={gpus}',
f'--work-dir={tmp_path}', '--search-args', '--optimizer.lr 1e-3 1e-4'
]
args2 = [
'mmcls', 'tests/data/lenet5_mnist.py', f'--gpus={gpus}',
f'--work-dir={tmp_path}', '--search-args',
'--optimizer.weight_decay 1e-3 1e-4'
]
args3 = [
'mmcls', 'tests/data/xxx.py', f'--gpus={gpus}',
f'--work-dir={tmp_path}', '--search-args', '--optimizer.lr 1e-3 1e-4'
]
args4 = [
'mmcls', 'tests/data/lenet5_mnist.py', f'--gpus={gpus}',
f'--work-dir={tmp_path}', '--search-args'
]
result = runner.invoke(gridsearch, args1)
assert result.exit_code == 0
result = runner.invoke(gridsearch, args2)
assert result.exit_code == 0
result = runner.invoke(gridsearch, args3)
assert result.exit_code != 0
result = runner.invoke(gridsearch, args4)
assert result.exit_code != 0
def teardown_module():
runner = CliRunner()
result = runner.invoke(uninstall, ['mmcv-full', '--yes'])
assert result.exit_code == 0
result = runner.invoke(uninstall, ['mmengine', '--yes'])
assert result.exit_code == 0
result = runner.invoke(uninstall, ['mmcls', '--yes'])
assert result.exit_code == 0