mim/tests/test_gridsearch.py
Zaida Zhou 695f2af0e4
[Fix] Fix and refactor unit tests (#128)
* fix and refactor unit tests

* Fix GPG key error in CI

* use container provided by pytorch

* install git in container

* install git in container

* fix ci

* update pip version

* install system dependencies

* add test data

* add circleci

* add test data

* refine ut
2022-06-14 00:02:58 +08:00

69 lines
2.0 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from click.testing import CliRunner
from mim.commands.gridsearch import cli as gridsearch
from mim.commands.install import cli as install
from mim.commands.uninstall import cli as uninstall
def setup_module():
runner = CliRunner()
result = runner.invoke(uninstall, ['mmcv-full', '--yes'])
assert result.exit_code == 0
result = runner.invoke(uninstall, ['mmcls', '--yes'])
assert result.exit_code == 0
@pytest.mark.parametrize('gpus', [
0,
pytest.param(
1,
marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')),
])
def test_gridsearch(gpus, tmp_path):
runner = CliRunner()
result = runner.invoke(install, ['mmcls', '--yes'])
assert result.exit_code == 0
args1 = [
'mmcls', 'tests/data/lenet5_mnist.py', f'--gpus={gpus}',
f'--work-dir={tmp_path}', '--search-args', '--optimizer.lr 1e-3 1e-4'
]
args2 = [
'mmcls', 'tests/data/lenet5_mnist.py', f'--gpus={gpus}',
f'--work-dir={tmp_path}', '--search-args',
'--optimizer.weight_decay 1e-3 1e-4'
]
args3 = [
'mmcls', 'tests/data/xxx.py', f'--gpus={gpus}',
f'--work-dir={tmp_path}', '--search-args', '--optimizer.lr 1e-3 1e-4'
]
args4 = [
'mmcls', 'tests/data/lenet5_mnist.py', f'--gpus={gpus}',
f'--work-dir={tmp_path}', '--search-args'
]
result = runner.invoke(gridsearch, args1)
assert result.exit_code == 0
result = runner.invoke(gridsearch, args2)
assert result.exit_code == 0
result = runner.invoke(gridsearch, args3)
assert result.exit_code != 0
result = runner.invoke(gridsearch, args4)
assert result.exit_code != 0
def teardown_module():
runner = CliRunner()
result = runner.invoke(uninstall, ['mmcv-full', '--yes'])
assert result.exit_code == 0
result = runner.invoke(uninstall, ['mmcls', '--yes'])
assert result.exit_code == 0