[Enhancement] Support loading different lmdb datasets in LoadImageFromLMDB (#1293)

* [Enhancement] Support loading different lmdb datasets in LoadImageFromLMDB

* add docstr
pull/1308/head
Tong Gao 2022-08-22 16:43:25 +08:00 committed by GitHub
parent d27b2fd84f
commit e760dcd1dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 73 additions and 35 deletions

View File

@ -44,8 +44,8 @@ class LoadImageFromFile(MMCV_LoadImageFromFile):
ignore_empty (bool): Whether to allow loading empty image or file path
not existent. Defaults to False.
min_size (int): The minimum size of the image to be loaded. If the
image is smaller than the minimum size, it will be ignored.
Defaults to 0.
image is smaller than the minimum size, it will be regarded as a
broken image. Defaults to 0.
"""
def __init__(self,
@ -89,18 +89,15 @@ class LoadImageFromFile(MMCV_LoadImageFromFile):
return None
else:
raise e
if self.ignore_empty and img is None:
warnings.warn(f'Ignore broken image: {filename}')
return None
elif img is None:
if img is None or min(img.shape[:2]) < self.min_size:
if self.ignore_empty:
warnings.warn(f'Ignore broken image: {filename}')
return None
raise IOError(f'{filename} is broken')
if self.to_float32:
img = img.astype(np.float32)
if min(img.shape[:2]) < self.min_size:
return None
results['img'] = img
results['img_shape'] = img.shape[:2]
results['ori_shape'] = img.shape[:2]
@ -473,25 +470,42 @@ class LoadImageFromLMDB(BaseTransform):
argument for :func:``mmcv.imfrombytes``.
See :func:``mmcv.imfrombytes`` for details.
Defaults to 'cv2'.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmengine.fileio.FileClient` for details.
Defaults to ``dict(backend='lmdb', db_path='')``.
file_client_args (dict): Arguments to instantiate a FileClient except
for ``backend`` and ``db_path``. See
:class:`mmengine.fileio.FileClient` for details.
Defaults to ``dict()``.
ignore_empty (bool): Whether to allow loading empty image or file path
not existent. Defaults to False.
"""
def __init__(self,
to_float32: bool = False,
color_type: str = 'color',
imdecode_backend: str = 'cv2',
file_client_args: dict = dict(backend='lmdb', db_path=''),
ignore_empty: bool = False) -> None:
def __init__(
self,
to_float32: bool = False,
color_type: str = 'color',
imdecode_backend: str = 'cv2',
file_client_args: dict = dict(),
ignore_empty: bool = False,
) -> None:
self.ignore_empty = ignore_empty
self.to_float32 = to_float32
self.color_type = color_type
self.imdecode_backend = imdecode_backend
self.file_client_args = file_client_args.copy()
self.file_client = mmengine.FileClient(**self.file_client_args)
self.file_clients = {}
if 'backend' in file_client_args or 'db_path' in file_client_args:
raise ValueError(
'"file_client_args" should not contain "backend" and "db_path"'
)
self.file_client_args = file_client_args
def _get_client(self, db_path: str) -> mmengine.FileClient:
"""Get a FileClient bound to the given db_path.
If the client for this db_path is not initialized, initialize it.
"""
if self.file_clients.get(db_path) is None:
self.file_clients[db_path] = mmengine.FileClient(
backend='lmdb', db_path=db_path, **self.file_client_args)
return self.file_clients.get(db_path)
def transform(self, results: dict) -> Optional[dict]:
"""Functions to load image from LMDB file.
@ -505,14 +519,21 @@ class LoadImageFromLMDB(BaseTransform):
filename = results['img_path']
lmdb_path = os.path.dirname(filename)
image_key = os.path.basename(filename)
self.file_client.client.db_path = lmdb_path
img_bytes = self.file_client.get(image_key)
file_client = self._get_client(lmdb_path)
img_bytes = file_client.get(image_key)
if img_bytes is None:
return None
try:
img = mmcv.imfrombytes(img_bytes, flag=self.color_type)
except OSError:
return None
if self.ignore_empty:
return None
raise KeyError(f'Image not found in lmdb: {filename}')
img = mmcv.imfrombytes(img_bytes, flag=self.color_type)
if img is None:
if self.ignore_empty:
return None
raise IOError(f'{filename} is broken')
if self.to_float32:
img = img.astype(np.float32)

View File

@ -36,9 +36,11 @@ class TestLoadImageFromFile(TestCase):
self.assertEquals(results['img'].dtype, np.float32)
# min_size
transform = LoadImageFromFile(min_size=26, ignore_empty=True)
self.assertIsNone(transform(copy.deepcopy(results)))
transform = LoadImageFromFile(min_size=26)
results = transform(copy.deepcopy(results))
self.assertIsNone(results)
with self.assertRaises(IOError):
transform(copy.deepcopy(results))
# test load empty
fake_img_path = osp.join(data_prefix, 'fake.jpg')
@ -197,12 +199,19 @@ class TestLoadImageFromLMDB(TestCase):
self.results1 = {
'img_path': f'tests/data/rec_toy_dataset/imgs.lmdb/{img_key}'
}
self.broken_results = {
'img_path': f'tests/data/rec_toy_dataset/broken.lmdb/{img_key}'
}
img_key = 'image-%09d' % 100
self.results2 = {
'img_path': f'tests/data/rec_toy_dataset/imgs.lmdb/{img_key}'
}
def test_init(self):
with self.assertRaises(ValueError):
LoadImageFromLMDB(file_client_args=dict(backend='disk'))
def test_transform(self):
transform = LoadImageFromLMDB()
results = transform(copy.deepcopy(self.results1))
@ -212,9 +221,18 @@ class TestLoadImageFromLMDB(TestCase):
self.assertEqual(results['ori_shape'], results['img_shape'])
def test_invalid_key(self):
# This test also tests its capability of implicitly switching between
# different backends (due to different lmdb path)
transform = LoadImageFromLMDB()
with self.assertRaises(KeyError):
results = transform(copy.deepcopy(self.results2))
with self.assertRaises(IOError):
transform(copy.deepcopy(self.broken_results))
transform = LoadImageFromLMDB(ignore_empty=True)
results = transform(copy.deepcopy(self.results2))
self.assertEqual(results, None)
self.assertIsNone(results)
results = transform(copy.deepcopy(self.broken_results))
self.assertIsNone(results)
def test_to_float32(self):
transform = LoadImageFromLMDB(to_float32=True)
@ -227,8 +245,7 @@ class TestLoadImageFromLMDB(TestCase):
def test_repr(self):
transform = LoadImageFromLMDB()
assert repr(transform) == (
'LoadImageFromLMDB(ignore_empty=False, '
"to_float32=False, color_type='color', "
"imdecode_backend='cv2', "
"file_client_args={'backend': 'lmdb', 'db_path': ''})")
assert repr(transform) == ('LoadImageFromLMDB(ignore_empty=False, '
"to_float32=False, color_type='color', "
"imdecode_backend='cv2', "
'file_client_args={})')