mirror of https://github.com/open-mmlab/mim.git
[Feature] Support donwloading datasets from opendatalab (#212)
parent
706cdc58b2
commit
bc5aec2abe
|
@ -1,9 +1,13 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import os.path as osp
|
||||
from typing import List, Optional
|
||||
import platform
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import click
|
||||
import yaml
|
||||
|
||||
from mim.click import (
|
||||
OptionEatAll,
|
||||
|
@ -14,11 +18,14 @@ from mim.click import (
|
|||
from mim.commands.search import get_model_info
|
||||
from mim.utils import (
|
||||
DEFAULT_CACHE_DIR,
|
||||
call_command,
|
||||
color_echo,
|
||||
download_from_file,
|
||||
echo_success,
|
||||
get_installed_path,
|
||||
highlighted_error,
|
||||
is_installed,
|
||||
module_full_name,
|
||||
split_package_version,
|
||||
)
|
||||
|
||||
|
@ -33,8 +40,13 @@ from mim.utils import (
|
|||
'--config',
|
||||
'configs',
|
||||
cls=OptionEatAll,
|
||||
required=True,
|
||||
help='Config ids to download, such as resnet18_8xb16_cifar10')
|
||||
help='Config ids to download, such as resnet18_8xb16_cifar10',
|
||||
default=None)
|
||||
@click.option(
|
||||
'--dataset',
|
||||
'dataset',
|
||||
help='dataset name to download, such as coco2017',
|
||||
default=None)
|
||||
@click.option(
|
||||
'--ignore-ssl',
|
||||
'check_certificate',
|
||||
|
@ -44,7 +56,8 @@ from mim.utils import (
|
|||
@click.option(
|
||||
'--dest', 'dest_root', type=str, help='Destination of saving checkpoints.')
|
||||
def cli(package: str,
|
||||
configs: List[str],
|
||||
configs: Optional[List[str]],
|
||||
dataset: Optional[str],
|
||||
dest_root: Optional[str] = None,
|
||||
check_certificate: bool = True) -> None:
|
||||
"""Download checkpoints from url and parse configs from package.
|
||||
|
@ -54,31 +67,55 @@ def cli(package: str,
|
|||
> mim download mmcls --config resnet18_8xb16_cifar10
|
||||
> mim download mmcls --config resnet18_8xb16_cifar10 --dest .
|
||||
"""
|
||||
download(package, configs, dest_root, check_certificate)
|
||||
download(package, configs, dest_root, check_certificate, dataset)
|
||||
|
||||
|
||||
def download(package: str,
|
||||
configs: List[str],
|
||||
configs: Optional[List[str]] = None,
|
||||
dest_root: Optional[str] = None,
|
||||
check_certificate: bool = True) -> List[str]:
|
||||
check_certificate: bool = True,
|
||||
dataset: Optional[str] = None) -> Union[List[str], None]:
|
||||
"""Download checkpoints from url and parse configs from package.
|
||||
|
||||
Args:
|
||||
package (str): Name of package.
|
||||
configs (List[str]): List of config ids.
|
||||
dest_root (Optional[str]): Destination directory to save checkpoint and
|
||||
configs (List[str], optional): List of config ids.
|
||||
dest_root (str, optional): Destination directory to save checkpoint and
|
||||
config. Default: None.
|
||||
check_certificate (bool): Whether to check the ssl certificate.
|
||||
Default: True.
|
||||
dataset (str, optional): The name of dataset.
|
||||
"""
|
||||
full_name = module_full_name(package)
|
||||
if full_name == '':
|
||||
msg = f"Can't determine a unique package given abbreviation {package}"
|
||||
raise ValueError(highlighted_error(msg))
|
||||
package = full_name
|
||||
|
||||
if dest_root is None:
|
||||
dest_root = DEFAULT_CACHE_DIR
|
||||
|
||||
dest_root = osp.abspath(dest_root)
|
||||
|
||||
if configs is not None and dataset is not None:
|
||||
raise ValueError(
|
||||
'Cannot download config and dataset at the same time!')
|
||||
if configs is None and dataset is None:
|
||||
raise ValueError('Please specify config or dataset to download!')
|
||||
|
||||
if configs is not None:
|
||||
return _download_configs(package, configs, dest_root,
|
||||
check_certificate)
|
||||
else:
|
||||
return _download_dataset(package, dataset, dest_root) # type: ignore
|
||||
|
||||
|
||||
def _download_configs(package: str,
|
||||
configs: List[str],
|
||||
dest_root: str,
|
||||
check_certificate: bool = True) -> List[str]:
|
||||
# Create the destination directory if it does not exist.
|
||||
if not osp.exists(dest_root):
|
||||
os.makedirs(dest_root)
|
||||
os.makedirs(dest_root, exist_ok=True)
|
||||
|
||||
package, version = split_package_version(package)
|
||||
if version:
|
||||
|
@ -152,3 +189,74 @@ def download(package: str,
|
|||
highlighted_error(f'{config_path} is not found.'))
|
||||
|
||||
return checkpoints
|
||||
|
||||
|
||||
def _download_dataset(package: str, dataset: str, dest_root: str) -> None:
|
||||
if platform.system() != 'Linux':
|
||||
raise RuntimeError('downloading dataset is only supported in Linux!')
|
||||
|
||||
if not is_installed(package):
|
||||
raise RuntimeError(
|
||||
f'Please install {package} by `pip install {package}`')
|
||||
|
||||
installed_path = get_installed_path(package)
|
||||
mim_path = osp.join(installed_path, '.mim')
|
||||
dataset_index_path = osp.join(mim_path, 'dataset-index.yml')
|
||||
|
||||
if not osp.exists(dataset_index_path):
|
||||
raise FileNotFoundError(
|
||||
f'Cannot find {dataset_index_path}, '
|
||||
f'please update {package} to the latest version! If you have '
|
||||
f'already updated it and still get this error, please report an '
|
||||
f'issue to {package}')
|
||||
with open(dataset_index_path) as f:
|
||||
datasets_meta = yaml.load(f, Loader=yaml.SafeLoader)
|
||||
|
||||
if dataset not in datasets_meta:
|
||||
raise KeyError(f'Cannot find {dataset} in {dataset_index_path}. '
|
||||
'here are the available datasets: '
|
||||
'{}'.format('\n'.join(datasets_meta.keys())))
|
||||
dataset_meta = datasets_meta[dataset]
|
||||
|
||||
# OpenMMLab repo will define the `dataset-index.yml` like this:
|
||||
# voc2007:
|
||||
# dataset: PASCAL_VOC2007
|
||||
# download_root: data
|
||||
# data_root: data
|
||||
# script: tools/dataset_converters/scripts/preprocess_voc2007.sh
|
||||
|
||||
# In this case, the top level key "voc2007" means the "Dataset Name" passed
|
||||
# to `mim download --dataset {Dataset Name}`
|
||||
# The nested field "dataset" means the argument passed to `odl get`
|
||||
# If the value of "dataset" is the same as the "Dataset Name", downstream
|
||||
# repos can skip defining "dataset" and "Dataset Name" will be passed
|
||||
# to `odl get`
|
||||
src_name = dataset_meta.get('dataset', dataset)
|
||||
|
||||
# `odl get` will download raw dataset to `download_root`, and the script
|
||||
# will process the raws data and put the prepared data to the `data_root`
|
||||
download_root = dataset_meta['download_root']
|
||||
os.makedirs(download_root, exist_ok=True)
|
||||
|
||||
color_echo(f'Start downloading {dataset} to {download_root}...', 'blue')
|
||||
subprocess.check_call(['odl', 'get', src_name, '-d', download_root],
|
||||
stdin=sys.stdin,
|
||||
stdout=sys.stdout)
|
||||
|
||||
if not osp.exists(download_root):
|
||||
return
|
||||
|
||||
script_path = dataset_meta.get('script')
|
||||
if script_path is None:
|
||||
return
|
||||
|
||||
script_path = osp.join(mim_path, script_path)
|
||||
color_echo('Preprocess data ...', 'blue')
|
||||
if dest_root == osp.abspath(DEFAULT_CACHE_DIR):
|
||||
data_root = dataset_meta['data_root']
|
||||
else:
|
||||
data_root = dest_root
|
||||
os.makedirs(data_root, exist_ok=True)
|
||||
call_command(['chmod', '+x', script_path])
|
||||
call_command([script_path, download_root, data_root])
|
||||
echo_success('Finished!')
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
Click
|
||||
colorama
|
||||
model-index
|
||||
opendatalab
|
||||
pandas
|
||||
pip>=19.3
|
||||
requests
|
||||
|
|
Loading…
Reference in New Issue