mmcv/tests/test_fileclient.py

276 lines
9.2 KiB
Python
Raw Normal View History

import sys
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
import mmcv
from mmcv import BaseStorageBackend, FileClient
sys.modules['ceph'] = MagicMock()
sys.modules['petrel_client'] = MagicMock()
sys.modules['petrel_client.client'] = MagicMock()
sys.modules['mc'] = MagicMock()
class MockS3Client:
def __init__(self, enable_mc=True):
self.enable_mc = enable_mc
def Get(self, filepath):
with open(filepath, 'rb') as f:
content = f.read()
return content
class MockMemcachedClient:
def __init__(self, server_list_cfg, client_cfg):
pass
def Get(self, filepath, buffer):
with open(filepath, 'rb') as f:
buffer.content = f.read()
class TestFileClient:
@classmethod
def setup_class(cls):
cls.test_data_dir = Path(__file__).parent / 'data'
cls.img_path = cls.test_data_dir / 'color.jpg'
cls.img_shape = (300, 400, 3)
cls.text_path = cls.test_data_dir / 'filelist.txt'
def test_error(self):
with pytest.raises(ValueError):
FileClient('hadoop')
def test_disk_backend(self):
disk_backend = FileClient('disk')
# input path is Path object
img_bytes = disk_backend.get(self.img_path)
img = mmcv.imfrombytes(img_bytes)
assert self.img_path.open('rb').read() == img_bytes
assert img.shape == self.img_shape
# input path is str
img_bytes = disk_backend.get(str(self.img_path))
img = mmcv.imfrombytes(img_bytes)
assert self.img_path.open('rb').read() == img_bytes
assert img.shape == self.img_shape
# input path is Path object
value_buf = disk_backend.get_text(self.text_path)
assert self.text_path.open('r').read() == value_buf
# input path is str
value_buf = disk_backend.get_text(str(self.text_path))
assert self.text_path.open('r').read() == value_buf
@patch('ceph.S3Client', MockS3Client)
def test_ceph_backend(self):
ceph_backend = FileClient('ceph')
# input path is Path object
with pytest.raises(NotImplementedError):
ceph_backend.get_text(self.text_path)
# input path is str
with pytest.raises(NotImplementedError):
ceph_backend.get_text(str(self.text_path))
# input path is Path object
img_bytes = ceph_backend.get(self.img_path)
img = mmcv.imfrombytes(img_bytes)
assert img.shape == self.img_shape
# input path is str
img_bytes = ceph_backend.get(str(self.img_path))
img = mmcv.imfrombytes(img_bytes)
assert img.shape == self.img_shape
# `path_mapping` is either None or dict
with pytest.raises(AssertionError):
FileClient('ceph', path_mapping=1)
# test `path_mapping`
ceph_path = 's3://user/data'
ceph_backend = FileClient(
'ceph', path_mapping={str(self.test_data_dir): ceph_path})
ceph_backend.client._client.Get = MagicMock(
return_value=ceph_backend.client._client.Get(self.img_path))
img_bytes = ceph_backend.get(self.img_path)
img = mmcv.imfrombytes(img_bytes)
assert img.shape == self.img_shape
ceph_backend.client._client.Get.assert_called_with(
str(self.img_path).replace(str(self.test_data_dir), ceph_path))
@patch('petrel_client.client.Client', MockS3Client)
def test_petrel_backend(self):
petrel_backend = FileClient('petrel')
# input path is Path object
with pytest.raises(NotImplementedError):
petrel_backend.get_text(self.text_path)
# input path is str
with pytest.raises(NotImplementedError):
petrel_backend.get_text(str(self.text_path))
# input path is Path object
img_bytes = petrel_backend.get(self.img_path)
img = mmcv.imfrombytes(img_bytes)
assert img.shape == self.img_shape
# input path is str
img_bytes = petrel_backend.get(str(self.img_path))
img = mmcv.imfrombytes(img_bytes)
assert img.shape == self.img_shape
# `path_mapping` is either None or dict
with pytest.raises(AssertionError):
FileClient('petrel', path_mapping=1)
# test `path_mapping`
petrel_path = 's3://user/data'
petrel_backend = FileClient(
'petrel', path_mapping={str(self.test_data_dir): petrel_path})
petrel_backend.client._client.Get = MagicMock(
return_value=petrel_backend.client._client.Get(self.img_path))
img_bytes = petrel_backend.get(self.img_path)
img = mmcv.imfrombytes(img_bytes)
assert img.shape == self.img_shape
petrel_backend.client._client.Get.assert_called_with(
str(self.img_path).replace(str(self.test_data_dir), petrel_path))
@patch('mc.MemcachedClient.GetInstance', MockMemcachedClient)
@patch('mc.pyvector', MagicMock)
@patch('mc.ConvertBuffer', lambda x: x.content)
def test_memcached_backend(self):
mc_cfg = dict(server_list_cfg='', client_cfg='', sys_path=None)
mc_backend = FileClient('memcached', **mc_cfg)
# input path is Path object
with pytest.raises(NotImplementedError):
mc_backend.get_text(self.text_path)
# input path is str
with pytest.raises(NotImplementedError):
mc_backend.get_text(str(self.text_path))
# input path is Path object
img_bytes = mc_backend.get(self.img_path)
img = mmcv.imfrombytes(img_bytes)
assert img.shape == self.img_shape
# input path is str
img_bytes = mc_backend.get(str(self.img_path))
img = mmcv.imfrombytes(img_bytes)
assert img.shape == self.img_shape
def test_lmdb_backend(self):
lmdb_path = self.test_data_dir / 'demo.lmdb'
# db_path is Path object
lmdb_backend = FileClient('lmdb', db_path=lmdb_path)
with pytest.raises(NotImplementedError):
lmdb_backend.get_text(self.text_path)
img_bytes = lmdb_backend.get('baboon')
img = mmcv.imfrombytes(img_bytes)
assert img.shape == (120, 125, 3)
# db_path is str
lmdb_backend = FileClient('lmdb', db_path=str(lmdb_path))
with pytest.raises(NotImplementedError):
lmdb_backend.get_text(str(self.text_path))
img_bytes = lmdb_backend.get('baboon')
img = mmcv.imfrombytes(img_bytes)
assert img.shape == (120, 125, 3)
def test_register_backend(self):
# name must be a string
with pytest.raises(TypeError):
class TestClass1:
pass
FileClient.register_backend(1, TestClass1)
# module must be a class
with pytest.raises(TypeError):
FileClient.register_backend('int', 0)
# module must be a subclass of BaseStorageBackend
with pytest.raises(TypeError):
class TestClass1:
pass
FileClient.register_backend('TestClass1', TestClass1)
class ExampleBackend(BaseStorageBackend):
def get(self, filepath):
return filepath
def get_text(self, filepath):
return filepath
FileClient.register_backend('example', ExampleBackend)
example_backend = FileClient('example')
assert example_backend.get(self.img_path) == self.img_path
assert example_backend.get_text(self.text_path) == self.text_path
assert 'example' in FileClient._backends
class Example2Backend(BaseStorageBackend):
def get(self, filepath):
return 'bytes2'
def get_text(self, filepath):
return 'text2'
# force=False
with pytest.raises(KeyError):
FileClient.register_backend('example', Example2Backend)
FileClient.register_backend('example', Example2Backend, force=True)
example_backend = FileClient('example')
assert example_backend.get(self.img_path) == 'bytes2'
assert example_backend.get_text(self.text_path) == 'text2'
@FileClient.register_backend(name='example3')
class Example3Backend(BaseStorageBackend):
def get(self, filepath):
return 'bytes3'
def get_text(self, filepath):
return 'text3'
example_backend = FileClient('example3')
assert example_backend.get(self.img_path) == 'bytes3'
assert example_backend.get_text(self.text_path) == 'text3'
assert 'example3' in FileClient._backends
# force=False
with pytest.raises(KeyError):
@FileClient.register_backend(name='example3')
class Example4Backend(BaseStorageBackend):
def get(self, filepath):
return 'bytes4'
def get_text(self, filepath):
return 'text4'
@FileClient.register_backend(name='example3', force=True)
class Example5Backend(BaseStorageBackend):
def get(self, filepath):
return 'bytes5'
def get_text(self, filepath):
return 'text5'
example_backend = FileClient('example3')
assert example_backend.get(self.img_path) == 'bytes5'
assert example_backend.get_text(self.text_path) == 'text5'