[Fix] Fix the unit test of the download command (#95)

* [Fix] Fix the unit test of download command

* rename the name of config
pull/99/head^2
Zaida Zhou 2022-01-13 16:54:58 +08:00 committed by GitHub
parent d85de669b0
commit be720eeebb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 19 deletions

View File

@ -150,7 +150,7 @@ Please refer to [installation.md](docs/en/installation.md) for installation.
```bash
> mim search mmcls
> mim search mmcls==0.11.0 --remote
> mim search mmcls --config resnet18_b16x8_cifar10
> mim search mmcls --config resnet18_8xb16_cifar10
> mim search mmcls --model resnet
> mim search mmcls --dataset cifar-10
> mim search mmcls --valid-field
@ -188,8 +188,8 @@ Please refer to [installation.md](docs/en/installation.md) for installation.
- command
```bash
> mim download mmcls --config resnet18_b16x8_cifar10
> mim download mmcls --config resnet18_b16x8_cifar10 --dest .
> mim download mmcls --config resnet18_8xb16_cifar10
> mim download mmcls --config resnet18_8xb16_cifar10 --dest .
```
- api
@ -197,8 +197,8 @@ Please refer to [installation.md](docs/en/installation.md) for installation.
```python
from mim import download
download('mmcls', ['resnet18_b16x8_cifar10'])
download('mmcls', ['resnet18_b16x8_cifar10'], dest_dir='.')
download('mmcls', ['resnet18_8xb16_cifar10'])
download('mmcls', ['resnet18_8xb16_cifar10'], dest_dir='.')
```
</details>
@ -234,13 +234,13 @@ Please refer to [installation.md](docs/en/installation.md) for installation.
```python
from mim import train
train(repo='mmcls', config='resnet18_b16x8_cifar10.py', gpus=0,
train(repo='mmcls', config='resnet18_8xb16_cifar10.py', gpus=0,
other_args='--work-dir tmp')
train(repo='mmcls', config='resnet18_b16x8_cifar10.py', gpus=1,
train(repo='mmcls', config='resnet18_8xb16_cifar10.py', gpus=1,
other_args='--work-dir tmp')
train(repo='mmcls', config='resnet18_b16x8_cifar10.py', gpus=4,
train(repo='mmcls', config='resnet18_8xb16_cifar10.py', gpus=4,
launcher='pytorch', other_args='--work-dir tmp')
train(repo='mmcls', config='resnet18_b16x8_cifar10.py', gpus=8,
train(repo='mmcls', config='resnet18_8xb16_cifar10.py', gpus=8,
launcher='slurm', gpus_per_node=8, partition='partition_name',
other_args='--work-dir tmp')
```

View File

@ -27,7 +27,7 @@ from mim.utils import (
'configs',
cls=OptionEatAll,
required=True,
help='Config ids to download, such as resnet18_b16x8_cifar10')
help='Config ids to download, such as resnet18_8xb16_cifar10')
@click.option(
'--dest', 'dest_root', type=str, help='Destination of saving checkpoints.')
def cli(package: str,
@ -37,8 +37,8 @@ def cli(package: str,
\b
Example:
> mim download mmcls --config resnet18_b16x8_cifar10
> mim download mmcls --config resnet18_b16x8_cifar10 --dest .
> mim download mmcls --config resnet18_8xb16_cifar10
> mim download mmcls --config resnet18_8xb16_cifar10 --dest .
"""
download(package, configs, dest_root)

View File

@ -22,11 +22,11 @@ def test_download():
with pytest.raises(ValueError):
# verion is not allowed
download('mmcls==0.11.0', ['resnet18_b16x8_cifar10'])
download('mmcls==0.11.0', ['resnet18_8xb16_cifar10'])
with pytest.raises(RuntimeError):
# mmcls is not installed
download('mmcls', ['resnet18_b16x8_cifar10'])
download('mmcls', ['resnet18_8xb16_cifar10'])
with pytest.raises(ValueError):
# invalid config
@ -40,12 +40,12 @@ def test_download():
])
assert result.exit_code == 0
# mim download mmcls --config resnet18_b16x8_cifar10
checkpoints = download('mmcls', ['resnet18_b16x8_cifar10'])
# mim download mmcls --config resnet18_8xb16_cifar10
checkpoints = download('mmcls', ['resnet18_8xb16_cifar10'])
assert checkpoints == ['resnet18_b16x8_cifar10_20210528-bd6371c8.pth']
checkpoints = download('mmcls', ['resnet18_b16x8_cifar10'])
checkpoints = download('mmcls', ['resnet18_8xb16_cifar10'])
# mim download mmcls --config resnet18_b16x8_cifar10 --dest temp_root
# mim download mmcls --config resnet18_8xb16_cifar10 --dest temp_root
with tempfile.TemporaryDirectory() as temp_root:
checkpoints = download('mmcls', ['resnet18_b16x8_cifar10'], temp_root)
checkpoints = download('mmcls', ['resnet18_8xb16_cifar10'], temp_root)
assert checkpoints == ['resnet18_b16x8_cifar10_20210528-bd6371c8.pth']