mirror of https://github.com/open-mmlab/mmocr.git
[Enhancement] Discard deprecated lmdb dataset format and only support img+label now (#1681)
* [Enhance] Discard deprecated lmdb dataset format and only support img+label now * rename * update * add ut * updata document * update docs * update test * update test * Update dataset.md Co-authored-by: liukuikun <641417025@qq.com>pull/1691/head
parent
b64565c10f
commit
6992923768
|
@ -230,38 +230,26 @@ Specifically, we provide three dataset classes [IcdarDataset](mmocr.datasets.Icd
|
|||
parser_cfg=dict(
|
||||
type='LineJsonParser',
|
||||
keys=['filename', 'text'],
|
||||
pipeline=[])
|
||||
pipeline=[]))
|
||||
```
|
||||
|
||||
3. [RecogLMDBDataset](mmocr.datasets.RecogLMDBDataset) supports LMDB format annotations for text recognition. You just need to add a new dataset config to `configs/textrecog/_base_/datasets` and specify its dataset type as `RecogLMDBDataset`. For example, the following example shows how to configure and load the **label-only lmdb** `label.lmdb` from the toy dataset.
|
||||
3. [RecogLMDBDataset](mmocr.datasets.RecogLMDBDataset) supports LMDB format dataset (img+labels) for text recognition. You just need to add a new dataset config to `configs/textrecog/_base_/datasets` and specify its dataset type as `RecogLMDBDataset`. For example, the following example shows how to configure and load the **both labels and images** `imgs.lmdb` from the toy dataset.
|
||||
|
||||
```python
|
||||
data_root = 'tests/data/rec_toy_dataset/'
|
||||
- set the dataset type to `RecogLMDBDataset`
|
||||
|
||||
lmdb_dataset = dict(
|
||||
type='RecogLMDBDataset',
|
||||
data_root=data_root,
|
||||
ann_file='label.lmdb',
|
||||
data_prefix=dict(img_path='imgs'),
|
||||
pipeline=[])
|
||||
```
|
||||
```python
|
||||
# Specify the dataset type as RecogLMDBDataset
|
||||
data_root = 'tests/data/rec_toy_dataset/'
|
||||
|
||||
When the `lmdb` file contains **both labels and images**, in addition to setting the dataset type to `RecogLMDBDataset` as in the above example, you also need to replace the [`LoadImageFromFile`](mmocr.datasets.transforms.LoadImageFromFile) with [`LoadImageFromLMDB`](mmocr.datasets.transforms.LoadImageFromLMDB) in the data pipelines.
|
||||
lmdb_dataset = dict(
|
||||
type='RecogLMDBDataset',
|
||||
data_root=data_root,
|
||||
ann_file='imgs.lmdb',
|
||||
pipeline=None)
|
||||
```
|
||||
|
||||
```python
|
||||
# Specify the dataset type as RecogLMDBDataset
|
||||
data_root = 'tests/data/rec_toy_dataset/'
|
||||
- replace the [`LoadImageFromFile`](mmocr.datasets.transforms.LoadImageFromFile) with [`LoadImageFromNDArray`](mmocr.datasets.transforms.LoadImageFromNDArray) in the data pipelines in `train_pipeline` and `test_pipeline`., for example:
|
||||
|
||||
lmdb_dataset = dict(
|
||||
type='RecogLMDBDataset',
|
||||
data_root=data_root,
|
||||
ann_file='imgs.lmdb',
|
||||
data_prefix=dict(img_path='imgs.lmdb'), # setting the img_path as the lmdb name
|
||||
pipeline=[])
|
||||
```
|
||||
|
||||
Also, replacing the image loading transforms in `train_pipeline` and `test_pipeline`, for example:
|
||||
|
||||
```python
|
||||
train_pipeline = [dict(type='LoadImageFromLMDB', color_type='grayscale', ignore_empty=True)]
|
||||
```
|
||||
```python
|
||||
train_pipeline = [dict(type='LoadImageFromNDArray')]
|
||||
```
|
||||
|
|
|
@ -232,20 +232,7 @@ python tools/dataset_converters/textrecog/data_migrator.py ${IN_PATH} ${OUT_PATH
|
|||
pipeline=[])
|
||||
```
|
||||
|
||||
3. [RecogLMDBDataset](mmocr.datasets.RecogLMDBDataset) 支持 0.x 版本文本识别任务的 `LMDB` 标注格式。只需要在 `configs/textrecog/_base_/datasets` 中添加新的数据集配置文件,并指定其数据集类型为 `RecogLMDBDataset` 即可。例如,以下示例展示了如何配置并读取 toy dataset 中的 `label.lmdb`,该 `lmdb` 文件**仅包含标签信息**。
|
||||
|
||||
```python
|
||||
data_root = 'tests/data/rec_toy_dataset/'
|
||||
|
||||
lmdb_dataset = dict(
|
||||
type='RecogLMDBDataset',
|
||||
data_root=data_root,
|
||||
ann_file='label.lmdb',
|
||||
data_prefix=dict(img_path='imgs'),
|
||||
pipeline=[])
|
||||
```
|
||||
|
||||
当 `lmdb` 文件中既包含标签信息又包含图像时,我们除了需要将数据集类型设定为 `RecogLMDBDataset` 以外,还需要将数据流水线中的图像读取方法由 [`LoadImageFromFile`](mmocr.datasets.transforms.LoadImageFromFile) 替换为 [`LoadImageFromLMDB`](mmocr.datasets.transforms.LoadImageFromLMDB)。
|
||||
3. [RecogLMDBDataset](mmocr.datasets.RecogLMDBDataset) 支持 0.x 版本文本识别任务**图像+文字**的 `LMDB` 标注格式。只需要在 `configs/textrecog/_base_/datasets` 中添加新的数据集配置文件,并指定其数据集类型为 `RecogLMDBDataset` 即可。例如,以下示例展示了如何配置并读取 toy dataset 中的 `imgs.lmdb`,该 `lmdb` 文件**包含标签和图像**。
|
||||
|
||||
```python
|
||||
# 将数据集类型设定为 RecogLMDBDataset
|
||||
|
@ -255,12 +242,11 @@ python tools/dataset_converters/textrecog/data_migrator.py ${IN_PATH} ${OUT_PATH
|
|||
type='RecogLMDBDataset',
|
||||
data_root=data_root,
|
||||
ann_file='imgs.lmdb',
|
||||
data_prefix=dict(img_path='imgs.lmdb'), # 将 img_path 设定为 lmdb 文件名
|
||||
pipeline=[])
|
||||
pipeline=None)
|
||||
```
|
||||
|
||||
还需把 `train_pipeline` 及 `test_pipeline` 中的数据读取方法进行替换:
|
||||
还需把 `train_pipeline` 及 `test_pipeline` 中的数据读取方法如 [`LoadImageFromFile`](mmocr.datasets.transforms.LoadImageFromFile) 替换为 [`LoadImageFromNDArray`](mmocr.datasets.transforms.LoadImageFromNDArray):
|
||||
|
||||
```python
|
||||
train_pipeline = [dict(type='LoadImageFromLMDB', color_type='grayscale', ignore_empty=True)]
|
||||
train_pipeline = [dict(type='LoadImageFromNDArray')]
|
||||
```
|
||||
|
|
|
@ -1,37 +1,35 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import json
|
||||
import os.path as osp
|
||||
import warnings
|
||||
from typing import Callable, List, Optional, Sequence, Union
|
||||
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import mmcv
|
||||
from mmengine.dataset import BaseDataset
|
||||
from mmengine.utils import is_abs
|
||||
|
||||
from mmocr.registry import DATASETS, TASK_UTILS
|
||||
from mmocr.registry import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class RecogLMDBDataset(BaseDataset):
|
||||
r"""RecogLMDBDataset for text recognition.
|
||||
|
||||
The annotation format should be in lmdb format. We support two lmdb
|
||||
formats, one is the lmdb file with only labels generated by txt2lmdb
|
||||
(deprecated), and another one is the lmdb file generated by recog2lmdb.
|
||||
The annotation format should be in lmdb format. The lmdb file should
|
||||
contain three keys: 'num-samples', 'label-xxxxxxxxx' and 'image-xxxxxxxxx',
|
||||
where 'xxxxxxxxx' is the index of the image. The value of 'num-samples' is
|
||||
the total number of images. The value of 'label-xxxxxxx' is the text label
|
||||
of the image, and the value of 'image-xxxxxxx' is the image data.
|
||||
|
||||
The former format stores string in `filename text` format directly in lmdb,
|
||||
while the latter uses `image_key` as well as `label_key` for querying.
|
||||
following keys:
|
||||
Each item fetched from this dataset will be a dict containing the
|
||||
following keys:
|
||||
|
||||
- img (ndarray): The loaded image.
|
||||
- img_path (str): The image key.
|
||||
- instances (list[dict]): The list of annotations for the image.
|
||||
|
||||
Args:
|
||||
ann_file (str): Annotation file path. Defaults to ''.
|
||||
parse_cfg (dict, optional): Config of parser for parsing annotations.
|
||||
Use ``LineJsonParser`` when the annotation file is in jsonl format
|
||||
with keys of ``filename`` and ``text``. The keys in parse_cfg
|
||||
should be consistent with the keys in jsonl annotations. The first
|
||||
key in parse_cfg should be the key of the path in jsonl
|
||||
annotations. The second key in parse_cfg should be the key of the
|
||||
text in jsonl Use ``LineStrParser`` when the annotation file is in
|
||||
txt format. Defaults to
|
||||
``dict(type='LineJsonParser', keys=['filename', 'text'])``.
|
||||
img_color_type (str): The flag argument for :func:``mmcv.imfrombytes``,
|
||||
which determines how the image bytes will be parsed. Defaults to
|
||||
'color'.
|
||||
metainfo (dict, optional): Meta information for dataset, such as class
|
||||
information. Defaults to None.
|
||||
data_root (str): The root directory for ``data_prefix`` and
|
||||
|
@ -60,50 +58,21 @@ class RecogLMDBDataset(BaseDataset):
|
|||
image. Defaults to 1000.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ann_file: str = '',
|
||||
parser_cfg: Optional[dict] = dict(
|
||||
type='LineJsonParser', keys=['filename', 'text']),
|
||||
metainfo: Optional[dict] = None,
|
||||
data_root: Optional[str] = '',
|
||||
data_prefix: dict = dict(img_path=''),
|
||||
filter_cfg: Optional[dict] = None,
|
||||
indices: Optional[Union[int, Sequence[int]]] = None,
|
||||
serialize_data: bool = True,
|
||||
pipeline: List[Union[dict, Callable]] = [],
|
||||
test_mode: bool = False,
|
||||
lazy_init: bool = False,
|
||||
max_refetch: int = 1000) -> None:
|
||||
if parser_cfg['type'] != 'LineJsonParser':
|
||||
raise ValueError('We only support using LineJsonParser '
|
||||
'to parse lmdb file. Please use LineJsonParser '
|
||||
'in the dataset config')
|
||||
self.parser = TASK_UTILS.build(parser_cfg)
|
||||
self.ann_file = ann_file
|
||||
self.deprecated_format = False
|
||||
env = self._get_env(root=data_root)
|
||||
with env.begin(write=False) as txn:
|
||||
try:
|
||||
self.total_number = int(
|
||||
txn.get(b'num-samples').decode('utf-8'))
|
||||
except AttributeError:
|
||||
warnings.warn(
|
||||
'DeprecationWarning: The lmdb dataset generated with '
|
||||
'txt2lmdb will be deprecate, please use the latest '
|
||||
'tools/data/utils/recog2lmdb to generate lmdb dataset. '
|
||||
'See https://mmocr.readthedocs.io/en/latest/tools.html#'
|
||||
'convert-text-recognition-dataset-to-lmdb-format for '
|
||||
'details.', UserWarning)
|
||||
self.total_number = int(
|
||||
txn.get(b'total_number').decode('utf-8'))
|
||||
self.deprecated_format = True
|
||||
# The lmdb file may contain only the label, or it may contain both
|
||||
# the label and the image, so we use image_key here for probing.
|
||||
image_key = f'image-{1:09d}'
|
||||
if txn.get(image_key.encode('utf-8')) is None:
|
||||
self.label_only = True
|
||||
else:
|
||||
self.label_only = False
|
||||
def __init__(
|
||||
self,
|
||||
ann_file: str = '',
|
||||
img_color_type: str = 'color',
|
||||
metainfo: Optional[dict] = None,
|
||||
data_root: Optional[str] = '',
|
||||
data_prefix: dict = dict(img_path=''),
|
||||
filter_cfg: Optional[dict] = None,
|
||||
indices: Optional[Union[int, Sequence[int]]] = None,
|
||||
serialize_data: bool = True,
|
||||
pipeline: List[Union[dict, Callable]] = [],
|
||||
test_mode: bool = False,
|
||||
lazy_init: bool = False,
|
||||
max_refetch: int = 1000,
|
||||
) -> None:
|
||||
|
||||
super().__init__(
|
||||
ann_file=ann_file,
|
||||
|
@ -118,6 +87,8 @@ class RecogLMDBDataset(BaseDataset):
|
|||
lazy_init=lazy_init,
|
||||
max_refetch=max_refetch)
|
||||
|
||||
self.color_type = img_color_type
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load annotations from an annotation file named as ``self.ann_file``
|
||||
|
||||
|
@ -125,33 +96,25 @@ class RecogLMDBDataset(BaseDataset):
|
|||
List[dict]: A list of annotation.
|
||||
"""
|
||||
if not hasattr(self, 'env'):
|
||||
self.env = self._get_env()
|
||||
self._make_env()
|
||||
with self.env.begin(write=False) as txn:
|
||||
self.total_number = int(
|
||||
txn.get(b'num-samples').decode('utf-8'))
|
||||
|
||||
data_list = []
|
||||
with self.env.begin(write=False) as txn:
|
||||
for i in range(self.total_number):
|
||||
if self.deprecated_format:
|
||||
line = txn.get(str(i).encode('utf-8')).decode('utf-8')
|
||||
filename, text = line.strip('/n').split(' ')
|
||||
line = json.dumps(
|
||||
dict(filename=filename, text=text), ensure_ascii=False)
|
||||
else:
|
||||
i = i + 1
|
||||
label_key = f'label-{i:09d}'
|
||||
if self.label_only:
|
||||
line = txn.get(
|
||||
label_key.encode('utf-8')).decode('utf-8')
|
||||
else:
|
||||
img_key = f'image-{i:09d}'
|
||||
text = txn.get(
|
||||
label_key.encode('utf-8')).decode('utf-8')
|
||||
line = json.dumps(
|
||||
dict(filename=img_key, text=text),
|
||||
ensure_ascii=False)
|
||||
idx = i + 1
|
||||
label_key = f'label-{idx:09d}'
|
||||
img_key = f'image-{idx:09d}'
|
||||
text = txn.get(label_key.encode('utf-8')).decode('utf-8')
|
||||
line = [img_key, text]
|
||||
data_list.append(self.parse_data_info(line))
|
||||
return data_list
|
||||
|
||||
def parse_data_info(self, raw_anno_info: str) -> Union[dict, List[dict]]:
|
||||
def parse_data_info(self,
|
||||
raw_anno_info: Tuple[Optional[str],
|
||||
str]) -> Union[dict, List[dict]]:
|
||||
"""Parse raw annotation to target format.
|
||||
|
||||
Args:
|
||||
|
@ -162,16 +125,32 @@ class RecogLMDBDataset(BaseDataset):
|
|||
(dict): Parsed annotation.
|
||||
"""
|
||||
data_info = {}
|
||||
parsed_anno = self.parser(raw_anno_info)
|
||||
img_path = osp.join(self.data_prefix['img_path'],
|
||||
parsed_anno[self.parser.keys[0]])
|
||||
|
||||
data_info['img_path'] = img_path
|
||||
data_info['instances'] = [dict(text=parsed_anno[self.parser.keys[1]])]
|
||||
img_key, text = raw_anno_info
|
||||
data_info['img_path'] = img_key
|
||||
data_info['instances'] = [dict(text=text)]
|
||||
return data_info
|
||||
|
||||
def _get_env(self, root=''):
|
||||
"""Get lmdb environment from self.ann_file.
|
||||
def prepare_data(self, idx) -> Any:
|
||||
"""Get data processed by ``self.pipeline``.
|
||||
|
||||
Args:
|
||||
idx (int): The index of ``data_info``.
|
||||
|
||||
Returns:
|
||||
Any: Depends on ``self.pipeline``.
|
||||
"""
|
||||
data_info = self.get_data_info(idx)
|
||||
with self.env.begin(write=False) as txn:
|
||||
img_bytes = txn.get(data_info['img_path'].encode('utf-8'))
|
||||
if img_bytes is None:
|
||||
return None
|
||||
data_info['img'] = mmcv.imfrombytes(
|
||||
img_bytes, flag=self.color_type)
|
||||
return self.pipeline(data_info)
|
||||
|
||||
def _make_env(self):
|
||||
"""Create lmdb environment from self.ann_file and save it to
|
||||
``self.env``.
|
||||
|
||||
Returns:
|
||||
Lmdb environment.
|
||||
|
@ -181,10 +160,11 @@ class RecogLMDBDataset(BaseDataset):
|
|||
except ImportError:
|
||||
raise ImportError(
|
||||
'Please install lmdb to enable RecogLMDBDataset.')
|
||||
lmdb_path = self.ann_file if is_abs(self.ann_file) else osp.join(
|
||||
root, self.ann_file)
|
||||
return lmdb.open(
|
||||
lmdb_path,
|
||||
if hasattr(self, 'env'):
|
||||
return
|
||||
|
||||
self.env = lmdb.open(
|
||||
self.ann_file,
|
||||
max_readers=1,
|
||||
readonly=True,
|
||||
lock=False,
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .adapters import MMDet2MMOCR, MMOCR2MMDet
|
||||
from .formatting import PackKIEInputs, PackTextDetInputs, PackTextRecogInputs
|
||||
from .loading import (LoadImageFromFile, LoadImageFromLMDB,
|
||||
LoadImageFromNDArray, LoadKIEAnnotations,
|
||||
LoadOCRAnnotations)
|
||||
from .loading import (LoadImageFromFile, LoadImageFromNDArray,
|
||||
LoadKIEAnnotations, LoadOCRAnnotations)
|
||||
from .ocr_transforms import (FixInvalidPolygon, RandomCrop, RandomRotate,
|
||||
RemoveIgnored, Resize)
|
||||
from .textdet_transforms import (BoundedScaleAspectJitter, RandomFlip,
|
||||
|
@ -21,7 +20,7 @@ __all__ = [
|
|||
'PackTextRecogInputs', 'RescaleToHeight', 'PadToWidth',
|
||||
'ShortScaleAspectJitter', 'RandomFlip', 'BoundedScaleAspectJitter',
|
||||
'PackKIEInputs', 'LoadKIEAnnotations', 'FixInvalidPolygon', 'MMDet2MMOCR',
|
||||
'MMOCR2MMDet', 'LoadImageFromLMDB', 'LoadImageFromFile',
|
||||
'LoadImageFromNDArray', 'CropHeight', 'TextRecogGeneralAug',
|
||||
'ImageContentJitter', 'ReversePixels', 'RemoveIgnored', 'ConditionApply'
|
||||
'MMOCR2MMDet', 'LoadImageFromFile', 'LoadImageFromNDArray', 'CropHeight',
|
||||
'TextRecogGeneralAug', 'ImageContentJitter', 'ReversePixels',
|
||||
'RemoveIgnored', 'ConditionApply'
|
||||
]
|
||||
|
|
|
@ -1,13 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import mmcv
|
||||
import mmengine
|
||||
import numpy as np
|
||||
from mmcv.transforms import BaseTransform
|
||||
from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations
|
||||
from mmcv.transforms import LoadImageFromFile as MMCV_LoadImageFromFile
|
||||
|
||||
|
@ -487,114 +485,3 @@ class LoadKIEAnnotations(MMCV_LoadAnnotations):
|
|||
repr_str += f'with_label={self.with_label}, '
|
||||
repr_str += f'with_text={self.with_text})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadImageFromLMDB(BaseTransform):
|
||||
"""Load an image from lmdb file. Only support LMDB file at disk.
|
||||
|
||||
LMDB file is organized with the following structure:
|
||||
lmdb
|
||||
|__data.mdb
|
||||
|__lock.mdb
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path (In LMDB img_path is a key in the format of "image-{i:09d}".)
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is an uint8 array.
|
||||
Defaults to False.
|
||||
color_type (str): The flag argument for :func:``mmcv.imfrombytes``.
|
||||
Defaults to 'color'.
|
||||
imdecode_backend (str): The image decoding backend type. The backend
|
||||
argument for :func:``mmcv.imfrombytes``.
|
||||
See :func:``mmcv.imfrombytes`` for details.
|
||||
Defaults to 'cv2'.
|
||||
file_client_args (dict): Arguments to instantiate a FileClient except
|
||||
for ``backend`` and ``db_path``. See
|
||||
:class:`mmengine.fileio.FileClient` for details.
|
||||
Defaults to ``dict()``.
|
||||
ignore_empty (bool): Whether to allow loading empty image or file path
|
||||
not existent. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
to_float32: bool = False,
|
||||
color_type: str = 'color',
|
||||
imdecode_backend: str = 'cv2',
|
||||
file_client_args: dict = dict(),
|
||||
ignore_empty: bool = False,
|
||||
) -> None:
|
||||
self.ignore_empty = ignore_empty
|
||||
self.to_float32 = to_float32
|
||||
self.color_type = color_type
|
||||
self.imdecode_backend = imdecode_backend
|
||||
self.file_clients = {}
|
||||
if 'backend' in file_client_args or 'db_path' in file_client_args:
|
||||
raise ValueError(
|
||||
'"file_client_args" should not contain "backend" and "db_path"'
|
||||
)
|
||||
self.file_client_args = file_client_args
|
||||
|
||||
def _get_client(self, db_path: str) -> mmengine.FileClient:
|
||||
"""Get a FileClient bound to the given db_path.
|
||||
|
||||
If the client for this db_path is not initialized, initialize it.
|
||||
"""
|
||||
if self.file_clients.get(db_path) is None:
|
||||
self.file_clients[db_path] = mmengine.FileClient(
|
||||
backend='lmdb', db_path=db_path, **self.file_client_args)
|
||||
return self.file_clients.get(db_path)
|
||||
|
||||
def transform(self, results: dict) -> Optional[dict]:
|
||||
"""Functions to load image from LMDB file.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
filename = results['img_path']
|
||||
lmdb_path = os.path.dirname(filename)
|
||||
image_key = os.path.basename(filename)
|
||||
file_client = self._get_client(lmdb_path)
|
||||
img_bytes = file_client.get(image_key)
|
||||
|
||||
if img_bytes is None:
|
||||
if self.ignore_empty:
|
||||
return None
|
||||
raise KeyError(f'Image not found in lmdb: {filename}')
|
||||
|
||||
img = mmcv.imfrombytes(img_bytes, flag=self.color_type)
|
||||
|
||||
if img is None:
|
||||
if self.ignore_empty:
|
||||
return None
|
||||
raise IOError(f'{filename} is broken')
|
||||
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f'ignore_empty={self.ignore_empty}, '
|
||||
f'to_float32={self.to_float32}, '
|
||||
f"color_type='{self.color_type}', "
|
||||
f"imdecode_backend='{self.imdecode_backend}', "
|
||||
f'file_client_args={self.file_client_args})')
|
||||
return repr_str
|
||||
|
|
|
@ -1,80 +1,20 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import tempfile
|
||||
from unittest import TestCase
|
||||
|
||||
import lmdb
|
||||
|
||||
from mmocr.datasets import RecogLMDBDataset
|
||||
|
||||
|
||||
class TestRecogLMDBDataset(TestCase):
|
||||
|
||||
def create_deprecated_format_lmdb(self, temp_dir):
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
env = lmdb.open(temp_dir, map_size=102400)
|
||||
cache = [(str(0).encode('utf-8'), b'test test')]
|
||||
with env.begin(write=True) as txn:
|
||||
cursor = txn.cursor()
|
||||
cursor.putmulti(cache, dupdata=False, overwrite=True)
|
||||
|
||||
cache = []
|
||||
cache.append((b'total_number', str(1).encode('utf-8')))
|
||||
with env.begin(write=True) as txn:
|
||||
cursor = txn.cursor()
|
||||
cursor.putmulti(cache, dupdata=False, overwrite=True)
|
||||
|
||||
def test_label_only_dataset(self):
|
||||
|
||||
# test initialization
|
||||
dataset = RecogLMDBDataset(
|
||||
ann_file='tests/data/rec_toy_dataset/label.lmdb',
|
||||
data_prefix=dict(img_path='imgs'),
|
||||
pipeline=[])
|
||||
dataset.full_init()
|
||||
self.assertEqual(len(dataset), 10)
|
||||
self.assertEqual(len(dataset.load_data_list()), 10)
|
||||
|
||||
# test load_data_list
|
||||
anno = dataset.load_data_list()[0]
|
||||
self.assertIn(anno['img_path'],
|
||||
['imgs/1223731.jpg', 'imgs\\1223731.jpg'])
|
||||
self.assertEqual(anno['instances'][0]['text'], 'GRAND')
|
||||
|
||||
def test_label_and_image_dataset(self):
|
||||
|
||||
# test initialization
|
||||
dataset = RecogLMDBDataset(
|
||||
ann_file='tests/data/rec_toy_dataset/imgs.lmdb',
|
||||
data_prefix=dict(img_path='imgs'),
|
||||
pipeline=[])
|
||||
ann_file='tests/data/rec_toy_dataset/imgs.lmdb', pipeline=[])
|
||||
dataset.full_init()
|
||||
self.assertEqual(len(dataset), 10)
|
||||
self.assertEqual(len(dataset.load_data_list()), 10)
|
||||
|
||||
# test load_data_list
|
||||
anno = dataset.load_data_list()[0]
|
||||
self.assertIn(anno['img_path'],
|
||||
[f'imgs/image-{1:09d}', f'imgs\\image-{1:09d}'])
|
||||
self.assertEqual(anno['instances'][0]['text'], 'GRAND')
|
||||
|
||||
def test_deprecated_format(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.create_deprecated_format_lmdb(
|
||||
os.path.join(tmpdirname, 'data'))
|
||||
dataset = RecogLMDBDataset(
|
||||
ann_file=os.path.join(tmpdirname, 'data'),
|
||||
data_prefix=dict(img_path='imgs'),
|
||||
pipeline=[])
|
||||
|
||||
warm_msg = 'DeprecationWarning: The lmdb dataset generated with '
|
||||
warm_msg += 'txt2lmdb will be deprecate, please use the latest '
|
||||
warm_msg += 'tools/data/utils/recog2lmdb to generate lmdb dataset.'
|
||||
warm_msg += ' See https://mmocr.readthedocs.io/en/'
|
||||
warm_msg += 'latest/tools.html#'
|
||||
warm_msg += 'convert-text-recognition-dataset-to-lmdb-format for '
|
||||
warm_msg += 'details.'
|
||||
|
||||
dataset.full_init()
|
||||
self.assertWarnsRegex(UserWarning, warm_msg)
|
||||
dataset.close()
|
||||
self.assertEqual(dataset[0]['img'].shape, (26, 67, 3))
|
||||
self.assertEqual(dataset[0]['instances'][0]['text'], 'GRAND')
|
||||
self.assertEqual(dataset[1]['img'].shape, (17, 37, 3))
|
||||
self.assertEqual(dataset[1]['instances'][0]['text'], 'HOTEL')
|
||||
|
|
|
@ -5,8 +5,8 @@ from unittest import TestCase
|
|||
|
||||
import numpy as np
|
||||
|
||||
from mmocr.datasets.transforms import (LoadImageFromFile, LoadImageFromLMDB,
|
||||
LoadKIEAnnotations, LoadOCRAnnotations)
|
||||
from mmocr.datasets.transforms import (LoadImageFromFile, LoadKIEAnnotations,
|
||||
LoadOCRAnnotations)
|
||||
|
||||
|
||||
class TestLoadImageFromFile(TestCase):
|
||||
|
@ -190,62 +190,3 @@ class TestLoadKIEAnnotations(TestCase):
|
|||
repr(self.load),
|
||||
'LoadKIEAnnotations(with_bbox=True, with_label=True, '
|
||||
'with_text=True)')
|
||||
|
||||
|
||||
class TestLoadImageFromLMDB(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
img_key = 'image-%09d' % 1
|
||||
self.results1 = {
|
||||
'img_path': f'tests/data/rec_toy_dataset/imgs.lmdb/{img_key}'
|
||||
}
|
||||
self.broken_results = {
|
||||
'img_path': f'tests/data/rec_toy_dataset/broken.lmdb/{img_key}'
|
||||
}
|
||||
|
||||
img_key = 'image-%09d' % 100
|
||||
self.results2 = {
|
||||
'img_path': f'tests/data/rec_toy_dataset/imgs.lmdb/{img_key}'
|
||||
}
|
||||
|
||||
def test_init(self):
|
||||
with self.assertRaises(ValueError):
|
||||
LoadImageFromLMDB(file_client_args=dict(backend='disk'))
|
||||
|
||||
def test_transform(self):
|
||||
transform = LoadImageFromLMDB()
|
||||
results = transform(copy.deepcopy(self.results1))
|
||||
self.assertIn('img', results)
|
||||
self.assertIsInstance(results['img'], np.ndarray)
|
||||
self.assertEqual(results['img'].shape[:2], results['img_shape'])
|
||||
self.assertEqual(results['ori_shape'], results['img_shape'])
|
||||
|
||||
def test_invalid_key(self):
|
||||
# This test also tests its capability of implicitly switching between
|
||||
# different backends (due to different lmdb path)
|
||||
transform = LoadImageFromLMDB()
|
||||
with self.assertRaises(KeyError):
|
||||
results = transform(copy.deepcopy(self.results2))
|
||||
with self.assertRaises(IOError):
|
||||
transform(copy.deepcopy(self.broken_results))
|
||||
transform = LoadImageFromLMDB(ignore_empty=True)
|
||||
results = transform(copy.deepcopy(self.results2))
|
||||
self.assertIsNone(results)
|
||||
results = transform(copy.deepcopy(self.broken_results))
|
||||
self.assertIsNone(results)
|
||||
|
||||
def test_to_float32(self):
|
||||
transform = LoadImageFromLMDB(to_float32=True)
|
||||
results = transform(copy.deepcopy(self.results1))
|
||||
self.assertIn('img', results)
|
||||
self.assertIsInstance(results['img'], np.ndarray)
|
||||
self.assertTrue(results['img'].dtype, np.float32)
|
||||
self.assertEqual(results['img'].shape[:2], results['img_shape'])
|
||||
self.assertEqual(results['ori_shape'], results['img_shape'])
|
||||
|
||||
def test_repr(self):
|
||||
transform = LoadImageFromLMDB()
|
||||
assert repr(transform) == ('LoadImageFromLMDB(ignore_empty=False, '
|
||||
"to_float32=False, color_type='color', "
|
||||
"imdecode_backend='cv2', "
|
||||
'file_client_args={})')
|
||||
|
|
Loading…
Reference in New Issue