From e760dcd1dd89d6c136db2bec8c56bab00109d201 Mon Sep 17 00:00:00 2001 From: Tong Gao Date: Mon, 22 Aug 2022 16:43:25 +0800 Subject: [PATCH] [Enhancement] Support loading different lmdb datasets in LoadImageFromLMDB (#1293) * [Enhancement] Support loading different lmdb datasets in LoadImageFromLMDB * add docstr --- mmocr/datasets/transforms/loading.py | 75 +++++++++++------- .../data/rec_toy_dataset/broken.lmdb/data.mdb | Bin 0 -> 12288 bytes .../data/rec_toy_dataset/broken.lmdb/lock.mdb | Bin 0 -> 8192 bytes .../test_transforms/test_loading.py | 33 ++++++-- 4 files changed, 73 insertions(+), 35 deletions(-) create mode 100644 tests/data/rec_toy_dataset/broken.lmdb/data.mdb create mode 100644 tests/data/rec_toy_dataset/broken.lmdb/lock.mdb diff --git a/mmocr/datasets/transforms/loading.py b/mmocr/datasets/transforms/loading.py index 05450322..034af68e 100644 --- a/mmocr/datasets/transforms/loading.py +++ b/mmocr/datasets/transforms/loading.py @@ -44,8 +44,8 @@ class LoadImageFromFile(MMCV_LoadImageFromFile): ignore_empty (bool): Whether to allow loading empty image or file path not existent. Defaults to False. min_size (int): The minimum size of the image to be loaded. If the - image is smaller than the minimum size, it will be ignored. - Defaults to 0. + image is smaller than the minimum size, it will be regarded as a + broken image. Defaults to 0. """ def __init__(self, @@ -89,18 +89,15 @@ class LoadImageFromFile(MMCV_LoadImageFromFile): return None else: raise e - if self.ignore_empty and img is None: - warnings.warn(f'Ignore broken image: {filename}') - return None - elif img is None: + if img is None or min(img.shape[:2]) < self.min_size: + if self.ignore_empty: + warnings.warn(f'Ignore broken image: {filename}') + return None raise IOError(f'{filename} is broken') if self.to_float32: img = img.astype(np.float32) - if min(img.shape[:2]) < self.min_size: - return None - results['img'] = img results['img_shape'] = img.shape[:2] results['ori_shape'] = img.shape[:2] @@ -473,25 +470,42 @@ class LoadImageFromLMDB(BaseTransform): argument for :func:``mmcv.imfrombytes``. See :func:``mmcv.imfrombytes`` for details. Defaults to 'cv2'. - file_client_args (dict): Arguments to instantiate a FileClient. - See :class:`mmengine.fileio.FileClient` for details. - Defaults to ``dict(backend='lmdb', db_path='')``. + file_client_args (dict): Arguments to instantiate a FileClient except + for ``backend`` and ``db_path``. See + :class:`mmengine.fileio.FileClient` for details. + Defaults to ``dict()``. ignore_empty (bool): Whether to allow loading empty image or file path not existent. Defaults to False. """ - def __init__(self, - to_float32: bool = False, - color_type: str = 'color', - imdecode_backend: str = 'cv2', - file_client_args: dict = dict(backend='lmdb', db_path=''), - ignore_empty: bool = False) -> None: + def __init__( + self, + to_float32: bool = False, + color_type: str = 'color', + imdecode_backend: str = 'cv2', + file_client_args: dict = dict(), + ignore_empty: bool = False, + ) -> None: self.ignore_empty = ignore_empty self.to_float32 = to_float32 self.color_type = color_type self.imdecode_backend = imdecode_backend - self.file_client_args = file_client_args.copy() - self.file_client = mmengine.FileClient(**self.file_client_args) + self.file_clients = {} + if 'backend' in file_client_args or 'db_path' in file_client_args: + raise ValueError( + '"file_client_args" should not contain "backend" and "db_path"' + ) + self.file_client_args = file_client_args + + def _get_client(self, db_path: str) -> mmengine.FileClient: + """Get a FileClient bound to the given db_path. + + If the client for this db_path is not initialized, initialize it. + """ + if self.file_clients.get(db_path) is None: + self.file_clients[db_path] = mmengine.FileClient( + backend='lmdb', db_path=db_path, **self.file_client_args) + return self.file_clients.get(db_path) def transform(self, results: dict) -> Optional[dict]: """Functions to load image from LMDB file. @@ -505,14 +519,21 @@ class LoadImageFromLMDB(BaseTransform): filename = results['img_path'] lmdb_path = os.path.dirname(filename) image_key = os.path.basename(filename) - self.file_client.client.db_path = lmdb_path - img_bytes = self.file_client.get(image_key) + file_client = self._get_client(lmdb_path) + img_bytes = file_client.get(image_key) + if img_bytes is None: - return None - try: - img = mmcv.imfrombytes(img_bytes, flag=self.color_type) - except OSError: - return None + if self.ignore_empty: + return None + raise KeyError(f'Image not found in lmdb: {filename}') + + img = mmcv.imfrombytes(img_bytes, flag=self.color_type) + + if img is None: + if self.ignore_empty: + return None + raise IOError(f'{filename} is broken') + if self.to_float32: img = img.astype(np.float32) diff --git a/tests/data/rec_toy_dataset/broken.lmdb/data.mdb b/tests/data/rec_toy_dataset/broken.lmdb/data.mdb new file mode 100644 index 0000000000000000000000000000000000000000..e0f32582d80daac31749135ce7225a7db9268bb3 GIT binary patch literal 12288 zcmZQzfB_CLfQkXKSO42VtOh<_WzCl4;Tcz+&tYF z7=eZ|0wKfy+YHVOJRF?doE$vdT-PKLX@4FtRd&0W1m_nV4Bv+1NQaxwwG}whAyXF)}kVu`si;vakSE z*8=4kSOi&x6b&8OgaZ@Vl?p|S8YeE~Pwh=DOELf4NWZ* zQ!{f5ODks=S2uSLPp{yR(6I1`$f)F$)U@=B%&g*)(z5c3%Btp;*0%PJ&aO$5r%atT zea6gLixw|gx@`H1m8&*w-m-Pu_8mKS9XfpE=&|D`PM*4S`O4L6*Kgds_3+W-Cr_U} zfAR9w$4{TXeEs(Q$Io9Ne=#yJL%anfAwEO%mmttzOe`$SEbJhEF*20{F|!~GtD+&B ukYgZwVxh2-Q6q=9aq$@ z&9M`;+`jv-ckM6FRQ5~6$uRZZHpZk9AV7cs0RjXF5FkK+009C72oNAZfB*pk1PBly LK!5-N0)GTfLu3`A literal 0 HcmV?d00001 diff --git a/tests/test_datasets/test_transforms/test_loading.py b/tests/test_datasets/test_transforms/test_loading.py index eda9c4e3..92780170 100644 --- a/tests/test_datasets/test_transforms/test_loading.py +++ b/tests/test_datasets/test_transforms/test_loading.py @@ -36,9 +36,11 @@ class TestLoadImageFromFile(TestCase): self.assertEquals(results['img'].dtype, np.float32) # min_size + transform = LoadImageFromFile(min_size=26, ignore_empty=True) + self.assertIsNone(transform(copy.deepcopy(results))) transform = LoadImageFromFile(min_size=26) - results = transform(copy.deepcopy(results)) - self.assertIsNone(results) + with self.assertRaises(IOError): + transform(copy.deepcopy(results)) # test load empty fake_img_path = osp.join(data_prefix, 'fake.jpg') @@ -197,12 +199,19 @@ class TestLoadImageFromLMDB(TestCase): self.results1 = { 'img_path': f'tests/data/rec_toy_dataset/imgs.lmdb/{img_key}' } + self.broken_results = { + 'img_path': f'tests/data/rec_toy_dataset/broken.lmdb/{img_key}' + } img_key = 'image-%09d' % 100 self.results2 = { 'img_path': f'tests/data/rec_toy_dataset/imgs.lmdb/{img_key}' } + def test_init(self): + with self.assertRaises(ValueError): + LoadImageFromLMDB(file_client_args=dict(backend='disk')) + def test_transform(self): transform = LoadImageFromLMDB() results = transform(copy.deepcopy(self.results1)) @@ -212,9 +221,18 @@ class TestLoadImageFromLMDB(TestCase): self.assertEqual(results['ori_shape'], results['img_shape']) def test_invalid_key(self): + # This test also tests its capability of implicitly switching between + # different backends (due to different lmdb path) transform = LoadImageFromLMDB() + with self.assertRaises(KeyError): + results = transform(copy.deepcopy(self.results2)) + with self.assertRaises(IOError): + transform(copy.deepcopy(self.broken_results)) + transform = LoadImageFromLMDB(ignore_empty=True) results = transform(copy.deepcopy(self.results2)) - self.assertEqual(results, None) + self.assertIsNone(results) + results = transform(copy.deepcopy(self.broken_results)) + self.assertIsNone(results) def test_to_float32(self): transform = LoadImageFromLMDB(to_float32=True) @@ -227,8 +245,7 @@ class TestLoadImageFromLMDB(TestCase): def test_repr(self): transform = LoadImageFromLMDB() - assert repr(transform) == ( - 'LoadImageFromLMDB(ignore_empty=False, ' - "to_float32=False, color_type='color', " - "imdecode_backend='cv2', " - "file_client_args={'backend': 'lmdb', 'db_path': ''})") + assert repr(transform) == ('LoadImageFromLMDB(ignore_empty=False, ' + "to_float32=False, color_type='color', " + "imdecode_backend='cv2', " + 'file_client_args={})')