2021-05-20 00:03:43 +08:00
|
|
|
import os.path as osp
|
2021-05-24 22:23:45 +08:00
|
|
|
import shutil
|
2021-05-20 00:03:43 +08:00
|
|
|
import time
|
|
|
|
|
|
|
|
from click.testing import CliRunner
|
|
|
|
|
|
|
|
from mim.commands.gridsearch import cli as gridsearch
|
|
|
|
from mim.commands.install import cli as install
|
|
|
|
from mim.utils import download_from_file, extract_tar, is_installed
|
|
|
|
|
2021-05-20 17:18:31 +08:00
|
|
|
dataset_url = 'https://download.openmmlab.com/mim/dataset.tar'
|
|
|
|
cfg_url = 'https://download.openmmlab.com/mim/resnet18_b16x8_custom.py'
|
2021-05-20 00:03:43 +08:00
|
|
|
|
|
|
|
|
|
|
|
def setup_module():
|
|
|
|
runner = CliRunner()
|
|
|
|
|
|
|
|
if not is_installed('mmcls'):
|
|
|
|
result = runner.invoke(install, ['mmcls', '--yes'])
|
|
|
|
assert result.exit_code == 0
|
|
|
|
|
|
|
|
|
|
|
|
def test_gridsearch():
|
|
|
|
runner = CliRunner()
|
2021-05-24 22:23:45 +08:00
|
|
|
if not osp.exists('/tmp/dataset'):
|
|
|
|
download_from_file(dataset_url, '/tmp/dataset.tar')
|
2021-06-01 13:03:39 +08:00
|
|
|
extract_tar('/tmp/dataset.tar', '/tmp/')
|
2021-05-20 00:03:43 +08:00
|
|
|
|
2021-05-24 22:23:45 +08:00
|
|
|
if not osp.exists('/tmp/config.py'):
|
|
|
|
download_from_file(cfg_url, '/tmp/config.py')
|
|
|
|
|
|
|
|
# wait for the download task to complete
|
|
|
|
time.sleep(5)
|
|
|
|
|
|
|
|
args1 = [
|
2021-06-01 13:03:39 +08:00
|
|
|
'mmcls', '/tmp/config.py', '--gpus=0', '--work-dir=tmp',
|
2021-05-24 22:23:45 +08:00
|
|
|
'--search-args', '--optimizer.lr 1e-3 1e-4'
|
|
|
|
]
|
|
|
|
args2 = [
|
2021-06-01 13:03:39 +08:00
|
|
|
'mmcls', '/tmp/config.py', '--gpus=0', '--work-dir=tmp',
|
2021-05-24 22:23:45 +08:00
|
|
|
'--search-args', '--optimizer.weight_decay 1e-3 1e-4'
|
|
|
|
]
|
|
|
|
args3 = [
|
2021-06-01 13:03:39 +08:00
|
|
|
'mmcls', '/tmp/xxx.py', '--gpus=0', '--work-dir=tmp', '--search-args',
|
2021-05-24 22:23:45 +08:00
|
|
|
'--optimizer.lr 1e-3 1e-4'
|
|
|
|
]
|
|
|
|
args4 = [
|
2021-06-01 13:03:39 +08:00
|
|
|
'mmcls', '/tmp/config.py', '--gpus=0', '--work-dir=tmp',
|
2021-05-24 22:23:45 +08:00
|
|
|
'--search-args'
|
|
|
|
]
|
|
|
|
|
|
|
|
result = runner.invoke(gridsearch, args1)
|
|
|
|
assert result.exit_code == 0
|
|
|
|
|
|
|
|
result = runner.invoke(gridsearch, args2)
|
|
|
|
assert result.exit_code == 0
|
|
|
|
|
|
|
|
result = runner.invoke(gridsearch, args3)
|
|
|
|
assert result.exit_code != 0
|
2021-05-20 00:03:43 +08:00
|
|
|
|
2021-05-24 22:23:45 +08:00
|
|
|
result = runner.invoke(gridsearch, args4)
|
|
|
|
assert result.exit_code != 0
|
2021-05-20 00:03:43 +08:00
|
|
|
|
2021-05-24 22:23:45 +08:00
|
|
|
shutil.rmtree('tmp_search_optimizer.lr_0.001')
|
|
|
|
shutil.rmtree('tmp_search_optimizer.lr_0.0001')
|
|
|
|
shutil.rmtree('tmp_search_optimizer.weight_decay_0.001')
|
|
|
|
shutil.rmtree('tmp_search_optimizer.weight_decay_0.0001')
|