2022-02-23 22:20:01 +08:00
|
|
|
# Copyright (c) OpenMMLab. All rights reserved.
|
2021-05-20 00:03:43 +08:00
|
|
|
|
2022-06-14 00:02:58 +08:00
|
|
|
import pytest
|
|
|
|
import torch
|
2021-05-20 00:03:43 +08:00
|
|
|
from click.testing import CliRunner
|
|
|
|
|
|
|
|
from mim.commands.gridsearch import cli as gridsearch
|
|
|
|
from mim.commands.install import cli as install
|
2022-06-14 00:02:58 +08:00
|
|
|
from mim.commands.uninstall import cli as uninstall
|
2021-05-20 00:03:43 +08:00
|
|
|
|
|
|
|
|
|
|
|
def setup_module():
|
|
|
|
runner = CliRunner()
|
2022-06-14 00:02:58 +08:00
|
|
|
result = runner.invoke(uninstall, ['mmcv-full', '--yes'])
|
|
|
|
assert result.exit_code == 0
|
2022-09-23 19:16:40 +08:00
|
|
|
result = runner.invoke(uninstall, ['mmcv', '--yes'])
|
|
|
|
assert result.exit_code == 0
|
2022-09-01 11:33:31 +08:00
|
|
|
result = runner.invoke(uninstall, ['mmengine', '--yes'])
|
|
|
|
assert result.exit_code == 0
|
2022-06-14 00:02:58 +08:00
|
|
|
result = runner.invoke(uninstall, ['mmcls', '--yes'])
|
|
|
|
assert result.exit_code == 0
|
2021-05-20 00:03:43 +08:00
|
|
|
|
|
|
|
|
2022-06-14 00:02:58 +08:00
|
|
|
@pytest.mark.parametrize('gpus', [
|
|
|
|
0,
|
|
|
|
pytest.param(
|
|
|
|
1,
|
|
|
|
marks=pytest.mark.skipif(
|
|
|
|
not torch.cuda.is_available(), reason='requires CUDA support')),
|
|
|
|
])
|
|
|
|
def test_gridsearch(gpus, tmp_path):
|
2021-05-20 00:03:43 +08:00
|
|
|
runner = CliRunner()
|
2022-09-23 19:16:40 +08:00
|
|
|
result = runner.invoke(install, ['mmcls>=1.0.0rc0', '--yes'])
|
2022-06-14 00:02:58 +08:00
|
|
|
assert result.exit_code == 0
|
2022-09-01 11:33:31 +08:00
|
|
|
result = runner.invoke(install, ['mmengine', '--yes'])
|
|
|
|
assert result.exit_code == 0
|
2022-09-23 19:16:40 +08:00
|
|
|
result = runner.invoke(install, ['mmcv>=2.0.0rc0', '--yes'])
|
2022-06-22 20:09:10 +08:00
|
|
|
assert result.exit_code == 0
|
2021-05-24 22:23:45 +08:00
|
|
|
|
|
|
|
args1 = [
|
2022-09-23 19:16:40 +08:00
|
|
|
'mmcls', 'tests/data/lenet5_mnist_2.0.py', f'--gpus={gpus}',
|
|
|
|
f'--work-dir={tmp_path}', '--search-args',
|
|
|
|
'--optim_wrapper.optimizer.lr 1e-3 1e-4'
|
2021-05-24 22:23:45 +08:00
|
|
|
]
|
|
|
|
args2 = [
|
2022-09-23 19:16:40 +08:00
|
|
|
'mmcls', 'tests/data/lenet5_mnist_2.0.py', f'--gpus={gpus}',
|
2022-06-14 00:02:58 +08:00
|
|
|
f'--work-dir={tmp_path}', '--search-args',
|
2022-09-23 19:16:40 +08:00
|
|
|
'--optim_wrapper.optimizer.weight_decay 1e-3 1e-4'
|
2021-05-24 22:23:45 +08:00
|
|
|
]
|
|
|
|
args3 = [
|
2022-06-14 00:02:58 +08:00
|
|
|
'mmcls', 'tests/data/xxx.py', f'--gpus={gpus}',
|
2022-09-23 19:16:40 +08:00
|
|
|
f'--work-dir={tmp_path}', '--search-args',
|
|
|
|
'--optim_wrapper.optimizer.lr 1e-3 1e-4'
|
2021-05-24 22:23:45 +08:00
|
|
|
]
|
|
|
|
args4 = [
|
2022-09-23 19:16:40 +08:00
|
|
|
'mmcls', 'tests/data/lenet5_mnist_2.0.py', f'--gpus={gpus}',
|
2022-06-14 00:02:58 +08:00
|
|
|
f'--work-dir={tmp_path}', '--search-args'
|
2021-05-24 22:23:45 +08:00
|
|
|
]
|
|
|
|
|
2023-01-17 20:26:50 +08:00
|
|
|
args5 = [
|
|
|
|
'mmcls', 'tests/data/lenet5_mnist_2.0.py', f'--gpus={gpus}',
|
|
|
|
f'--work-dir={tmp_path}', '--search-args',
|
|
|
|
'--train_dataloader.dataset.pipeline.0.scale 16 32'
|
|
|
|
]
|
|
|
|
|
2021-05-24 22:23:45 +08:00
|
|
|
result = runner.invoke(gridsearch, args1)
|
|
|
|
assert result.exit_code == 0
|
|
|
|
|
|
|
|
result = runner.invoke(gridsearch, args2)
|
|
|
|
assert result.exit_code == 0
|
|
|
|
|
|
|
|
result = runner.invoke(gridsearch, args3)
|
|
|
|
assert result.exit_code != 0
|
2021-05-20 00:03:43 +08:00
|
|
|
|
2021-05-24 22:23:45 +08:00
|
|
|
result = runner.invoke(gridsearch, args4)
|
|
|
|
assert result.exit_code != 0
|
2021-05-20 00:03:43 +08:00
|
|
|
|
2023-01-17 20:26:50 +08:00
|
|
|
result = runner.invoke(gridsearch, args5)
|
|
|
|
assert result.exit_code == 0
|
|
|
|
|
2022-06-14 00:02:58 +08:00
|
|
|
|
|
|
|
def teardown_module():
|
|
|
|
runner = CliRunner()
|
|
|
|
result = runner.invoke(uninstall, ['mmcv-full', '--yes'])
|
|
|
|
assert result.exit_code == 0
|
2022-09-23 19:16:40 +08:00
|
|
|
result = runner.invoke(uninstall, ['mmcv-full', '--yes'])
|
|
|
|
assert result.exit_code == 0
|
2022-09-01 11:33:31 +08:00
|
|
|
result = runner.invoke(uninstall, ['mmengine', '--yes'])
|
|
|
|
assert result.exit_code == 0
|
2022-06-14 00:02:58 +08:00
|
|
|
result = runner.invoke(uninstall, ['mmcls', '--yes'])
|
|
|
|
assert result.exit_code == 0
|