mirror of https://github.com/open-mmlab/mim.git
52 lines
1.6 KiB
Python
52 lines
1.6 KiB
Python
import tempfile
|
|
|
|
import pytest
|
|
from click.testing import CliRunner
|
|
|
|
from mim.commands.download import download
|
|
from mim.commands.install import cli as install
|
|
from mim.commands.uninstall import cli as uninstall
|
|
|
|
|
|
def setup_module():
|
|
runner = CliRunner()
|
|
# mim uninstall mmcls --yes
|
|
result = runner.invoke(uninstall, ['mmcls', '--yes'])
|
|
assert result.exit_code == 0
|
|
|
|
|
|
def test_download():
|
|
runner = CliRunner()
|
|
result = runner.invoke(install, ['mmcv-full', '--yes'])
|
|
assert result.exit_code == 0
|
|
|
|
with pytest.raises(ValueError):
|
|
# verion is not allowed
|
|
download('mmcls==0.11.0', ['resnet18_b16x8_cifar10'])
|
|
|
|
with pytest.raises(RuntimeError):
|
|
# mmcls is not installed
|
|
download('mmcls', ['resnet18_b16x8_cifar10'])
|
|
|
|
with pytest.raises(ValueError):
|
|
# invalid config
|
|
download('mmcls==0.11.0', ['resnet18_b16x8_cifar1'])
|
|
|
|
runner = CliRunner()
|
|
# mim install mmcls --yes
|
|
result = runner.invoke(install, [
|
|
'mmcls', '--yes', '-f',
|
|
'https://github.com/open-mmlab/mmclassification.git'
|
|
])
|
|
assert result.exit_code == 0
|
|
|
|
# mim download mmcls --config resnet18_b16x8_cifar10
|
|
checkpoints = download('mmcls', ['resnet18_b16x8_cifar10'])
|
|
assert checkpoints == ['resnet18_b16x8_cifar10_20200823-f906fa4e.pth']
|
|
checkpoints = download('mmcls', ['resnet18_b16x8_cifar10'])
|
|
|
|
# mim download mmcls --config resnet18_b16x8_cifar10 --dest temp_root
|
|
with tempfile.TemporaryDirectory() as temp_root:
|
|
checkpoints = download('mmcls', ['resnet18_b16x8_cifar10'], temp_root)
|
|
assert checkpoints == ['resnet18_b16x8_cifar10_20200823-f906fa4e.pth']
|