mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Loading objects from different backends and dumping objects to different backends (#1330)
* [Feature] Choose storage backend by the prefix of filepath * refactor FileClient and add unittest * support loading from different backends * polish docstring * fix unittet * rename attribute str_like_obj to is_str_like_obj * add infer_client method * add check_exist method * rename var client to file_client * polish docstring * add join_paths method * remove join_paths and add _format_path * enhance unittest * refactor unittest * singleton pattern * fix test_clientio.py * deprecate CephBackend * enhance docstring * refactor unittest for petrel * refactor unittest for disk backend * update io.md * add concat_paths method * improve docstring * improve docstring * add isdir and copyfile for file backend * delete copyfile and add get_local_path * remove isdir method of petrel * fix typo * add comment and polish docstring * polish docstring * rename _path_mapping to _map_path * polish docstring and fix typo * refactor get_local_path * add list_dir_or_file for FileClient * add list_dir_or_file for PetrelBackend * fix windows ci * Add return docstring * polish docstring * fix typo * fix typo * deprecate the conversion from Path to str * add docs for loading checkpoints with FileClient * refactor map_path * add _ensure_methods to ensure methods have been implemented * fix list_dir_or_file * rename _ensure_method_implemented to has_methodpull/1430/head
parent
2d73eafec2
commit
5b5b47fc87
|
@ -2,11 +2,17 @@
|
|||
|
||||
This module provides two universal API to load and dump files of different formats.
|
||||
|
||||
```{note}
|
||||
Since v1.3.16, the IO modules support loading (dumping) data from (to) different backends, respectively. More details are in PR [#1330](https://github.com/open-mmlab/mmcv/pull/1330).
|
||||
```
|
||||
|
||||
### Load and dump data
|
||||
|
||||
`mmcv` provides a universal api for loading and dumping data, currently
|
||||
supported formats are json, yaml and pickle.
|
||||
|
||||
#### Load from disk or dump to disk
|
||||
|
||||
```python
|
||||
import mmcv
|
||||
|
||||
|
@ -29,6 +35,20 @@ with open('test.yaml', 'w') as f:
|
|||
data = mmcv.dump(data, f, file_format='yaml')
|
||||
```
|
||||
|
||||
#### Load from other backends or dump to other backends
|
||||
|
||||
```python
|
||||
import mmcv
|
||||
|
||||
# load data from a file
|
||||
data = mmcv.load('s3://bucket-name/test.json')
|
||||
data = mmcv.load('s3://bucket-name/test.yaml')
|
||||
data = mmcv.load('s3://bucket-name/test.pkl')
|
||||
|
||||
# dump data to a file with a filename (infer format from file extension)
|
||||
mmcv.dump(data, 's3://bucket-name/out.pkl')
|
||||
```
|
||||
|
||||
It is also very convenient to extend the api to support more file formats.
|
||||
All you need to do is to write a file handler inherited from `BaseFileHandler`
|
||||
and register it with one or several file formats.
|
||||
|
@ -92,7 +112,9 @@ d
|
|||
e
|
||||
```
|
||||
|
||||
Then use `list_from_file` to load the list from a.txt.
|
||||
#### Load from disk
|
||||
|
||||
Use `list_from_file` to load the list from a.txt.
|
||||
|
||||
```python
|
||||
>>> mmcv.list_from_file('a.txt')
|
||||
|
@ -113,7 +135,7 @@ For example `b.txt` is a text file with 3 lines.
|
|||
3 panda
|
||||
```
|
||||
|
||||
Then use `dict_from_file` to load the dict from `b.txt` .
|
||||
Then use `dict_from_file` to load the dict from `b.txt`.
|
||||
|
||||
```python
|
||||
>>> mmcv.dict_from_file('b.txt')
|
||||
|
@ -121,3 +143,105 @@ Then use `dict_from_file` to load the dict from `b.txt` .
|
|||
>>> mmcv.dict_from_file('b.txt', key_type=int)
|
||||
{1: 'cat', 2: ['dog', 'cow'], 3: 'panda'}
|
||||
```
|
||||
|
||||
#### Load from other backends
|
||||
|
||||
Use `list_from_file` to load the list from `s3://bucket-name/a.txt`.
|
||||
|
||||
```python
|
||||
>>> mmcv.list_from_file('s3://bucket-name/a.txt')
|
||||
['a', 'b', 'c', 'd', 'e']
|
||||
>>> mmcv.list_from_file('s3://bucket-name/a.txt', offset=2)
|
||||
['c', 'd', 'e']
|
||||
>>> mmcv.list_from_file('s3://bucket-name/a.txt', max_num=2)
|
||||
['a', 'b']
|
||||
>>> mmcv.list_from_file('s3://bucket-name/a.txt', prefix='/mnt/')
|
||||
['/mnt/a', '/mnt/b', '/mnt/c', '/mnt/d', '/mnt/e']
|
||||
```
|
||||
|
||||
Use `dict_from_file` to load the dict from `s3://bucket-name/b.txt`.
|
||||
|
||||
```python
|
||||
>>> mmcv.dict_from_file('s3://bucket-name/b.txt')
|
||||
{'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
|
||||
>>> mmcv.dict_from_file('s3://bucket-name/b.txt', key_type=int)
|
||||
{1: 'cat', 2: ['dog', 'cow'], 3: 'panda'}
|
||||
```
|
||||
|
||||
### Load and dump checkpoints
|
||||
|
||||
#### Load checkpoints from disk or save to disk
|
||||
|
||||
We can read the checkpoints from disk or save to disk in the following way.
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
filepath1 = '/path/of/your/checkpoint1.pth'
|
||||
filepath2 = '/path/of/your/checkpoint2.pth'
|
||||
# read from filepath1
|
||||
checkpoint = torch.load(filepath1)
|
||||
# save to filepath2
|
||||
torch.save(checkpoint, filepath2)
|
||||
```
|
||||
|
||||
MMCV provides many backends. `HardDiskBackend` is one of them and we can use it to read or save checkpoints.
|
||||
|
||||
```python
|
||||
import io
|
||||
from mmcv.fileio.file_client import HardDiskBackend
|
||||
|
||||
disk_backend = HardDiskBackend()
|
||||
with io.BytesIO(disk_backend.get(filepath1)) as buffer:
|
||||
checkpoint = torch.load(buffer)
|
||||
with io.BytesIO() as buffer:
|
||||
torch.save(checkpoint, f)
|
||||
disk_backend.put(f.getvalue(), filepath2)
|
||||
```
|
||||
|
||||
If we want to implement an interface which automatically select the corresponding
|
||||
backend based on the file path, we can use the `FileClient`.
|
||||
For example, we want to implement two methods for reading checkpoints as well as saving checkpoints,
|
||||
which need to support different types of file paths, either disk paths, network paths or other paths.
|
||||
|
||||
```python
|
||||
from mmcv.fileio.file_client import FileClient
|
||||
|
||||
def load_checkpoint(path):
|
||||
file_client = FileClient.infer(uri=path)
|
||||
with io.BytesIO(file_client.get(path)) as buffer:
|
||||
checkpoint = torch.load(buffer)
|
||||
return checkpoint
|
||||
|
||||
def save_checkpoint(checkpoint, path):
|
||||
with io.BytesIO() as buffer:
|
||||
torch.save(checkpoint, buffer)
|
||||
file_client.put(buffer.getvalue(), path)
|
||||
|
||||
file_client = FileClient.infer_client(uri=filepath1)
|
||||
checkpoint = load_checkpoint(filepath1)
|
||||
save_checkpoint(checkpoint, filepath2)
|
||||
```
|
||||
|
||||
#### Load checkpoints from the Internet
|
||||
|
||||
```{note}
|
||||
Currently, it only supports reading checkpoints from the Internet, and does not support saving checkpoints to the Internet.
|
||||
```
|
||||
|
||||
```python
|
||||
import io
|
||||
import torch
|
||||
from mmcv.fileio.file_client import HTTPBackend, FileClient
|
||||
|
||||
filepath = 'http://path/of/your/checkpoint.pth'
|
||||
checkpoint = torch.utils.model_zoo.load_url(filepath)
|
||||
|
||||
http_backend = HTTPBackend()
|
||||
with io.BytesIO(http_backend.get(filepath)) as buffer:
|
||||
checkpoint = torch.load(buffer)
|
||||
|
||||
file_client = FileClient.infer_client(uri=filepath)
|
||||
with io.BytesIO(file_client.get(filepath)) as buffer:
|
||||
checkpoint = torch.load(buffer)
|
||||
```
|
||||
|
|
|
@ -2,10 +2,16 @@
|
|||
|
||||
文件输入输出模块提供了两个通用的 API 接口用于读取和保存不同格式的文件。
|
||||
|
||||
```{note}
|
||||
在 v1.3.16 及之后的版本中,IO 模块支持从不同后端读取数据并支持将数据至不同后端。更多细节请访问 PR [#1330](https://github.com/open-mmlab/mmcv/pull/1330)。
|
||||
```
|
||||
|
||||
### 读取和保存数据
|
||||
|
||||
`mmcv` 提供了一个通用的 api 用于读取和保存数据,目前支持的格式有 json、yaml 和 pickle。
|
||||
|
||||
#### 从硬盘读取数据或者将数据保存至硬盘
|
||||
|
||||
```python
|
||||
import mmcv
|
||||
|
||||
|
@ -28,6 +34,20 @@ with open('test.yaml', 'w') as f:
|
|||
data = mmcv.dump(data, f, file_format='yaml')
|
||||
```
|
||||
|
||||
#### 从其他后端加载或者保存至其他后端
|
||||
|
||||
```python
|
||||
import mmcv
|
||||
|
||||
# 从 s3 文件读取数据
|
||||
data = mmcv.load('s3://bucket-name/test.json')
|
||||
data = mmcv.load('s3://bucket-name/test.yaml')
|
||||
data = mmcv.load('s3://bucket-name/test.pkl')
|
||||
|
||||
# 将数据保存至 s3 文件 (根据文件名后缀反推文件类型)
|
||||
mmcv.dump(data, 's3://bucket-name/out.pkl')
|
||||
```
|
||||
|
||||
我们提供了易于拓展的方式以支持更多的文件格式。我们只需要创建一个继承自 `BaseFileHandler` 的
|
||||
文件句柄类并将其注册到 `mmcv` 中即可。句柄类至少需要重写三个方法。
|
||||
|
||||
|
@ -49,7 +69,7 @@ class TxtHandler1(mmcv.BaseFileHandler):
|
|||
return str(obj)
|
||||
```
|
||||
|
||||
举 `PickleHandler` 为例。
|
||||
以 `PickleHandler` 为例
|
||||
|
||||
```python
|
||||
import pickle
|
||||
|
@ -87,8 +107,9 @@ c
|
|||
d
|
||||
e
|
||||
```
|
||||
#### 从硬盘读取
|
||||
|
||||
使用 `list_from_file` 读取 `a.txt` 。
|
||||
使用 `list_from_file` 读取 `a.txt`
|
||||
|
||||
```python
|
||||
>>> mmcv.list_from_file('a.txt')
|
||||
|
@ -101,7 +122,7 @@ e
|
|||
['/mnt/a', '/mnt/b', '/mnt/c', '/mnt/d', '/mnt/e']
|
||||
```
|
||||
|
||||
同样, `b.txt` 也是文本文件,一共有3行内容。
|
||||
同样, `b.txt` 也是文本文件,一共有3行内容
|
||||
|
||||
```
|
||||
1 cat
|
||||
|
@ -109,7 +130,7 @@ e
|
|||
3 panda
|
||||
```
|
||||
|
||||
使用 `dict_from_file` 读取 `b.txt` 。
|
||||
使用 `dict_from_file` 读取 `b.txt`
|
||||
|
||||
```python
|
||||
>>> mmcv.dict_from_file('b.txt')
|
||||
|
@ -117,3 +138,103 @@ e
|
|||
>>> mmcv.dict_from_file('b.txt', key_type=int)
|
||||
{1: 'cat', 2: ['dog', 'cow'], 3: 'panda'}
|
||||
```
|
||||
|
||||
#### 从其他后端读取
|
||||
|
||||
使用 `list_from_file` 读取 `s3://bucket-name/a.txt`
|
||||
|
||||
```python
|
||||
>>> mmcv.list_from_file('s3://bucket-name/a.txt')
|
||||
['a', 'b', 'c', 'd', 'e']
|
||||
>>> mmcv.list_from_file('s3://bucket-name/a.txt', offset=2)
|
||||
['c', 'd', 'e']
|
||||
>>> mmcv.list_from_file('s3://bucket-name/a.txt', max_num=2)
|
||||
['a', 'b']
|
||||
>>> mmcv.list_from_file('s3://bucket-name/a.txt', prefix='/mnt/')
|
||||
['/mnt/a', '/mnt/b', '/mnt/c', '/mnt/d', '/mnt/e']
|
||||
```
|
||||
|
||||
使用 `dict_from_file` 读取 `b.txt`
|
||||
|
||||
```python
|
||||
>>> mmcv.dict_from_file('s3://bucket-name/b.txt')
|
||||
{'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
|
||||
>>> mmcv.dict_from_file('s3://bucket-name/b.txt', key_type=int)
|
||||
{1: 'cat', 2: ['dog', 'cow'], 3: 'panda'}
|
||||
```
|
||||
|
||||
### 读取和保存权重文件
|
||||
|
||||
#### 从硬盘读取权重文件或者将权重文件保存至硬盘
|
||||
|
||||
我们可以通过下面的方式从磁盘读取权重文件或者将权重文件保存至磁盘
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
filepath1 = '/path/of/your/checkpoint1.pth'
|
||||
filepath2 = '/path/of/your/checkpoint2.pth'
|
||||
# 从 filepath1 读取权重文件
|
||||
checkpoint = torch.load(filepath1)
|
||||
# 将权重文件保存至 filepath2
|
||||
torch.save(checkpoint, filepath2)
|
||||
```
|
||||
|
||||
MMCV 提供了很多后端,`HardDiskBackend` 是其中一个,我们可以通过它来读取或者保存权重文件。
|
||||
|
||||
```python
|
||||
import io
|
||||
from mmcv.fileio.file_client import HardDiskBackend
|
||||
|
||||
disk_backend = HardDiskBackend()
|
||||
with io.BytesIO(disk_backend.get(filepath1)) as buffer:
|
||||
checkpoint = torch.load(buffer)
|
||||
with io.BytesIO() as buffer:
|
||||
torch.save(checkpoint, f)
|
||||
disk_backend.put(f.getvalue(), filepath2)
|
||||
```
|
||||
|
||||
如果我们想在接口中实现根据文件路径自动选择对应的后端,我们可以使用 `FileClient`。
|
||||
例如,我们想实现两个方法,分别是读取权重以及保存权重,它们需支持不同类型的文件路径,可以是磁盘路径,也可以是网络路径或者其他路径。
|
||||
|
||||
```python
|
||||
from mmcv.fileio.file_client import FileClient
|
||||
|
||||
def load_checkpoint(path):
|
||||
file_client = FileClient.infer(uri=path)
|
||||
with io.BytesIO(file_client.get(path)) as buffer:
|
||||
checkpoint = torch.load(buffer)
|
||||
return checkpoint
|
||||
|
||||
def save_checkpoint(checkpoint, path):
|
||||
with io.BytesIO() as buffer:
|
||||
torch.save(checkpoint, buffer)
|
||||
file_client.put(buffer.getvalue(), path)
|
||||
|
||||
file_client = FileClient.infer_client(uri=filepath1)
|
||||
checkpoint = load_checkpoint(filepath1)
|
||||
save_checkpoint(checkpoint, filepath2)
|
||||
```
|
||||
|
||||
#### 从网络远端读取权重文件
|
||||
|
||||
```{note}
|
||||
目前只支持从网络远端读取权重文件,暂不支持将权重文件写入网络远端
|
||||
```
|
||||
|
||||
```python
|
||||
import io
|
||||
import torch
|
||||
from mmcv.fileio.file_client import HTTPBackend, FileClient
|
||||
|
||||
filepath = 'http://path/of/your/checkpoint.pth'
|
||||
checkpoint = torch.utils.model_zoo.load_url(filepath)
|
||||
|
||||
http_backend = HTTPBackend()
|
||||
with io.BytesIO(http_backend.get(filepath)) as buffer:
|
||||
checkpoint = torch.load(buffer)
|
||||
|
||||
file_client = FileClient.infer_client(uri=filepath)
|
||||
with io.BytesIO(file_client.get(filepath)) as buffer:
|
||||
checkpoint = torch.load(buffer)
|
||||
```
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -3,6 +3,12 @@ from abc import ABCMeta, abstractmethod
|
|||
|
||||
|
||||
class BaseFileHandler(metaclass=ABCMeta):
|
||||
# `str_like` is a flag to indicate whether the type of file object is
|
||||
# str-like object or bytes-like object. Pickle only processes bytes-like
|
||||
# objects but json only processes str-like object. If it is str-like
|
||||
# object, `StringIO` will be used to process the buffer.
|
||||
|
||||
str_like = True
|
||||
|
||||
@abstractmethod
|
||||
def load_from_fileobj(self, file, **kwargs):
|
||||
|
|
|
@ -6,6 +6,8 @@ from .base import BaseFileHandler
|
|||
|
||||
class PickleHandler(BaseFileHandler):
|
||||
|
||||
str_like = False
|
||||
|
||||
def load_from_fileobj(self, file, **kwargs):
|
||||
return pickle.load(file, **kwargs)
|
||||
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from io import BytesIO, StringIO
|
||||
from pathlib import Path
|
||||
|
||||
from ..utils import is_list_of, is_str
|
||||
from .file_client import FileClient
|
||||
from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
|
||||
|
||||
file_handlers = {
|
||||
|
@ -13,11 +15,15 @@ file_handlers = {
|
|||
}
|
||||
|
||||
|
||||
def load(file, file_format=None, **kwargs):
|
||||
def load(file, file_format=None, file_client_args=None, **kwargs):
|
||||
"""Load data from json/yaml/pickle files.
|
||||
|
||||
This method provides a unified api for loading data from serialized files.
|
||||
|
||||
Note:
|
||||
In v1.3.16 and later, ``load`` supports loading data from serialized
|
||||
files those can be storaged in different backends.
|
||||
|
||||
Args:
|
||||
file (str or :obj:`Path` or file-like object): Filename or a file-like
|
||||
object.
|
||||
|
@ -25,6 +31,14 @@ def load(file, file_format=None, **kwargs):
|
|||
inferred from the file extension, otherwise use the specified one.
|
||||
Currently supported formats include "json", "yaml/yml" and
|
||||
"pickle/pkl".
|
||||
file_client_args (dict, optional): Arguments to instantiate a
|
||||
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
||||
Default: None.
|
||||
|
||||
Examples:
|
||||
>>> load('/path/of/your/file') # file is storaged in disk
|
||||
>>> load('https://path/of/your/file') # file is storaged in Internet
|
||||
>>> load('s3://path/of/your/file') # file is storaged in petrel
|
||||
|
||||
Returns:
|
||||
The content from the file.
|
||||
|
@ -38,7 +52,13 @@ def load(file, file_format=None, **kwargs):
|
|||
|
||||
handler = file_handlers[file_format]
|
||||
if is_str(file):
|
||||
obj = handler.load_from_path(file, **kwargs)
|
||||
file_client = FileClient.infer_client(file_client_args, file)
|
||||
if handler.str_like:
|
||||
with StringIO(file_client.get_text(file)) as f:
|
||||
obj = handler.load_from_fileobj(f, **kwargs)
|
||||
else:
|
||||
with BytesIO(file_client.get(file)) as f:
|
||||
obj = handler.load_from_fileobj(f, **kwargs)
|
||||
elif hasattr(file, 'read'):
|
||||
obj = handler.load_from_fileobj(file, **kwargs)
|
||||
else:
|
||||
|
@ -46,18 +66,29 @@ def load(file, file_format=None, **kwargs):
|
|||
return obj
|
||||
|
||||
|
||||
def dump(obj, file=None, file_format=None, **kwargs):
|
||||
def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs):
|
||||
"""Dump data to json/yaml/pickle strings or files.
|
||||
|
||||
This method provides a unified api for dumping data as strings or to files,
|
||||
and also supports custom arguments for each file format.
|
||||
|
||||
Note:
|
||||
In v1.3.16 and later, ``dump`` supports dumping data as strings or to
|
||||
files which is saved to different backends.
|
||||
|
||||
Args:
|
||||
obj (any): The python object to be dumped.
|
||||
file (str or :obj:`Path` or file-like object, optional): If not
|
||||
specified, then the object is dump to a str, otherwise to a file
|
||||
specified, then the object is dumped to a str, otherwise to a file
|
||||
specified by the filename or file-like object.
|
||||
file_format (str, optional): Same as :func:`load`.
|
||||
file_client_args (dict, optional): Arguments to instantiate a
|
||||
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
||||
Default: None.
|
||||
|
||||
Examples:
|
||||
>>> dump('hello world', '/path/of/your/file') # disk
|
||||
>>> dump('hello world', 's3://path/of/your/file') # ceph or petrel
|
||||
|
||||
Returns:
|
||||
bool: True for success, False otherwise.
|
||||
|
@ -77,7 +108,15 @@ def dump(obj, file=None, file_format=None, **kwargs):
|
|||
if file is None:
|
||||
return handler.dump_to_str(obj, **kwargs)
|
||||
elif is_str(file):
|
||||
handler.dump_to_path(obj, file, **kwargs)
|
||||
file_client = FileClient.infer_client(file_client_args, file)
|
||||
if handler.str_like:
|
||||
with StringIO() as f:
|
||||
handler.dump_to_fileobj(obj, f, **kwargs)
|
||||
file_client.put_text(f.getvalue(), file)
|
||||
else:
|
||||
with BytesIO() as f:
|
||||
handler.dump_to_fileobj(obj, f, **kwargs)
|
||||
file_client.put(f.getvalue(), file)
|
||||
elif hasattr(file, 'write'):
|
||||
handler.dump_to_fileobj(obj, file, **kwargs)
|
||||
else:
|
||||
|
|
|
@ -1,7 +1,23 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
def list_from_file(filename, prefix='', offset=0, max_num=0, encoding='utf-8'):
|
||||
|
||||
from io import StringIO
|
||||
|
||||
from .file_client import FileClient
|
||||
|
||||
|
||||
def list_from_file(filename,
|
||||
prefix='',
|
||||
offset=0,
|
||||
max_num=0,
|
||||
encoding='utf-8',
|
||||
file_client_args=None):
|
||||
"""Load a text file and parse the content as a list of strings.
|
||||
|
||||
Note:
|
||||
In v1.3.16 and later, ``list_from_file`` supports loading a text file
|
||||
which can be storaged in different backends and parsing the content as
|
||||
a list for strings.
|
||||
|
||||
Args:
|
||||
filename (str): Filename.
|
||||
prefix (str): The prefix to be inserted to the beginning of each item.
|
||||
|
@ -9,13 +25,23 @@ def list_from_file(filename, prefix='', offset=0, max_num=0, encoding='utf-8'):
|
|||
max_num (int): The maximum number of lines to be read,
|
||||
zeros and negatives mean no limitation.
|
||||
encoding (str): Encoding used to open the file. Default utf-8.
|
||||
file_client_args (dict, optional): Arguments to instantiate a
|
||||
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
||||
Default: None.
|
||||
|
||||
Examples:
|
||||
>>> list_from_file('/path/of/your/file') # disk
|
||||
['hello', 'world']
|
||||
>>> list_from_file('s3://path/of/your/file') # ceph or petrel
|
||||
['hello', 'world']
|
||||
|
||||
Returns:
|
||||
list[str]: A list of strings.
|
||||
"""
|
||||
cnt = 0
|
||||
item_list = []
|
||||
with open(filename, 'r', encoding=encoding) as f:
|
||||
file_client = FileClient.infer_client(file_client_args, filename)
|
||||
with StringIO(file_client.get_text(filename, encoding)) as f:
|
||||
for _ in range(offset):
|
||||
f.readline()
|
||||
for line in f:
|
||||
|
@ -26,23 +52,42 @@ def list_from_file(filename, prefix='', offset=0, max_num=0, encoding='utf-8'):
|
|||
return item_list
|
||||
|
||||
|
||||
def dict_from_file(filename, key_type=str):
|
||||
def dict_from_file(filename,
|
||||
key_type=str,
|
||||
encoding='utf-8',
|
||||
file_client_args=None):
|
||||
"""Load a text file and parse the content as a dict.
|
||||
|
||||
Each line of the text file will be two or more columns split by
|
||||
whitespaces or tabs. The first column will be parsed as dict keys, and
|
||||
the following columns will be parsed as dict values.
|
||||
|
||||
Note:
|
||||
In v1.3.16 and later, ``dict_from_file`` supports loading a text file
|
||||
which can be storaged in different backends and parsing the content as
|
||||
a dict.
|
||||
|
||||
Args:
|
||||
filename(str): Filename.
|
||||
key_type(type): Type of the dict keys. str is user by default and
|
||||
type conversion will be performed if specified.
|
||||
encoding (str): Encoding used to open the file. Default utf-8.
|
||||
file_client_args (dict, optional): Arguments to instantiate a
|
||||
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
||||
Default: None.
|
||||
|
||||
Examples:
|
||||
>>> dict_from_file('/path/of/your/file') # disk
|
||||
{'key1': 'value1', 'key2': 'value2'}
|
||||
>>> dict_from_file('s3://path/of/your/file') # ceph or petrel
|
||||
{'key1': 'value1', 'key2': 'value2'}
|
||||
|
||||
Returns:
|
||||
dict: The parsed contents.
|
||||
"""
|
||||
mapping = {}
|
||||
with open(filename, 'r') as f:
|
||||
file_client = FileClient.infer_client(file_client_args, filename)
|
||||
with StringIO(file_client.get_text(filename, encoding)) as f:
|
||||
for line in f:
|
||||
items = line.rstrip('\n').split()
|
||||
assert len(items) >= 2
|
||||
|
|
|
@ -309,9 +309,9 @@ def adjust_sharpness(img, factor=1., kernel=None):
|
|||
kernel (np.ndarray, optional): Filter kernel to be applied on the img
|
||||
to obtain the degenerated img. Defaults to None.
|
||||
|
||||
Notes::
|
||||
Note:
|
||||
No value sanity check is enforced on the kernel set by users. So with
|
||||
an inappropriate kernel, the `adjust_sharpness` may fail to perform
|
||||
an inappropriate kernel, the ``adjust_sharpness`` may fail to perform
|
||||
the function its name indicates but end up performing whatever
|
||||
transform determined by the kernel.
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .config import Config, ConfigDict, DictAction
|
||||
from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
|
||||
import_modules_from_strings, is_list_of,
|
||||
has_method, import_modules_from_strings, is_list_of,
|
||||
is_method_overridden, is_seq_of, is_str, is_tuple_of,
|
||||
iter_cast, list_cast, requires_executable, requires_package,
|
||||
slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple,
|
||||
|
@ -33,7 +33,7 @@ except ImportError:
|
|||
'assert_dict_contains_subset', 'assert_attrs_equal',
|
||||
'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script',
|
||||
'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple',
|
||||
'is_method_overridden'
|
||||
'is_method_overridden', 'has_method'
|
||||
]
|
||||
else:
|
||||
from .env import collect_env
|
||||
|
@ -65,5 +65,5 @@ else:
|
|||
'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer',
|
||||
'assert_params_all_zeros', 'check_python_script',
|
||||
'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch',
|
||||
'_get_cuda_home'
|
||||
'_get_cuda_home', 'has_method'
|
||||
]
|
||||
|
|
|
@ -362,3 +362,16 @@ def is_method_overridden(method, base_class, derived_class):
|
|||
base_method = getattr(base_class, method)
|
||||
derived_method = getattr(derived_class, method)
|
||||
return derived_method != base_method
|
||||
|
||||
|
||||
def has_method(obj: object, method: str) -> bool:
|
||||
"""Check whether the object has a method.
|
||||
|
||||
Args:
|
||||
method (str): The method name to check.
|
||||
obj (object): The object to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the object has the method else False.
|
||||
"""
|
||||
return hasattr(obj, method) and callable(getattr(obj, method))
|
||||
|
|
|
@ -1,4 +1,9 @@
|
|||
import os
|
||||
import os.path as osp
|
||||
import sys
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
@ -6,6 +11,7 @@ import pytest
|
|||
|
||||
import mmcv
|
||||
from mmcv import BaseStorageBackend, FileClient
|
||||
from mmcv.utils import has_method
|
||||
|
||||
sys.modules['ceph'] = MagicMock()
|
||||
sys.modules['petrel_client'] = MagicMock()
|
||||
|
@ -13,6 +19,51 @@ sys.modules['petrel_client.client'] = MagicMock()
|
|||
sys.modules['mc'] = MagicMock()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def build_temporary_directory():
|
||||
"""Build a temporary directory containing many files to test
|
||||
``FileClient.list_dir_or_file``.
|
||||
|
||||
. \n
|
||||
| -- dir1 \n
|
||||
| -- | -- text3.txt \n
|
||||
| -- dir2 \n
|
||||
| -- | -- dir3 \n
|
||||
| -- | -- | -- text4.txt \n
|
||||
| -- | -- img.jpg \n
|
||||
| -- text1.txt \n
|
||||
| -- text2.txt \n
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
text1 = Path(tmp_dir) / 'text1.txt'
|
||||
text1.open('w').write('text1')
|
||||
text2 = Path(tmp_dir) / 'text2.txt'
|
||||
text2.open('w').write('text2')
|
||||
dir1 = Path(tmp_dir) / 'dir1'
|
||||
dir1.mkdir()
|
||||
text3 = dir1 / 'text3.txt'
|
||||
text3.open('w').write('text3')
|
||||
dir2 = Path(tmp_dir) / 'dir2'
|
||||
dir2.mkdir()
|
||||
jpg1 = dir2 / 'img.jpg'
|
||||
jpg1.open('wb').write(b'img')
|
||||
dir3 = dir2 / 'dir3'
|
||||
dir3.mkdir()
|
||||
text4 = dir3 / 'text4.txt'
|
||||
text4.open('w').write('text4')
|
||||
yield tmp_dir
|
||||
|
||||
|
||||
@contextmanager
|
||||
def delete_and_reset_method(obj, method):
|
||||
method_obj = deepcopy(getattr(type(obj), method))
|
||||
try:
|
||||
delattr(type(obj), method)
|
||||
yield
|
||||
finally:
|
||||
setattr(type(obj), method, method_obj)
|
||||
|
||||
|
||||
class MockS3Client:
|
||||
|
||||
def __init__(self, enable_mc=True):
|
||||
|
@ -24,6 +75,37 @@ class MockS3Client:
|
|||
return content
|
||||
|
||||
|
||||
class MockPetrelClient:
|
||||
|
||||
def __init__(self, enable_mc=True, enable_multi_cluster=False):
|
||||
self.enable_mc = enable_mc
|
||||
self.enable_multi_cluster = enable_multi_cluster
|
||||
|
||||
def Get(self, filepath):
|
||||
with open(filepath, 'rb') as f:
|
||||
content = f.read()
|
||||
return content
|
||||
|
||||
def put(self):
|
||||
pass
|
||||
|
||||
def delete(self):
|
||||
pass
|
||||
|
||||
def contains(self):
|
||||
pass
|
||||
|
||||
def isdir(self):
|
||||
pass
|
||||
|
||||
def list(self, dir_path):
|
||||
for entry in os.scandir(dir_path):
|
||||
if not entry.name.startswith('.') and entry.is_file():
|
||||
yield entry.name
|
||||
elif osp.isdir(entry.path):
|
||||
yield entry.name + '/'
|
||||
|
||||
|
||||
class MockMemcachedClient:
|
||||
|
||||
def __init__(self, server_list_cfg, client_cfg):
|
||||
|
@ -50,6 +132,7 @@ class TestFileClient:
|
|||
def test_disk_backend(self):
|
||||
disk_backend = FileClient('disk')
|
||||
|
||||
# test `get`
|
||||
# input path is Path object
|
||||
img_bytes = disk_backend.get(self.img_path)
|
||||
img = mmcv.imfrombytes(img_bytes)
|
||||
|
@ -61,6 +144,7 @@ class TestFileClient:
|
|||
assert self.img_path.open('rb').read() == img_bytes
|
||||
assert img.shape == self.img_shape
|
||||
|
||||
# test `get_text`
|
||||
# input path is Path object
|
||||
value_buf = disk_backend.get_text(self.text_path)
|
||||
assert self.text_path.open('r').read() == value_buf
|
||||
|
@ -68,6 +152,118 @@ class TestFileClient:
|
|||
value_buf = disk_backend.get_text(str(self.text_path))
|
||||
assert self.text_path.open('r').read() == value_buf
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# test `put`
|
||||
filepath1 = Path(tmp_dir) / 'test.jpg'
|
||||
disk_backend.put(b'disk', filepath1)
|
||||
assert filepath1.open('rb').read() == b'disk'
|
||||
|
||||
# test `put_text`
|
||||
filepath2 = Path(tmp_dir) / 'test.txt'
|
||||
disk_backend.put_text('disk', filepath2)
|
||||
assert filepath2.open('r').read() == 'disk'
|
||||
|
||||
# test `isfile`
|
||||
assert disk_backend.isfile(filepath2)
|
||||
assert not disk_backend.isfile(Path(tmp_dir) / 'not/existed/path')
|
||||
|
||||
# test `remove`
|
||||
disk_backend.remove(filepath2)
|
||||
|
||||
# test `exists`
|
||||
assert not disk_backend.exists(filepath2)
|
||||
|
||||
# test `get_local_path`
|
||||
# if the backend is disk, `get_local_path` just return the input
|
||||
with disk_backend.get_local_path(filepath1) as path:
|
||||
assert str(filepath1) == path
|
||||
assert osp.isfile(filepath1)
|
||||
|
||||
# test `concat_paths`
|
||||
disk_dir = '/path/of/your/directory'
|
||||
assert disk_backend.concat_paths(disk_dir, 'file') == \
|
||||
osp.join(disk_dir, 'file')
|
||||
assert disk_backend.concat_paths(disk_dir, 'dir', 'file') == \
|
||||
osp.join(disk_dir, 'dir', 'file')
|
||||
|
||||
# test `list_dir_or_file`
|
||||
with build_temporary_directory() as tmp_dir:
|
||||
# 1. list directories and files
|
||||
assert set(disk_backend.list_dir_or_file(tmp_dir)) == set(
|
||||
['dir1', 'dir2', 'text1.txt', 'text2.txt'])
|
||||
# 2. list directories and files recursively
|
||||
assert set(disk_backend.list_dir_or_file(
|
||||
tmp_dir, recursive=True)) == set([
|
||||
'dir1',
|
||||
osp.join('dir1', 'text3.txt'), 'dir2',
|
||||
osp.join('dir2', 'dir3'),
|
||||
osp.join('dir2', 'dir3', 'text4.txt'),
|
||||
osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt'
|
||||
])
|
||||
# 3. only list directories
|
||||
assert set(
|
||||
disk_backend.list_dir_or_file(
|
||||
tmp_dir, list_file=False)) == set(['dir1', 'dir2'])
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match='`suffix` should be None when `list_dir` is True'):
|
||||
# Exception is raised among the `list_dir_or_file` of client,
|
||||
# so we need to invode the client to trigger the exception
|
||||
disk_backend.client.list_dir_or_file(
|
||||
tmp_dir, list_file=False, suffix='.txt')
|
||||
# 4. only list directories recursively
|
||||
assert set(
|
||||
disk_backend.list_dir_or_file(
|
||||
tmp_dir, list_file=False, recursive=True)) == set(
|
||||
['dir1', 'dir2',
|
||||
osp.join('dir2', 'dir3')])
|
||||
# 5. only list files
|
||||
assert set(disk_backend.list_dir_or_file(
|
||||
tmp_dir, list_dir=False)) == set(['text1.txt', 'text2.txt'])
|
||||
# 6. only list files recursively
|
||||
assert set(
|
||||
disk_backend.list_dir_or_file(
|
||||
tmp_dir, list_dir=False, recursive=True)) == set([
|
||||
osp.join('dir1', 'text3.txt'),
|
||||
osp.join('dir2', 'dir3', 'text4.txt'),
|
||||
osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt'
|
||||
])
|
||||
# 7. only list files ending with suffix
|
||||
assert set(
|
||||
disk_backend.list_dir_or_file(
|
||||
tmp_dir, list_dir=False,
|
||||
suffix='.txt')) == set(['text1.txt', 'text2.txt'])
|
||||
assert set(
|
||||
disk_backend.list_dir_or_file(
|
||||
tmp_dir, list_dir=False,
|
||||
suffix=('.txt',
|
||||
'.jpg'))) == set(['text1.txt', 'text2.txt'])
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match='`suffix` must be a string or tuple of strings'):
|
||||
disk_backend.client.list_dir_or_file(
|
||||
tmp_dir, list_dir=False, suffix=['.txt', '.jpg'])
|
||||
# 8. only list files ending with suffix recursively
|
||||
assert set(
|
||||
disk_backend.list_dir_or_file(
|
||||
tmp_dir, list_dir=False, suffix='.txt',
|
||||
recursive=True)) == set([
|
||||
osp.join('dir1', 'text3.txt'),
|
||||
osp.join('dir2', 'dir3', 'text4.txt'), 'text1.txt',
|
||||
'text2.txt'
|
||||
])
|
||||
# 7. only list files ending with suffix
|
||||
assert set(
|
||||
disk_backend.list_dir_or_file(
|
||||
tmp_dir,
|
||||
list_dir=False,
|
||||
suffix=('.txt', '.jpg'),
|
||||
recursive=True)) == set([
|
||||
osp.join('dir1', 'text3.txt'),
|
||||
osp.join('dir2', 'dir3', 'text4.txt'),
|
||||
osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt'
|
||||
])
|
||||
|
||||
@patch('ceph.S3Client', MockS3Client)
|
||||
def test_ceph_backend(self):
|
||||
ceph_backend = FileClient('ceph')
|
||||
|
@ -103,16 +299,11 @@ class TestFileClient:
|
|||
ceph_backend.client._client.Get.assert_called_with(
|
||||
str(self.img_path).replace(str(self.test_data_dir), ceph_path))
|
||||
|
||||
@patch('petrel_client.client.Client', MockS3Client)
|
||||
def test_petrel_backend(self):
|
||||
petrel_backend = FileClient('petrel')
|
||||
|
||||
# input path is Path object
|
||||
with pytest.raises(NotImplementedError):
|
||||
petrel_backend.get_text(self.text_path)
|
||||
# input path is str
|
||||
with pytest.raises(NotImplementedError):
|
||||
petrel_backend.get_text(str(self.text_path))
|
||||
@patch('petrel_client.client.Client', MockPetrelClient)
|
||||
@pytest.mark.parametrize('backend,prefix', [('petrel', None),
|
||||
(None, 's3')])
|
||||
def test_petrel_backend(self, backend, prefix):
|
||||
petrel_backend = FileClient(backend=backend, prefix=prefix)
|
||||
|
||||
# input path is Path object
|
||||
img_bytes = petrel_backend.get(self.img_path)
|
||||
|
@ -126,17 +317,209 @@ class TestFileClient:
|
|||
# `path_mapping` is either None or dict
|
||||
with pytest.raises(AssertionError):
|
||||
FileClient('petrel', path_mapping=1)
|
||||
# test `path_mapping`
|
||||
petrel_path = 's3://user/data'
|
||||
|
||||
# test `_map_path`
|
||||
petrel_dir = 's3://user/data'
|
||||
petrel_backend = FileClient(
|
||||
'petrel', path_mapping={str(self.test_data_dir): petrel_path})
|
||||
petrel_backend.client._client.Get = MagicMock(
|
||||
return_value=petrel_backend.client._client.Get(self.img_path))
|
||||
img_bytes = petrel_backend.get(self.img_path)
|
||||
img = mmcv.imfrombytes(img_bytes)
|
||||
assert img.shape == self.img_shape
|
||||
petrel_backend.client._client.Get.assert_called_with(
|
||||
str(self.img_path).replace(str(self.test_data_dir), petrel_path))
|
||||
'petrel', path_mapping={str(self.test_data_dir): petrel_dir})
|
||||
assert petrel_backend.client._map_path(str(self.img_path)) == \
|
||||
str(self.img_path).replace(str(self.test_data_dir), petrel_dir)
|
||||
|
||||
petrel_path = f'{petrel_dir}/test.jpg'
|
||||
petrel_backend = FileClient('petrel')
|
||||
|
||||
# test `_format_path`
|
||||
assert petrel_backend.client._format_path('s3://user\\data\\test.jpg')\
|
||||
== petrel_path
|
||||
|
||||
# test `get`
|
||||
with patch.object(
|
||||
petrel_backend.client._client, 'Get',
|
||||
return_value=b'petrel') as mock_get:
|
||||
assert petrel_backend.get(petrel_path) == b'petrel'
|
||||
mock_get.assert_called_once_with(petrel_path)
|
||||
|
||||
# test `get_text`
|
||||
with patch.object(
|
||||
petrel_backend.client._client, 'Get',
|
||||
return_value=b'petrel') as mock_get:
|
||||
assert petrel_backend.get_text(petrel_path) == 'petrel'
|
||||
mock_get.assert_called_once_with(petrel_path)
|
||||
|
||||
# test `put`
|
||||
with patch.object(petrel_backend.client._client, 'put') as mock_put:
|
||||
petrel_backend.put(b'petrel', petrel_path)
|
||||
mock_put.assert_called_once_with(petrel_path, b'petrel')
|
||||
|
||||
# test `put_text`
|
||||
with patch.object(petrel_backend.client._client, 'put') as mock_put:
|
||||
petrel_backend.put_text('petrel', petrel_path)
|
||||
mock_put.assert_called_once_with(petrel_path, b'petrel')
|
||||
|
||||
# test `remove`
|
||||
assert has_method(petrel_backend.client._client, 'delete')
|
||||
# raise Exception if `delete` is not implemented
|
||||
with delete_and_reset_method(petrel_backend.client._client, 'delete'):
|
||||
assert not has_method(petrel_backend.client._client, 'delete')
|
||||
with pytest.raises(NotImplementedError):
|
||||
petrel_backend.remove(petrel_path)
|
||||
|
||||
with patch.object(petrel_backend.client._client,
|
||||
'delete') as mock_delete:
|
||||
petrel_backend.remove(petrel_path)
|
||||
mock_delete.assert_called_once_with(petrel_path)
|
||||
|
||||
# test `exists`
|
||||
assert has_method(petrel_backend.client._client, 'contains')
|
||||
assert has_method(petrel_backend.client._client, 'isdir')
|
||||
# raise Exception if `delete` is not implemented
|
||||
with delete_and_reset_method(petrel_backend.client._client,
|
||||
'contains'), delete_and_reset_method(
|
||||
petrel_backend.client._client,
|
||||
'isdir'):
|
||||
assert not has_method(petrel_backend.client._client, 'contains')
|
||||
assert not has_method(petrel_backend.client._client, 'isdir')
|
||||
with pytest.raises(NotImplementedError):
|
||||
petrel_backend.exists(petrel_path)
|
||||
|
||||
with patch.object(
|
||||
petrel_backend.client._client, 'contains',
|
||||
return_value=True) as mock_contains:
|
||||
assert petrel_backend.exists(petrel_path)
|
||||
mock_contains.assert_called_once_with(petrel_path)
|
||||
|
||||
# test `isdir`
|
||||
assert has_method(petrel_backend.client._client, 'isdir')
|
||||
with delete_and_reset_method(petrel_backend.client._client, 'isdir'):
|
||||
assert not has_method(petrel_backend.client._client, 'isdir')
|
||||
with pytest.raises(NotImplementedError):
|
||||
petrel_backend.isdir(petrel_path)
|
||||
|
||||
with patch.object(
|
||||
petrel_backend.client._client, 'isdir',
|
||||
return_value=True) as mock_isdir:
|
||||
assert petrel_backend.isdir(petrel_dir)
|
||||
mock_isdir.assert_called_once_with(petrel_dir)
|
||||
|
||||
# test `isfile`
|
||||
assert has_method(petrel_backend.client._client, 'contains')
|
||||
with delete_and_reset_method(petrel_backend.client._client,
|
||||
'contains'):
|
||||
assert not has_method(petrel_backend.client._client, 'contains')
|
||||
with pytest.raises(NotImplementedError):
|
||||
petrel_backend.isfile(petrel_path)
|
||||
|
||||
with patch.object(
|
||||
petrel_backend.client._client, 'contains',
|
||||
return_value=True) as mock_contains:
|
||||
assert petrel_backend.isfile(petrel_path)
|
||||
mock_contains.assert_called_once_with(petrel_path)
|
||||
|
||||
# test `concat_paths`
|
||||
assert petrel_backend.concat_paths(petrel_dir, 'file') == \
|
||||
f'{petrel_dir}/file'
|
||||
assert petrel_backend.concat_paths(f'{petrel_dir}/', 'file') == \
|
||||
f'{petrel_dir}/file'
|
||||
assert petrel_backend.concat_paths(petrel_dir, 'dir', 'file') == \
|
||||
f'{petrel_dir}/dir/file'
|
||||
|
||||
# test `get_local_path`
|
||||
with patch.object(petrel_backend.client._client, 'Get',
|
||||
return_value=b'petrel') as mock_get, \
|
||||
patch.object(petrel_backend.client._client, 'contains',
|
||||
return_value=True) as mock_contains:
|
||||
with petrel_backend.get_local_path(petrel_path) as path:
|
||||
assert Path(path).open('rb').read() == b'petrel'
|
||||
# exist the with block and path will be released
|
||||
assert not osp.isfile(path)
|
||||
mock_get.assert_called_once_with(petrel_path)
|
||||
mock_contains.assert_called_once_with(petrel_path)
|
||||
|
||||
# test `list_dir_or_file`
|
||||
assert has_method(petrel_backend.client._client, 'list')
|
||||
with delete_and_reset_method(petrel_backend.client._client, 'list'):
|
||||
assert not has_method(petrel_backend.client._client, 'list')
|
||||
with pytest.raises(NotImplementedError):
|
||||
list(petrel_backend.list_dir_or_file(petrel_dir))
|
||||
|
||||
with build_temporary_directory() as tmp_dir:
|
||||
# 1. list directories and files
|
||||
assert set(petrel_backend.list_dir_or_file(tmp_dir)) == set(
|
||||
['dir1', 'dir2', 'text1.txt', 'text2.txt'])
|
||||
# 2. list directories and files recursively
|
||||
assert set(
|
||||
petrel_backend.list_dir_or_file(
|
||||
tmp_dir, recursive=True)) == set([
|
||||
'dir1', '/'.join(('dir1', 'text3.txt')), 'dir2',
|
||||
'/'.join(('dir2', 'dir3')), '/'.join(
|
||||
('dir2', 'dir3', 'text4.txt')), '/'.join(
|
||||
('dir2', 'img.jpg')), 'text1.txt', 'text2.txt'
|
||||
])
|
||||
# 3. only list directories
|
||||
assert set(
|
||||
petrel_backend.list_dir_or_file(
|
||||
tmp_dir, list_file=False)) == set(['dir1', 'dir2'])
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match=('`list_dir` should be False when `suffix` is not '
|
||||
'None')):
|
||||
# Exception is raised among the `list_dir_or_file` of client,
|
||||
# so we need to invode the client to trigger the exception
|
||||
petrel_backend.client.list_dir_or_file(
|
||||
tmp_dir, list_file=False, suffix='.txt')
|
||||
# 4. only list directories recursively
|
||||
assert set(
|
||||
petrel_backend.list_dir_or_file(
|
||||
tmp_dir, list_file=False, recursive=True)) == set(
|
||||
['dir1', 'dir2', '/'.join(('dir2', 'dir3'))])
|
||||
# 5. only list files
|
||||
assert set(
|
||||
petrel_backend.list_dir_or_file(tmp_dir,
|
||||
list_dir=False)) == set(
|
||||
['text1.txt', 'text2.txt'])
|
||||
# 6. only list files recursively
|
||||
assert set(
|
||||
petrel_backend.list_dir_or_file(
|
||||
tmp_dir, list_dir=False, recursive=True)) == set([
|
||||
'/'.join(('dir1', 'text3.txt')), '/'.join(
|
||||
('dir2', 'dir3', 'text4.txt')), '/'.join(
|
||||
('dir2', 'img.jpg')), 'text1.txt', 'text2.txt'
|
||||
])
|
||||
# 7. only list files ending with suffix
|
||||
assert set(
|
||||
petrel_backend.list_dir_or_file(
|
||||
tmp_dir, list_dir=False,
|
||||
suffix='.txt')) == set(['text1.txt', 'text2.txt'])
|
||||
assert set(
|
||||
petrel_backend.list_dir_or_file(
|
||||
tmp_dir, list_dir=False,
|
||||
suffix=('.txt',
|
||||
'.jpg'))) == set(['text1.txt', 'text2.txt'])
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match='`suffix` must be a string or tuple of strings'):
|
||||
petrel_backend.client.list_dir_or_file(
|
||||
tmp_dir, list_dir=False, suffix=['.txt', '.jpg'])
|
||||
# 8. only list files ending with suffix recursively
|
||||
assert set(
|
||||
petrel_backend.list_dir_or_file(
|
||||
tmp_dir, list_dir=False, suffix='.txt',
|
||||
recursive=True)) == set([
|
||||
'/'.join(('dir1', 'text3.txt')), '/'.join(
|
||||
('dir2', 'dir3', 'text4.txt')), 'text1.txt',
|
||||
'text2.txt'
|
||||
])
|
||||
# 7. only list files ending with suffix
|
||||
assert set(
|
||||
petrel_backend.list_dir_or_file(
|
||||
tmp_dir,
|
||||
list_dir=False,
|
||||
suffix=('.txt', '.jpg'),
|
||||
recursive=True)) == set([
|
||||
'/'.join(('dir1', 'text3.txt')), '/'.join(
|
||||
('dir2', 'dir3', 'text4.txt')), '/'.join(
|
||||
('dir2', 'img.jpg')), 'text1.txt', 'text2.txt'
|
||||
])
|
||||
|
||||
@patch('mc.MemcachedClient.GetInstance', MockMemcachedClient)
|
||||
@patch('mc.pyvector', MagicMock)
|
||||
|
@ -182,8 +565,10 @@ class TestFileClient:
|
|||
img = mmcv.imfrombytes(img_bytes)
|
||||
assert img.shape == (120, 125, 3)
|
||||
|
||||
def test_http_backend(self):
|
||||
http_backend = FileClient('http')
|
||||
@pytest.mark.parametrize('backend,prefix', [('http', None),
|
||||
(None, 'http')])
|
||||
def test_http_backend(self, backend, prefix):
|
||||
http_backend = FileClient(backend=backend, prefix=prefix)
|
||||
img_url = 'https://raw.githubusercontent.com/open-mmlab/mmcv/' \
|
||||
'master/tests/data/color.jpg'
|
||||
text_url = 'https://raw.githubusercontent.com/open-mmlab/mmcv/' \
|
||||
|
@ -208,6 +593,84 @@ class TestFileClient:
|
|||
value_buf = http_backend.get_text(text_url)
|
||||
assert self.text_path.open('r').read() == value_buf
|
||||
|
||||
# test `_get_local_path`
|
||||
# exist the with block and path will be released
|
||||
with http_backend.get_local_path(img_url) as path:
|
||||
assert mmcv.imread(path).shape == self.img_shape
|
||||
assert not osp.isfile(path)
|
||||
|
||||
def test_new_magic_method(self):
|
||||
|
||||
class DummyBackend1(BaseStorageBackend):
|
||||
|
||||
def get(self, filepath):
|
||||
return filepath
|
||||
|
||||
def get_text(self, filepath, encoding='utf-8'):
|
||||
return filepath
|
||||
|
||||
FileClient.register_backend('dummy_backend', DummyBackend1)
|
||||
client1 = FileClient(backend='dummy_backend')
|
||||
client2 = FileClient(backend='dummy_backend')
|
||||
assert client1 is client2
|
||||
|
||||
# if a backend is overwrote, it will disable the singleton pattern for
|
||||
# the backend
|
||||
class DummyBackend2(BaseStorageBackend):
|
||||
|
||||
def get(self, filepath):
|
||||
pass
|
||||
|
||||
def get_text(self, filepath):
|
||||
pass
|
||||
|
||||
FileClient.register_backend('dummy_backend', DummyBackend2, force=True)
|
||||
client3 = FileClient(backend='dummy_backend')
|
||||
client4 = FileClient(backend='dummy_backend')
|
||||
assert client3 is not client4
|
||||
|
||||
def test_parse_uri_prefix(self):
|
||||
# input path is None
|
||||
with pytest.raises(AssertionError):
|
||||
FileClient.parse_uri_prefix(None)
|
||||
# input path is list
|
||||
with pytest.raises(AssertionError):
|
||||
FileClient.parse_uri_prefix([])
|
||||
|
||||
# input path is Path object
|
||||
assert FileClient.parse_uri_prefix(self.img_path) is None
|
||||
# input path is str
|
||||
assert FileClient.parse_uri_prefix(str(self.img_path)) is None
|
||||
|
||||
# input path starts with https
|
||||
img_url = 'https://raw.githubusercontent.com/open-mmlab/mmcv/' \
|
||||
'master/tests/data/color.jpg'
|
||||
assert FileClient.parse_uri_prefix(img_url) == 'https'
|
||||
|
||||
# input path starts with s3
|
||||
img_url = 's3://your_bucket/img.png'
|
||||
assert FileClient.parse_uri_prefix(img_url) == 's3'
|
||||
|
||||
# input path starts with clusterName:s3
|
||||
img_url = 'clusterName:s3://your_bucket/img.png'
|
||||
assert FileClient.parse_uri_prefix(img_url) == 's3'
|
||||
|
||||
def test_infer_client(self):
|
||||
# HardDiskBackend
|
||||
file_client_args = {'backend': 'disk'}
|
||||
client = FileClient.infer_client(file_client_args)
|
||||
assert client.backend_name == 'disk'
|
||||
client = FileClient.infer_client(uri=self.img_path)
|
||||
assert client.backend_name == 'disk'
|
||||
|
||||
# PetrelBackend
|
||||
file_client_args = {'backend': 'petrel'}
|
||||
client = FileClient.infer_client(file_client_args)
|
||||
assert client.backend_name == 'petrel'
|
||||
uri = 's3://user_data'
|
||||
client = FileClient.infer_client(uri=uri)
|
||||
assert client.backend_name == 'petrel'
|
||||
|
||||
def test_register_backend(self):
|
||||
|
||||
# name must be a string
|
||||
|
@ -235,7 +698,7 @@ class TestFileClient:
|
|||
def get(self, filepath):
|
||||
return filepath
|
||||
|
||||
def get_text(self, filepath):
|
||||
def get_text(self, filepath, encoding='utf-8'):
|
||||
return filepath
|
||||
|
||||
FileClient.register_backend('example', ExampleBackend)
|
||||
|
@ -247,9 +710,9 @@ class TestFileClient:
|
|||
class Example2Backend(BaseStorageBackend):
|
||||
|
||||
def get(self, filepath):
|
||||
return 'bytes2'
|
||||
return b'bytes2'
|
||||
|
||||
def get_text(self, filepath):
|
||||
def get_text(self, filepath, encoding='utf-8'):
|
||||
return 'text2'
|
||||
|
||||
# force=False
|
||||
|
@ -258,20 +721,20 @@ class TestFileClient:
|
|||
|
||||
FileClient.register_backend('example', Example2Backend, force=True)
|
||||
example_backend = FileClient('example')
|
||||
assert example_backend.get(self.img_path) == 'bytes2'
|
||||
assert example_backend.get(self.img_path) == b'bytes2'
|
||||
assert example_backend.get_text(self.text_path) == 'text2'
|
||||
|
||||
@FileClient.register_backend(name='example3')
|
||||
class Example3Backend(BaseStorageBackend):
|
||||
|
||||
def get(self, filepath):
|
||||
return 'bytes3'
|
||||
return b'bytes3'
|
||||
|
||||
def get_text(self, filepath):
|
||||
def get_text(self, filepath, encoding='utf-8'):
|
||||
return 'text3'
|
||||
|
||||
example_backend = FileClient('example3')
|
||||
assert example_backend.get(self.img_path) == 'bytes3'
|
||||
assert example_backend.get(self.img_path) == b'bytes3'
|
||||
assert example_backend.get_text(self.text_path) == 'text3'
|
||||
assert 'example3' in FileClient._backends
|
||||
|
||||
|
@ -282,20 +745,89 @@ class TestFileClient:
|
|||
class Example4Backend(BaseStorageBackend):
|
||||
|
||||
def get(self, filepath):
|
||||
return 'bytes4'
|
||||
return b'bytes4'
|
||||
|
||||
def get_text(self, filepath):
|
||||
def get_text(self, filepath, encoding='utf-8'):
|
||||
return 'text4'
|
||||
|
||||
@FileClient.register_backend(name='example3', force=True)
|
||||
class Example5Backend(BaseStorageBackend):
|
||||
|
||||
def get(self, filepath):
|
||||
return 'bytes5'
|
||||
return b'bytes5'
|
||||
|
||||
def get_text(self, filepath):
|
||||
def get_text(self, filepath, encoding='utf-8'):
|
||||
return 'text5'
|
||||
|
||||
example_backend = FileClient('example3')
|
||||
assert example_backend.get(self.img_path) == 'bytes5'
|
||||
assert example_backend.get(self.img_path) == b'bytes5'
|
||||
assert example_backend.get_text(self.text_path) == 'text5'
|
||||
|
||||
# prefixes is a str
|
||||
class Example6Backend(BaseStorageBackend):
|
||||
|
||||
def get(self, filepath):
|
||||
return b'bytes6'
|
||||
|
||||
def get_text(self, filepath, encoding='utf-8'):
|
||||
return 'text6'
|
||||
|
||||
FileClient.register_backend(
|
||||
'example4',
|
||||
Example6Backend,
|
||||
force=True,
|
||||
prefixes='example4_prefix')
|
||||
example_backend = FileClient('example4')
|
||||
assert example_backend.get(self.img_path) == b'bytes6'
|
||||
assert example_backend.get_text(self.text_path) == 'text6'
|
||||
example_backend = FileClient(prefix='example4_prefix')
|
||||
assert example_backend.get(self.img_path) == b'bytes6'
|
||||
assert example_backend.get_text(self.text_path) == 'text6'
|
||||
example_backend = FileClient('example4', prefix='example4_prefix')
|
||||
assert example_backend.get(self.img_path) == b'bytes6'
|
||||
assert example_backend.get_text(self.text_path) == 'text6'
|
||||
|
||||
# prefixes is a list of str
|
||||
class Example7Backend(BaseStorageBackend):
|
||||
|
||||
def get(self, filepath):
|
||||
return b'bytes7'
|
||||
|
||||
def get_text(self, filepath, encoding='utf-8'):
|
||||
return 'text7'
|
||||
|
||||
FileClient.register_backend(
|
||||
'example5',
|
||||
Example7Backend,
|
||||
force=True,
|
||||
prefixes=['example5_prefix1', 'example5_prefix2'])
|
||||
example_backend = FileClient('example5')
|
||||
assert example_backend.get(self.img_path) == b'bytes7'
|
||||
assert example_backend.get_text(self.text_path) == 'text7'
|
||||
example_backend = FileClient(prefix='example5_prefix1')
|
||||
assert example_backend.get(self.img_path) == b'bytes7'
|
||||
assert example_backend.get_text(self.text_path) == 'text7'
|
||||
example_backend = FileClient(prefix='example5_prefix2')
|
||||
assert example_backend.get(self.img_path) == b'bytes7'
|
||||
assert example_backend.get_text(self.text_path) == 'text7'
|
||||
|
||||
# backend has a higher priority than prefixes
|
||||
class Example8Backend(BaseStorageBackend):
|
||||
|
||||
def get(self, filepath):
|
||||
return b'bytes8'
|
||||
|
||||
def get_text(self, filepath, encoding='utf-8'):
|
||||
return 'text8'
|
||||
|
||||
FileClient.register_backend(
|
||||
'example6',
|
||||
Example8Backend,
|
||||
force=True,
|
||||
prefixes='example6_prefix')
|
||||
example_backend = FileClient('example6')
|
||||
assert example_backend.get(self.img_path) == b'bytes8'
|
||||
assert example_backend.get_text(self.text_path) == 'text8'
|
||||
example_backend = FileClient('example6', prefix='example4_prefix')
|
||||
assert example_backend.get(self.img_path) == b'bytes8'
|
||||
assert example_backend.get_text(self.text_path) == 'text8'
|
||||
|
|
|
@ -1,11 +1,17 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import os.path as osp
|
||||
import sys
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import mmcv
|
||||
from mmcv.fileio.file_client import HTTPBackend, PetrelBackend
|
||||
|
||||
sys.modules['petrel_client'] = MagicMock()
|
||||
sys.modules['petrel_client.client'] = MagicMock()
|
||||
|
||||
|
||||
def _test_handler(file_format, test_obj, str_checker, mode='r+'):
|
||||
|
@ -13,7 +19,7 @@ def _test_handler(file_format, test_obj, str_checker, mode='r+'):
|
|||
dump_str = mmcv.dump(test_obj, file_format=file_format)
|
||||
str_checker(dump_str)
|
||||
|
||||
# load/dump with filenames
|
||||
# load/dump with filenames from disk
|
||||
tmp_filename = osp.join(tempfile.gettempdir(), 'mmcv_test_dump')
|
||||
mmcv.dump(test_obj, tmp_filename, file_format=file_format)
|
||||
assert osp.isfile(tmp_filename)
|
||||
|
@ -21,6 +27,13 @@ def _test_handler(file_format, test_obj, str_checker, mode='r+'):
|
|||
assert load_obj == test_obj
|
||||
os.remove(tmp_filename)
|
||||
|
||||
# load/dump with filename from petrel
|
||||
method = 'put' if 'b' in mode else 'put_text'
|
||||
with patch.object(PetrelBackend, method, return_value=None) as mock_method:
|
||||
filename = 's3://path/of/your/file'
|
||||
mmcv.dump(test_obj, filename, file_format=file_format)
|
||||
mock_method.assert_called()
|
||||
|
||||
# json load/dump with a file-like object
|
||||
with tempfile.NamedTemporaryFile(mode, delete=False) as f:
|
||||
tmp_filename = f.name
|
||||
|
@ -122,6 +135,7 @@ def test_register_handler():
|
|||
|
||||
|
||||
def test_list_from_file():
|
||||
# get list from disk
|
||||
filename = osp.join(osp.dirname(__file__), 'data/filelist.txt')
|
||||
filelist = mmcv.list_from_file(filename)
|
||||
assert filelist == ['1.jpg', '2.jpg', '3.jpg', '4.jpg', '5.jpg']
|
||||
|
@ -134,10 +148,64 @@ def test_list_from_file():
|
|||
filelist = mmcv.list_from_file(filename, offset=3, max_num=3)
|
||||
assert filelist == ['4.jpg', '5.jpg']
|
||||
|
||||
# get list from http
|
||||
with patch.object(
|
||||
HTTPBackend, 'get_text', return_value='1.jpg\n2.jpg\n3.jpg'):
|
||||
filename = 'http://path/of/your/file'
|
||||
filelist = mmcv.list_from_file(
|
||||
filename, file_client_args={'backend': 'http'})
|
||||
assert filelist == ['1.jpg', '2.jpg', '3.jpg']
|
||||
filelist = mmcv.list_from_file(
|
||||
filename, file_client_args={'prefix': 'http'})
|
||||
assert filelist == ['1.jpg', '2.jpg', '3.jpg']
|
||||
filelist = mmcv.list_from_file(filename)
|
||||
assert filelist == ['1.jpg', '2.jpg', '3.jpg']
|
||||
|
||||
# get list from petrel
|
||||
with patch.object(
|
||||
PetrelBackend, 'get_text', return_value='1.jpg\n2.jpg\n3.jpg'):
|
||||
filename = 's3://path/of/your/file'
|
||||
filelist = mmcv.list_from_file(
|
||||
filename, file_client_args={'backend': 'petrel'})
|
||||
assert filelist == ['1.jpg', '2.jpg', '3.jpg']
|
||||
filelist = mmcv.list_from_file(
|
||||
filename, file_client_args={'prefix': 's3'})
|
||||
assert filelist == ['1.jpg', '2.jpg', '3.jpg']
|
||||
filelist = mmcv.list_from_file(filename)
|
||||
assert filelist == ['1.jpg', '2.jpg', '3.jpg']
|
||||
|
||||
|
||||
def test_dict_from_file():
|
||||
# get dict from disk
|
||||
filename = osp.join(osp.dirname(__file__), 'data/mapping.txt')
|
||||
mapping = mmcv.dict_from_file(filename)
|
||||
assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
|
||||
mapping = mmcv.dict_from_file(filename, key_type=int)
|
||||
assert mapping == {1: 'cat', 2: ['dog', 'cow'], 3: 'panda'}
|
||||
|
||||
# get dict from http
|
||||
with patch.object(
|
||||
HTTPBackend, 'get_text', return_value='1 cat\n2 dog cow\n3 panda'):
|
||||
filename = 'http://path/of/your/file'
|
||||
mapping = mmcv.dict_from_file(
|
||||
filename, file_client_args={'backend': 'http'})
|
||||
assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
|
||||
mapping = mmcv.dict_from_file(
|
||||
filename, file_client_args={'prefix': 'http'})
|
||||
assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
|
||||
mapping = mmcv.dict_from_file(filename)
|
||||
assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
|
||||
|
||||
# get dict from petrel
|
||||
with patch.object(
|
||||
PetrelBackend, 'get_text',
|
||||
return_value='1 cat\n2 dog cow\n3 panda'):
|
||||
filename = 's3://path/of/your/file'
|
||||
mapping = mmcv.dict_from_file(
|
||||
filename, file_client_args={'backend': 'petrel'})
|
||||
assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
|
||||
mapping = mmcv.dict_from_file(
|
||||
filename, file_client_args={'prefix': 's3'})
|
||||
assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
|
||||
mapping = mmcv.dict_from_file(filename)
|
||||
assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
|
||||
|
|
|
@ -3,6 +3,7 @@ import pytest
|
|||
|
||||
import mmcv
|
||||
from mmcv import deprecated_api_warning
|
||||
from mmcv.utils.misc import has_method
|
||||
|
||||
|
||||
def test_to_ntuple():
|
||||
|
@ -193,6 +194,21 @@ def test_is_method_overridden():
|
|||
mmcv.is_method_overridden('foo1', base_instance, sub_instance)
|
||||
|
||||
|
||||
def test_has_method():
|
||||
|
||||
class Foo:
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def print_name(self):
|
||||
print(self.name)
|
||||
|
||||
foo = Foo('foo')
|
||||
assert not has_method(foo, 'name')
|
||||
assert has_method(foo, 'print_name')
|
||||
|
||||
|
||||
def test_deprecated_api_warning():
|
||||
|
||||
@deprecated_api_warning(name_dict=dict(old_key='new_key'))
|
||||
|
|
Loading…
Reference in New Issue