mirror of https://github.com/open-mmlab/mmcv.git
add PetrelBackend (#294)
* add PetrelBackend * update docs * add path_maps for Ceph * rename to path_mapping * fixed importpull/287/head^2
parent
3bf3e8ef24
commit
0946feabe3
mmcv/fileio
tests
|
@ -1,4 +1,5 @@
|
|||
import inspect
|
||||
import warnings
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
|
||||
|
@ -20,18 +21,63 @@ class BaseStorageBackend(metaclass=ABCMeta):
|
|||
|
||||
|
||||
class CephBackend(BaseStorageBackend):
|
||||
"""Ceph storage backend."""
|
||||
"""Ceph storage backend.
|
||||
|
||||
def __init__(self):
|
||||
Args:
|
||||
path_mapping (dict|None): path mapping dict from local path to Petrel
|
||||
path. When `path_mapping={'src': 'dst'}`, `src` in `filepath` will
|
||||
be replaced by `dst`. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, path_mapping=None):
|
||||
try:
|
||||
import ceph
|
||||
warnings.warn('Ceph is deprecate in favor of Petrel.')
|
||||
except ImportError:
|
||||
raise ImportError('Please install ceph to enable CephBackend.')
|
||||
|
||||
self._client = ceph.S3Client()
|
||||
assert isinstance(path_mapping, dict) or path_mapping is None
|
||||
self.path_mapping = path_mapping
|
||||
|
||||
def get(self, filepath):
|
||||
filepath = str(filepath)
|
||||
if self.path_mapping is not None:
|
||||
for k, v in self.path_mapping.items():
|
||||
filepath = filepath.replace(k, v)
|
||||
value = self._client.Get(filepath)
|
||||
value_buf = memoryview(value)
|
||||
return value_buf
|
||||
|
||||
def get_text(self, filepath):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class PetrelBackend(BaseStorageBackend):
|
||||
"""Petrel storage backend (for internal use).
|
||||
|
||||
Args:
|
||||
path_mapping (dict|None): path mapping dict from local path to Petrel
|
||||
path. When `path_mapping={'src': 'dst'}`, `src` in `filepath` will
|
||||
be replaced by `dst`. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, path_mapping=None):
|
||||
try:
|
||||
from petrel_client import client
|
||||
except ImportError:
|
||||
raise ImportError('Please install petrel_client to enable '
|
||||
'PetrelBackend.')
|
||||
|
||||
self._client = client.Client()
|
||||
assert isinstance(path_mapping, dict) or path_mapping is None
|
||||
self.path_mapping = path_mapping
|
||||
|
||||
def get(self, filepath):
|
||||
filepath = str(filepath)
|
||||
if self.path_mapping is not None:
|
||||
for k, v in self.path_mapping.items():
|
||||
filepath = filepath.replace(k, v)
|
||||
value = self._client.Get(filepath)
|
||||
value_buf = memoryview(value)
|
||||
return value_buf
|
||||
|
@ -164,6 +210,7 @@ class FileClient(object):
|
|||
'ceph': CephBackend,
|
||||
'memcached': MemcachedBackend,
|
||||
'lmdb': LmdbBackend,
|
||||
'petrel': PetrelBackend,
|
||||
}
|
||||
|
||||
def __init__(self, backend='disk', **kwargs):
|
||||
|
|
|
@ -8,6 +8,8 @@ import mmcv
|
|||
from mmcv import BaseStorageBackend, FileClient
|
||||
|
||||
sys.modules['ceph'] = MagicMock()
|
||||
sys.modules['petrel_client'] = MagicMock()
|
||||
sys.modules['petrel_client.client'] = MagicMock()
|
||||
sys.modules['mc'] = MagicMock()
|
||||
|
||||
|
||||
|
@ -61,6 +63,9 @@ class TestFileClient(object):
|
|||
|
||||
@patch('ceph.S3Client', MockS3Client)
|
||||
def test_ceph_backend(self):
|
||||
with pytest.warns(
|
||||
Warning, match='Ceph is deprecate in favor of Petrel.'):
|
||||
FileClient('ceph')
|
||||
ceph_backend = FileClient('ceph')
|
||||
|
||||
# input path is Path object
|
||||
|
@ -79,6 +84,56 @@ class TestFileClient(object):
|
|||
img = mmcv.imfrombytes(img_bytes)
|
||||
assert img.shape == self.img_shape
|
||||
|
||||
# `path_mapping` is either None or dict
|
||||
with pytest.raises(AssertionError):
|
||||
FileClient('ceph', path_mapping=1)
|
||||
# test `path_mapping`
|
||||
ceph_path = 's3://user/data'
|
||||
ceph_backend = FileClient(
|
||||
'ceph', path_mapping={str(self.test_data_dir): ceph_path})
|
||||
ceph_backend.client._client.Get = MagicMock(
|
||||
return_value=ceph_backend.client._client.Get(self.img_path))
|
||||
img_bytes = ceph_backend.get(self.img_path)
|
||||
img = mmcv.imfrombytes(img_bytes)
|
||||
assert img.shape == self.img_shape
|
||||
ceph_backend.client._client.Get.assert_called_with(
|
||||
str(self.img_path).replace(str(self.test_data_dir), ceph_path))
|
||||
|
||||
@patch('petrel_client.client.Client', MockS3Client)
|
||||
def test_petrel_backend(self):
|
||||
petrel_backend = FileClient('petrel')
|
||||
|
||||
# input path is Path object
|
||||
with pytest.raises(NotImplementedError):
|
||||
petrel_backend.get_text(self.text_path)
|
||||
# input path is str
|
||||
with pytest.raises(NotImplementedError):
|
||||
petrel_backend.get_text(str(self.text_path))
|
||||
|
||||
# input path is Path object
|
||||
img_bytes = petrel_backend.get(self.img_path)
|
||||
img = mmcv.imfrombytes(img_bytes)
|
||||
assert img.shape == self.img_shape
|
||||
# input path is str
|
||||
img_bytes = petrel_backend.get(str(self.img_path))
|
||||
img = mmcv.imfrombytes(img_bytes)
|
||||
assert img.shape == self.img_shape
|
||||
|
||||
# `path_mapping` is either None or dict
|
||||
with pytest.raises(AssertionError):
|
||||
FileClient('petrel', path_mapping=1)
|
||||
# test `path_mapping`
|
||||
petrel_path = 's3://user/data'
|
||||
petrel_backend = FileClient(
|
||||
'petrel', path_mapping={str(self.test_data_dir): petrel_path})
|
||||
petrel_backend.client._client.Get = MagicMock(
|
||||
return_value=petrel_backend.client._client.Get(self.img_path))
|
||||
img_bytes = petrel_backend.get(self.img_path)
|
||||
img = mmcv.imfrombytes(img_bytes)
|
||||
assert img.shape == self.img_shape
|
||||
petrel_backend.client._client.Get.assert_called_with(
|
||||
str(self.img_path).replace(str(self.test_data_dir), petrel_path))
|
||||
|
||||
@patch('mc.MemcachedClient.GetInstance', MockMemcachedClient)
|
||||
@patch('mc.pyvector', MagicMock)
|
||||
@patch('mc.ConvertBuffer', lambda x: x.content)
|
||||
|
|
Loading…
Reference in New Issue