Allow registering new backends with decorators (#307)

* allow registering new backends with decorators

* add a docstring

* minor update to the docstring
pull/309/head
Kai Chen 2020-05-31 21:56:03 +08:00 committed by GitHub
parent 4150d0ac2f
commit 213156cebb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 131 additions and 7 deletions

View File

@ -222,16 +222,73 @@ class FileClient(object):
self.client = self._backends[backend](**kwargs)
@classmethod
def register_backend(cls, name, backend):
def _register_backend(cls, name, backend, force=False):
if not isinstance(name, str):
raise TypeError('the backend name should be a string, '
f'but got {type(name)}')
if not inspect.isclass(backend):
raise TypeError(
f'backend should be a class but got {type(backend)}')
if not issubclass(backend, BaseStorageBackend):
raise TypeError(
f'backend {backend} is not a subclass of BaseStorageBackend')
if not force and name in cls._backends:
raise KeyError(
f'{name} is already registered as a storage backend, '
'add "force=True" if you want to override it')
cls._backends[name] = backend
@classmethod
def register_backend(cls, name, backend=None, force=False):
"""Register a backend to FileClient.
This method can be used as a normal class method or a decorator.
.. code-block:: python
class NewBackend(BaseStorageBackend):
def get(self, filepath):
return filepath
def get_text(self, filepath):
return filepath
FileClient.register_backend('new', NewBackend)
or
.. code-block:: python
@FileClient.register_backend('new')
class NewBackend(BaseStorageBackend):
def get(self, filepath):
return filepath
def get_text(self, filepath):
return filepath
Args:
name (str): The name of the registered backend.
backend (class, optional): The backend class to be registered,
which must be a subclass of :class:`BaseStorageBackend`.
When this method is used as a decorator, backend is None.
Defaults to None.
force (bool, optional): Whether to override the backend if the name
has already been registered. Defaults to False.
"""
if backend is not None:
cls._register_backend(name, backend, force=force)
return
def _register(backend_cls):
cls._register_backend(name, backend_cls, force=force)
return backend_cls
return _register
def get(self, filepath):
return self.client.get(filepath)

View File

@ -40,6 +40,10 @@ class TestFileClient(object):
cls.img_shape = (300, 400, 3)
cls.text_path = cls.test_data_dir / 'filelist.txt'
def test_error(self):
with pytest.raises(ValueError):
FileClient('hadoop')
def test_disk_backend(self):
disk_backend = FileClient('disk')
@ -179,6 +183,20 @@ class TestFileClient(object):
assert img.shape == (120, 125, 3)
def test_register_backend(self):
# name must be a string
with pytest.raises(TypeError):
class TestClass1(object):
pass
FileClient.register_backend(1, TestClass1)
# module must be a class
with pytest.raises(TypeError):
FileClient.register_backend('int', 0)
# module must be a subclass of BaseStorageBackend
with pytest.raises(TypeError):
class TestClass1(object):
@ -186,9 +204,6 @@ class TestFileClient(object):
FileClient.register_backend('TestClass1', TestClass1)
with pytest.raises(TypeError):
FileClient.register_backend('int', 0)
class ExampleBackend(BaseStorageBackend):
def get(self, filepath):
@ -203,6 +218,58 @@ class TestFileClient(object):
assert example_backend.get_text(self.text_path) == self.text_path
assert 'example' in FileClient._backends
def test_error(self):
with pytest.raises(ValueError):
FileClient('hadoop')
class Example2Backend(BaseStorageBackend):
def get(self, filepath):
return 'bytes2'
def get_text(self, filepath):
return 'text2'
# force=False
with pytest.raises(KeyError):
FileClient.register_backend('example', Example2Backend)
FileClient.register_backend('example', Example2Backend, force=True)
example_backend = FileClient('example')
assert example_backend.get(self.img_path) == 'bytes2'
assert example_backend.get_text(self.text_path) == 'text2'
@FileClient.register_backend(name='example3')
class Example3Backend(BaseStorageBackend):
def get(self, filepath):
return 'bytes3'
def get_text(self, filepath):
return 'text3'
example_backend = FileClient('example3')
assert example_backend.get(self.img_path) == 'bytes3'
assert example_backend.get_text(self.text_path) == 'text3'
assert 'example3' in FileClient._backends
# force=False
with pytest.raises(KeyError):
@FileClient.register_backend(name='example3')
class Example4Backend(BaseStorageBackend):
def get(self, filepath):
return 'bytes4'
def get_text(self, filepath):
return 'text4'
@FileClient.register_backend(name='example3', force=True)
class Example5Backend(BaseStorageBackend):
def get(self, filepath):
return 'bytes5'
def get_text(self, filepath):
return 'text5'
example_backend = FileClient('example3')
assert example_backend.get(self.img_path) == 'bytes5'
assert example_backend.get_text(self.text_path) == 'text5'