diff --git a/mmcv/fileio/file_client.py b/mmcv/fileio/file_client.py index bce01c850..479146144 100644 --- a/mmcv/fileio/file_client.py +++ b/mmcv/fileio/file_client.py @@ -222,16 +222,73 @@ class FileClient(object): self.client = self._backends[backend](**kwargs) @classmethod - def register_backend(cls, name, backend): + def _register_backend(cls, name, backend, force=False): + if not isinstance(name, str): + raise TypeError('the backend name should be a string, ' + f'but got {type(name)}') if not inspect.isclass(backend): raise TypeError( f'backend should be a class but got {type(backend)}') if not issubclass(backend, BaseStorageBackend): raise TypeError( f'backend {backend} is not a subclass of BaseStorageBackend') + if not force and name in cls._backends: + raise KeyError( + f'{name} is already registered as a storage backend, ' + 'add "force=True" if you want to override it') cls._backends[name] = backend + @classmethod + def register_backend(cls, name, backend=None, force=False): + """Register a backend to FileClient. + + This method can be used as a normal class method or a decorator. + + .. code-block:: python + + class NewBackend(BaseStorageBackend): + + def get(self, filepath): + return filepath + + def get_text(self, filepath): + return filepath + + FileClient.register_backend('new', NewBackend) + + or + + .. code-block:: python + + @FileClient.register_backend('new') + class NewBackend(BaseStorageBackend): + + def get(self, filepath): + return filepath + + def get_text(self, filepath): + return filepath + + Args: + name (str): The name of the registered backend. + backend (class, optional): The backend class to be registered, + which must be a subclass of :class:`BaseStorageBackend`. + When this method is used as a decorator, backend is None. + Defaults to None. + force (bool, optional): Whether to override the backend if the name + has already been registered. Defaults to False. + """ + if backend is not None: + cls._register_backend(name, backend, force=force) + return + + def _register(backend_cls): + cls._register_backend(name, backend_cls, force=force) + return backend_cls + + return _register + def get(self, filepath): return self.client.get(filepath) diff --git a/tests/test_fileclient.py b/tests/test_fileclient.py index 922bb535b..bdd07b7ff 100644 --- a/tests/test_fileclient.py +++ b/tests/test_fileclient.py @@ -40,6 +40,10 @@ class TestFileClient(object): cls.img_shape = (300, 400, 3) cls.text_path = cls.test_data_dir / 'filelist.txt' + def test_error(self): + with pytest.raises(ValueError): + FileClient('hadoop') + def test_disk_backend(self): disk_backend = FileClient('disk') @@ -179,6 +183,20 @@ class TestFileClient(object): assert img.shape == (120, 125, 3) def test_register_backend(self): + + # name must be a string + with pytest.raises(TypeError): + + class TestClass1(object): + pass + + FileClient.register_backend(1, TestClass1) + + # module must be a class + with pytest.raises(TypeError): + FileClient.register_backend('int', 0) + + # module must be a subclass of BaseStorageBackend with pytest.raises(TypeError): class TestClass1(object): @@ -186,9 +204,6 @@ class TestFileClient(object): FileClient.register_backend('TestClass1', TestClass1) - with pytest.raises(TypeError): - FileClient.register_backend('int', 0) - class ExampleBackend(BaseStorageBackend): def get(self, filepath): @@ -203,6 +218,58 @@ class TestFileClient(object): assert example_backend.get_text(self.text_path) == self.text_path assert 'example' in FileClient._backends - def test_error(self): - with pytest.raises(ValueError): - FileClient('hadoop') + class Example2Backend(BaseStorageBackend): + + def get(self, filepath): + return 'bytes2' + + def get_text(self, filepath): + return 'text2' + + # force=False + with pytest.raises(KeyError): + FileClient.register_backend('example', Example2Backend) + + FileClient.register_backend('example', Example2Backend, force=True) + example_backend = FileClient('example') + assert example_backend.get(self.img_path) == 'bytes2' + assert example_backend.get_text(self.text_path) == 'text2' + + @FileClient.register_backend(name='example3') + class Example3Backend(BaseStorageBackend): + + def get(self, filepath): + return 'bytes3' + + def get_text(self, filepath): + return 'text3' + + example_backend = FileClient('example3') + assert example_backend.get(self.img_path) == 'bytes3' + assert example_backend.get_text(self.text_path) == 'text3' + assert 'example3' in FileClient._backends + + # force=False + with pytest.raises(KeyError): + + @FileClient.register_backend(name='example3') + class Example4Backend(BaseStorageBackend): + + def get(self, filepath): + return 'bytes4' + + def get_text(self, filepath): + return 'text4' + + @FileClient.register_backend(name='example3', force=True) + class Example5Backend(BaseStorageBackend): + + def get(self, filepath): + return 'bytes5' + + def get_text(self, filepath): + return 'text5' + + example_backend = FileClient('example3') + assert example_backend.get(self.img_path) == 'bytes5' + assert example_backend.get_text(self.text_path) == 'text5'