mirror of https://github.com/open-mmlab/mmcv.git
Allow registering new backends with decorators (#307)
* allow registering new backends with decorators * add a docstring * minor update to the docstringpull/309/head
parent
4150d0ac2f
commit
213156cebb
|
@ -222,16 +222,73 @@ class FileClient(object):
|
|||
self.client = self._backends[backend](**kwargs)
|
||||
|
||||
@classmethod
|
||||
def register_backend(cls, name, backend):
|
||||
def _register_backend(cls, name, backend, force=False):
|
||||
if not isinstance(name, str):
|
||||
raise TypeError('the backend name should be a string, '
|
||||
f'but got {type(name)}')
|
||||
if not inspect.isclass(backend):
|
||||
raise TypeError(
|
||||
f'backend should be a class but got {type(backend)}')
|
||||
if not issubclass(backend, BaseStorageBackend):
|
||||
raise TypeError(
|
||||
f'backend {backend} is not a subclass of BaseStorageBackend')
|
||||
if not force and name in cls._backends:
|
||||
raise KeyError(
|
||||
f'{name} is already registered as a storage backend, '
|
||||
'add "force=True" if you want to override it')
|
||||
|
||||
cls._backends[name] = backend
|
||||
|
||||
@classmethod
|
||||
def register_backend(cls, name, backend=None, force=False):
|
||||
"""Register a backend to FileClient.
|
||||
|
||||
This method can be used as a normal class method or a decorator.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class NewBackend(BaseStorageBackend):
|
||||
|
||||
def get(self, filepath):
|
||||
return filepath
|
||||
|
||||
def get_text(self, filepath):
|
||||
return filepath
|
||||
|
||||
FileClient.register_backend('new', NewBackend)
|
||||
|
||||
or
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@FileClient.register_backend('new')
|
||||
class NewBackend(BaseStorageBackend):
|
||||
|
||||
def get(self, filepath):
|
||||
return filepath
|
||||
|
||||
def get_text(self, filepath):
|
||||
return filepath
|
||||
|
||||
Args:
|
||||
name (str): The name of the registered backend.
|
||||
backend (class, optional): The backend class to be registered,
|
||||
which must be a subclass of :class:`BaseStorageBackend`.
|
||||
When this method is used as a decorator, backend is None.
|
||||
Defaults to None.
|
||||
force (bool, optional): Whether to override the backend if the name
|
||||
has already been registered. Defaults to False.
|
||||
"""
|
||||
if backend is not None:
|
||||
cls._register_backend(name, backend, force=force)
|
||||
return
|
||||
|
||||
def _register(backend_cls):
|
||||
cls._register_backend(name, backend_cls, force=force)
|
||||
return backend_cls
|
||||
|
||||
return _register
|
||||
|
||||
def get(self, filepath):
|
||||
return self.client.get(filepath)
|
||||
|
||||
|
|
|
@ -40,6 +40,10 @@ class TestFileClient(object):
|
|||
cls.img_shape = (300, 400, 3)
|
||||
cls.text_path = cls.test_data_dir / 'filelist.txt'
|
||||
|
||||
def test_error(self):
|
||||
with pytest.raises(ValueError):
|
||||
FileClient('hadoop')
|
||||
|
||||
def test_disk_backend(self):
|
||||
disk_backend = FileClient('disk')
|
||||
|
||||
|
@ -179,6 +183,20 @@ class TestFileClient(object):
|
|||
assert img.shape == (120, 125, 3)
|
||||
|
||||
def test_register_backend(self):
|
||||
|
||||
# name must be a string
|
||||
with pytest.raises(TypeError):
|
||||
|
||||
class TestClass1(object):
|
||||
pass
|
||||
|
||||
FileClient.register_backend(1, TestClass1)
|
||||
|
||||
# module must be a class
|
||||
with pytest.raises(TypeError):
|
||||
FileClient.register_backend('int', 0)
|
||||
|
||||
# module must be a subclass of BaseStorageBackend
|
||||
with pytest.raises(TypeError):
|
||||
|
||||
class TestClass1(object):
|
||||
|
@ -186,9 +204,6 @@ class TestFileClient(object):
|
|||
|
||||
FileClient.register_backend('TestClass1', TestClass1)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
FileClient.register_backend('int', 0)
|
||||
|
||||
class ExampleBackend(BaseStorageBackend):
|
||||
|
||||
def get(self, filepath):
|
||||
|
@ -203,6 +218,58 @@ class TestFileClient(object):
|
|||
assert example_backend.get_text(self.text_path) == self.text_path
|
||||
assert 'example' in FileClient._backends
|
||||
|
||||
def test_error(self):
|
||||
with pytest.raises(ValueError):
|
||||
FileClient('hadoop')
|
||||
class Example2Backend(BaseStorageBackend):
|
||||
|
||||
def get(self, filepath):
|
||||
return 'bytes2'
|
||||
|
||||
def get_text(self, filepath):
|
||||
return 'text2'
|
||||
|
||||
# force=False
|
||||
with pytest.raises(KeyError):
|
||||
FileClient.register_backend('example', Example2Backend)
|
||||
|
||||
FileClient.register_backend('example', Example2Backend, force=True)
|
||||
example_backend = FileClient('example')
|
||||
assert example_backend.get(self.img_path) == 'bytes2'
|
||||
assert example_backend.get_text(self.text_path) == 'text2'
|
||||
|
||||
@FileClient.register_backend(name='example3')
|
||||
class Example3Backend(BaseStorageBackend):
|
||||
|
||||
def get(self, filepath):
|
||||
return 'bytes3'
|
||||
|
||||
def get_text(self, filepath):
|
||||
return 'text3'
|
||||
|
||||
example_backend = FileClient('example3')
|
||||
assert example_backend.get(self.img_path) == 'bytes3'
|
||||
assert example_backend.get_text(self.text_path) == 'text3'
|
||||
assert 'example3' in FileClient._backends
|
||||
|
||||
# force=False
|
||||
with pytest.raises(KeyError):
|
||||
|
||||
@FileClient.register_backend(name='example3')
|
||||
class Example4Backend(BaseStorageBackend):
|
||||
|
||||
def get(self, filepath):
|
||||
return 'bytes4'
|
||||
|
||||
def get_text(self, filepath):
|
||||
return 'text4'
|
||||
|
||||
@FileClient.register_backend(name='example3', force=True)
|
||||
class Example5Backend(BaseStorageBackend):
|
||||
|
||||
def get(self, filepath):
|
||||
return 'bytes5'
|
||||
|
||||
def get_text(self, filepath):
|
||||
return 'text5'
|
||||
|
||||
example_backend = FileClient('example3')
|
||||
assert example_backend.get(self.img_path) == 'bytes5'
|
||||
assert example_backend.get_text(self.text_path) == 'text5'
|
||||
|
|
Loading…
Reference in New Issue