mirror of https://github.com/open-mmlab/mmocr.git
[Enhancement] Support loading different lmdb datasets in LoadImageFromLMDB (#1293)
* [Enhancement] Support loading different lmdb datasets in LoadImageFromLMDB * add docstrpull/1308/head
parent
d27b2fd84f
commit
e760dcd1dd
|
@ -44,8 +44,8 @@ class LoadImageFromFile(MMCV_LoadImageFromFile):
|
|||
ignore_empty (bool): Whether to allow loading empty image or file path
|
||||
not existent. Defaults to False.
|
||||
min_size (int): The minimum size of the image to be loaded. If the
|
||||
image is smaller than the minimum size, it will be ignored.
|
||||
Defaults to 0.
|
||||
image is smaller than the minimum size, it will be regarded as a
|
||||
broken image. Defaults to 0.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -89,18 +89,15 @@ class LoadImageFromFile(MMCV_LoadImageFromFile):
|
|||
return None
|
||||
else:
|
||||
raise e
|
||||
if self.ignore_empty and img is None:
|
||||
warnings.warn(f'Ignore broken image: {filename}')
|
||||
return None
|
||||
elif img is None:
|
||||
if img is None or min(img.shape[:2]) < self.min_size:
|
||||
if self.ignore_empty:
|
||||
warnings.warn(f'Ignore broken image: {filename}')
|
||||
return None
|
||||
raise IOError(f'{filename} is broken')
|
||||
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
if min(img.shape[:2]) < self.min_size:
|
||||
return None
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
|
@ -473,25 +470,42 @@ class LoadImageFromLMDB(BaseTransform):
|
|||
argument for :func:``mmcv.imfrombytes``.
|
||||
See :func:``mmcv.imfrombytes`` for details.
|
||||
Defaults to 'cv2'.
|
||||
file_client_args (dict): Arguments to instantiate a FileClient.
|
||||
See :class:`mmengine.fileio.FileClient` for details.
|
||||
Defaults to ``dict(backend='lmdb', db_path='')``.
|
||||
file_client_args (dict): Arguments to instantiate a FileClient except
|
||||
for ``backend`` and ``db_path``. See
|
||||
:class:`mmengine.fileio.FileClient` for details.
|
||||
Defaults to ``dict()``.
|
||||
ignore_empty (bool): Whether to allow loading empty image or file path
|
||||
not existent. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
to_float32: bool = False,
|
||||
color_type: str = 'color',
|
||||
imdecode_backend: str = 'cv2',
|
||||
file_client_args: dict = dict(backend='lmdb', db_path=''),
|
||||
ignore_empty: bool = False) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
to_float32: bool = False,
|
||||
color_type: str = 'color',
|
||||
imdecode_backend: str = 'cv2',
|
||||
file_client_args: dict = dict(),
|
||||
ignore_empty: bool = False,
|
||||
) -> None:
|
||||
self.ignore_empty = ignore_empty
|
||||
self.to_float32 = to_float32
|
||||
self.color_type = color_type
|
||||
self.imdecode_backend = imdecode_backend
|
||||
self.file_client_args = file_client_args.copy()
|
||||
self.file_client = mmengine.FileClient(**self.file_client_args)
|
||||
self.file_clients = {}
|
||||
if 'backend' in file_client_args or 'db_path' in file_client_args:
|
||||
raise ValueError(
|
||||
'"file_client_args" should not contain "backend" and "db_path"'
|
||||
)
|
||||
self.file_client_args = file_client_args
|
||||
|
||||
def _get_client(self, db_path: str) -> mmengine.FileClient:
|
||||
"""Get a FileClient bound to the given db_path.
|
||||
|
||||
If the client for this db_path is not initialized, initialize it.
|
||||
"""
|
||||
if self.file_clients.get(db_path) is None:
|
||||
self.file_clients[db_path] = mmengine.FileClient(
|
||||
backend='lmdb', db_path=db_path, **self.file_client_args)
|
||||
return self.file_clients.get(db_path)
|
||||
|
||||
def transform(self, results: dict) -> Optional[dict]:
|
||||
"""Functions to load image from LMDB file.
|
||||
|
@ -505,14 +519,21 @@ class LoadImageFromLMDB(BaseTransform):
|
|||
filename = results['img_path']
|
||||
lmdb_path = os.path.dirname(filename)
|
||||
image_key = os.path.basename(filename)
|
||||
self.file_client.client.db_path = lmdb_path
|
||||
img_bytes = self.file_client.get(image_key)
|
||||
file_client = self._get_client(lmdb_path)
|
||||
img_bytes = file_client.get(image_key)
|
||||
|
||||
if img_bytes is None:
|
||||
return None
|
||||
try:
|
||||
img = mmcv.imfrombytes(img_bytes, flag=self.color_type)
|
||||
except OSError:
|
||||
return None
|
||||
if self.ignore_empty:
|
||||
return None
|
||||
raise KeyError(f'Image not found in lmdb: {filename}')
|
||||
|
||||
img = mmcv.imfrombytes(img_bytes, flag=self.color_type)
|
||||
|
||||
if img is None:
|
||||
if self.ignore_empty:
|
||||
return None
|
||||
raise IOError(f'{filename} is broken')
|
||||
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
|
|
Binary file not shown.
Binary file not shown.
|
@ -36,9 +36,11 @@ class TestLoadImageFromFile(TestCase):
|
|||
self.assertEquals(results['img'].dtype, np.float32)
|
||||
|
||||
# min_size
|
||||
transform = LoadImageFromFile(min_size=26, ignore_empty=True)
|
||||
self.assertIsNone(transform(copy.deepcopy(results)))
|
||||
transform = LoadImageFromFile(min_size=26)
|
||||
results = transform(copy.deepcopy(results))
|
||||
self.assertIsNone(results)
|
||||
with self.assertRaises(IOError):
|
||||
transform(copy.deepcopy(results))
|
||||
|
||||
# test load empty
|
||||
fake_img_path = osp.join(data_prefix, 'fake.jpg')
|
||||
|
@ -197,12 +199,19 @@ class TestLoadImageFromLMDB(TestCase):
|
|||
self.results1 = {
|
||||
'img_path': f'tests/data/rec_toy_dataset/imgs.lmdb/{img_key}'
|
||||
}
|
||||
self.broken_results = {
|
||||
'img_path': f'tests/data/rec_toy_dataset/broken.lmdb/{img_key}'
|
||||
}
|
||||
|
||||
img_key = 'image-%09d' % 100
|
||||
self.results2 = {
|
||||
'img_path': f'tests/data/rec_toy_dataset/imgs.lmdb/{img_key}'
|
||||
}
|
||||
|
||||
def test_init(self):
|
||||
with self.assertRaises(ValueError):
|
||||
LoadImageFromLMDB(file_client_args=dict(backend='disk'))
|
||||
|
||||
def test_transform(self):
|
||||
transform = LoadImageFromLMDB()
|
||||
results = transform(copy.deepcopy(self.results1))
|
||||
|
@ -212,9 +221,18 @@ class TestLoadImageFromLMDB(TestCase):
|
|||
self.assertEqual(results['ori_shape'], results['img_shape'])
|
||||
|
||||
def test_invalid_key(self):
|
||||
# This test also tests its capability of implicitly switching between
|
||||
# different backends (due to different lmdb path)
|
||||
transform = LoadImageFromLMDB()
|
||||
with self.assertRaises(KeyError):
|
||||
results = transform(copy.deepcopy(self.results2))
|
||||
with self.assertRaises(IOError):
|
||||
transform(copy.deepcopy(self.broken_results))
|
||||
transform = LoadImageFromLMDB(ignore_empty=True)
|
||||
results = transform(copy.deepcopy(self.results2))
|
||||
self.assertEqual(results, None)
|
||||
self.assertIsNone(results)
|
||||
results = transform(copy.deepcopy(self.broken_results))
|
||||
self.assertIsNone(results)
|
||||
|
||||
def test_to_float32(self):
|
||||
transform = LoadImageFromLMDB(to_float32=True)
|
||||
|
@ -227,8 +245,7 @@ class TestLoadImageFromLMDB(TestCase):
|
|||
|
||||
def test_repr(self):
|
||||
transform = LoadImageFromLMDB()
|
||||
assert repr(transform) == (
|
||||
'LoadImageFromLMDB(ignore_empty=False, '
|
||||
"to_float32=False, color_type='color', "
|
||||
"imdecode_backend='cv2', "
|
||||
"file_client_args={'backend': 'lmdb', 'db_path': ''})")
|
||||
assert repr(transform) == ('LoadImageFromLMDB(ignore_empty=False, '
|
||||
"to_float32=False, color_type='color', "
|
||||
"imdecode_backend='cv2', "
|
||||
'file_client_args={})')
|
||||
|
|
Loading…
Reference in New Issue