From 0cd2878b048cacc85306ef02a5cb60a61de7f91b Mon Sep 17 00:00:00 2001 From: EnableAsync <43645467+EnableAsync@users.noreply.github.com> Date: Sat, 24 Jun 2023 00:10:54 +0800 Subject: [PATCH] [Feature] AWS S3 obtainer support (#1888) * feat: add aws s3 obtainer feat: add aws s3 obtainer fix: format fix: format * fix: avoid duplicated code fix: code format * fix: runtime.txt * fix: remove duplicated code --- .../datasets/preparers/obtainers/__init__.py | 3 +- .../preparers/obtainers/aws_s3_obtainer.py | 122 ++++++++++++++++++ requirements/optional.txt | 1 + 3 files changed, 125 insertions(+), 1 deletion(-) create mode 100644 mmocr/datasets/preparers/obtainers/aws_s3_obtainer.py diff --git a/mmocr/datasets/preparers/obtainers/__init__.py b/mmocr/datasets/preparers/obtainers/__init__.py index 55d484d9..a27fc7bc 100644 --- a/mmocr/datasets/preparers/obtainers/__init__.py +++ b/mmocr/datasets/preparers/obtainers/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .aws_s3_obtainer import AWSS3Obtainer from .naive_data_obtainer import NaiveDataObtainer -__all__ = ['NaiveDataObtainer'] +__all__ = ['NaiveDataObtainer', 'AWSS3Obtainer'] diff --git a/mmocr/datasets/preparers/obtainers/aws_s3_obtainer.py b/mmocr/datasets/preparers/obtainers/aws_s3_obtainer.py new file mode 100644 index 00000000..042778e5 --- /dev/null +++ b/mmocr/datasets/preparers/obtainers/aws_s3_obtainer.py @@ -0,0 +1,122 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import ssl +from typing import Dict, List, Optional + +from mmengine import mkdir_or_exist + +from mmocr.registry import DATA_OBTAINERS +from .naive_data_obtainer import NaiveDataObtainer + +ssl._create_default_https_context = ssl._create_unverified_context + + +@DATA_OBTAINERS.register_module() +class AWSS3Obtainer(NaiveDataObtainer): + """A AWS S3 obtainer. + + download -> extract -> move + + Args: + files (list[dict]): A list of file information. + cache_path (str): The path to cache the downloaded files. + data_root (str): The root path of the dataset. It is usually set auto- + matically and users do not need to set it manually in config file + in most cases. + task (str): The task of the dataset. It is usually set automatically + and users do not need to set it manually in config file + in most cases. + """ + + def __init__(self, files: List[Dict], cache_path: str, data_root: str, + task: str) -> None: + try: + import boto3 + from botocore import UNSIGNED + from botocore.config import Config + except ImportError: + raise ImportError( + 'Please install boto3 to download hiertext dataset.') + self.files = files + self.cache_path = cache_path + self.data_root = data_root + self.task = task + self.s3_client = boto3.client( + 's3', config=Config(signature_version=UNSIGNED)) + self.total_length = 0 + mkdir_or_exist(self.data_root) + mkdir_or_exist(osp.join(self.data_root, f'{task}_imgs')) + mkdir_or_exist(osp.join(self.data_root, 'annotations')) + mkdir_or_exist(self.cache_path) + + def find_bucket_key(self, s3_path: str): + """This is a helper function that given an s3 path such that the path + is of the form: bucket/key It will return the bucket and the key + represented by the s3 path. + + Args: + s3_path (str): The AWS s3 path. + """ + s3_components = s3_path.split('/', 1) + bucket = s3_components[0] + s3_key = '' + if len(s3_components) > 1: + s3_key = s3_components[1] + return bucket, s3_key + + def s3_download(self, s3_bucket: str, s3_object_key: str, dst_path: str): + """Download file from given s3 url with progress bar. + + Args: + s3_bucket (str): The s3 bucket to download the file. + s3_object_key (str): The s3 object key to download the file. + dst_path (str): The destination path to save the file. + """ + meta_data = self.s3_client.head_object( + Bucket=s3_bucket, Key=s3_object_key) + total_length = int(meta_data.get('ContentLength', 0)) + downloaded = 0 + + def progress(chunk): + nonlocal downloaded + downloaded += chunk + percent = min(100. * downloaded / total_length, 100) + file_name = osp.basename(dst_path) + print(f'\rDownloading {file_name}: {percent:.2f}%', end='') + + print(f'Downloading {dst_path}') + self.s3_client.download_file( + s3_bucket, s3_object_key, dst_path, Callback=progress) + + def download(self, url: Optional[str], dst_path: str) -> None: + """Download file from given url with progress bar. + + Args: + url (str): The url to download the file. + dst_path (str): The destination path to save the file. + """ + if url is None and not osp.exists(dst_path): + raise FileNotFoundError( + 'Direct url is not available for this dataset.' + ' Please manually download the required files' + ' following the guides.') + + if url.startswith('magnet'): + raise NotImplementedError('Please use any BitTorrent client to ' + 'download the following magnet link to ' + f'{osp.abspath(dst_path)} and ' + f'try again.\nLink: {url}') + + print('Downloading...') + print(f'URL: {url}') + print(f'Destination: {osp.abspath(dst_path)}') + print('If you stuck here for a long time, please check your network, ' + 'or manually download the file to the destination path and ' + 'run the script again.') + if url.startswith('s3://'): + url = url[5:] + bucket, key = self.find_bucket_key(url) + self.s3_download(bucket, key, osp.abspath(dst_path)) + elif url.startswith('https://') or url.startswith('http://'): + super().download(url, dst_path) + print('') diff --git a/requirements/optional.txt b/requirements/optional.txt index e69de29b..30ddf823 100644 --- a/requirements/optional.txt +++ b/requirements/optional.txt @@ -0,0 +1 @@ +boto3