mirror of https://github.com/open-mmlab/mmocr.git
[Utils] Migrate datasets/utils
parent
9e9c34d74c
commit
ab6e897c6b
|
@ -6,7 +6,6 @@ from .ocr_seg_dataset import OCRSegDataset
|
|||
from .pipelines import * # NOQA
|
||||
from .recog_lmdb_dataset import RecogLMDBDataset
|
||||
from .recog_text_dataset import RecogTextDataset
|
||||
from .utils import * # NOQA
|
||||
|
||||
__all__ = [
|
||||
'IcdarDataset', 'OCRDataset', 'OCRSegDataset', 'PARSERS', 'LOADERS',
|
||||
|
|
|
@ -1,8 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .loader import AnnFileLoader, HardDiskLoader, LmdbLoader
|
||||
from .parser import LineJsonParser, LineStrParser
|
||||
|
||||
__all__ = [
|
||||
'HardDiskLoader', 'LmdbLoader', 'AnnFileLoader', 'LineStrParser',
|
||||
'LineJsonParser'
|
||||
]
|
|
@ -1,197 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import json
|
||||
import os
|
||||
import os.path as osp
|
||||
import shutil
|
||||
import warnings
|
||||
|
||||
import mmcv
|
||||
|
||||
from mmocr import digit_version
|
||||
from mmocr.utils import list_from_file
|
||||
|
||||
|
||||
# TODO: remove
|
||||
class LmdbAnnFileBackend:
|
||||
"""Lmdb storage backend for annotation file.
|
||||
|
||||
Args:
|
||||
lmdb_path (str): Lmdb file path.
|
||||
"""
|
||||
|
||||
def __init__(self, lmdb_path, encoding='utf8'):
|
||||
"""Currently we support two lmdb formats, one is the lmdb file with
|
||||
only labels generated by txt2lmdb (deprecated), and one is the lmdb
|
||||
file generated by recog2lmdb.
|
||||
|
||||
The former stores string in 'filename text' format directly in lmdb,
|
||||
while the latter uses a more reasonable image_key as well as label_key
|
||||
for querying.
|
||||
"""
|
||||
self.lmdb_path = lmdb_path
|
||||
self.encoding = encoding
|
||||
self.deprecated_format = False
|
||||
env = self._get_env()
|
||||
with env.begin(write=False) as txn:
|
||||
try:
|
||||
self.total_number = int(
|
||||
txn.get(b'num-samples').decode(self.encoding))
|
||||
except AttributeError:
|
||||
warnings.warn(
|
||||
'DeprecationWarning: The lmdb dataset generated with '
|
||||
'txt2lmdb will be deprecate, please use the latest '
|
||||
'tools/data/utils/recog2lmdb to generate lmdb dataset. '
|
||||
'See https://mmocr.readthedocs.io/en/latest/tools.html#'
|
||||
'convert-text-recognition-dataset-to-lmdb-format for '
|
||||
'details.')
|
||||
self.total_number = int(
|
||||
txn.get(b'total_number').decode(self.encoding))
|
||||
self.deprecated_format = True
|
||||
# The lmdb file may contain only the label, or it may contain both
|
||||
# the label and the image, so we use image_key here for probing.
|
||||
image_key = f'image-{1:09d}'
|
||||
if txn.get(image_key.encode(encoding)) is None:
|
||||
self.label_only = True
|
||||
else:
|
||||
self.label_only = False
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Retrieve one line from lmdb file by index.
|
||||
|
||||
In order to support space
|
||||
reading, the returned lines are in the form of json, such as
|
||||
'{'filename': 'image1.jpg' ,'text':'HELLO'}'
|
||||
"""
|
||||
if not hasattr(self, 'env'):
|
||||
self.env = self._get_env()
|
||||
|
||||
with self.env.begin(write=False) as txn:
|
||||
if self.deprecated_format:
|
||||
line = txn.get(str(index).encode('utf-8')).decode(
|
||||
self.encoding)
|
||||
keys = line.strip('/n').split(' ')
|
||||
if len(keys) == 4:
|
||||
filename, height, width, annotations = keys
|
||||
line = json.dumps(
|
||||
dict(
|
||||
filename=filename,
|
||||
height=height,
|
||||
width=width,
|
||||
annotations=annotations),
|
||||
ensure_ascii=False)
|
||||
elif len(keys) == 2:
|
||||
filename, text = keys
|
||||
line = json.dumps(
|
||||
dict(filename=filename, text=text), ensure_ascii=False)
|
||||
else:
|
||||
index = index + 1
|
||||
label_key = f'label-{index:09d}'
|
||||
if self.label_only:
|
||||
line = txn.get(label_key.encode('utf-8')).decode(
|
||||
self.encoding)
|
||||
else:
|
||||
img_key = f'image-{index:09d}'
|
||||
text = txn.get(label_key.encode('utf-8')).decode(
|
||||
self.encoding)
|
||||
line = json.dumps(
|
||||
dict(filename=img_key, text=text), ensure_ascii=False)
|
||||
return line
|
||||
|
||||
def __len__(self):
|
||||
return self.total_number
|
||||
|
||||
def _get_env(self):
|
||||
try:
|
||||
import lmdb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Please install lmdb to enable LmdbAnnFileBackend.')
|
||||
return lmdb.open(
|
||||
self.lmdb_path,
|
||||
max_readers=1,
|
||||
readonly=True,
|
||||
lock=False,
|
||||
readahead=False,
|
||||
meminit=False,
|
||||
)
|
||||
|
||||
def close(self):
|
||||
self.env.close()
|
||||
|
||||
|
||||
class HardDiskAnnFileBackend:
|
||||
"""Load annotation file with raw hard disks storage backend."""
|
||||
|
||||
def __init__(self, file_format='txt'):
|
||||
assert file_format in ['txt', 'lmdb']
|
||||
self.file_format = file_format
|
||||
|
||||
def __call__(self, ann_file):
|
||||
if self.file_format == 'lmdb':
|
||||
return LmdbAnnFileBackend(ann_file)
|
||||
|
||||
return list_from_file(ann_file)
|
||||
|
||||
|
||||
class PetrelAnnFileBackend:
|
||||
"""Load annotation file with petrel storage backend."""
|
||||
|
||||
def __init__(self, file_format='txt', save_dir='tmp_dir'):
|
||||
assert file_format in ['txt', 'lmdb']
|
||||
self.file_format = file_format
|
||||
self.save_dir = save_dir
|
||||
|
||||
def __call__(self, ann_file):
|
||||
file_client = mmcv.FileClient(backend='petrel')
|
||||
|
||||
if self.file_format == 'lmdb':
|
||||
mmcv_version = digit_version(mmcv.__version__)
|
||||
if mmcv_version < digit_version('1.3.16'):
|
||||
raise Exception('Please update mmcv to 1.3.16 or higher '
|
||||
'to enable "get_local_path" of "FileClient".')
|
||||
assert file_client.isdir(ann_file)
|
||||
files = file_client.list_dir_or_file(ann_file)
|
||||
|
||||
ann_file_rel_path = ann_file.split('s3://')[-1]
|
||||
ann_file_dir = osp.dirname(ann_file_rel_path)
|
||||
ann_file_name = osp.basename(ann_file_rel_path)
|
||||
local_dir = osp.join(self.save_dir, ann_file_dir, ann_file_name)
|
||||
if osp.exists(local_dir):
|
||||
warnings.warn(
|
||||
f'local_ann_file: {local_dir} is already existed and '
|
||||
'will be used. If it is not the correct ann_file '
|
||||
'corresponding to {ann_file}, please remove it or '
|
||||
'change "save_dir" first then try again.')
|
||||
else:
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
print(f'Fetching {ann_file} to {local_dir}...')
|
||||
for each_file in files:
|
||||
tmp_file_path = file_client.join_path(ann_file, each_file)
|
||||
with file_client.get_local_path(
|
||||
tmp_file_path) as local_path:
|
||||
shutil.copy(local_path, osp.join(local_dir, each_file))
|
||||
|
||||
return LmdbAnnFileBackend(local_dir)
|
||||
|
||||
lines = str(file_client.get(ann_file), encoding='utf-8').split('\n')
|
||||
|
||||
return [x for x in lines if x.strip() != '']
|
||||
|
||||
|
||||
class HTTPAnnFileBackend:
|
||||
"""Load annotation file with http storage backend."""
|
||||
|
||||
def __init__(self, file_format='txt'):
|
||||
assert file_format in ['txt', 'lmdb']
|
||||
self.file_format = file_format
|
||||
|
||||
def __call__(self, ann_file):
|
||||
file_client = mmcv.FileClient(backend='http')
|
||||
|
||||
if self.file_format == 'lmdb':
|
||||
raise NotImplementedError(
|
||||
'Loading lmdb file on http is not supported yet.')
|
||||
|
||||
lines = str(file_client.get(ann_file), encoding='utf-8').split('\n')
|
||||
|
||||
return [x for x in lines if x.strip() != '']
|
|
@ -1,113 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
from mmocr.datasets.builder import LOADERS, PARSERS
|
||||
from .backend import (HardDiskAnnFileBackend, HTTPAnnFileBackend,
|
||||
PetrelAnnFileBackend)
|
||||
|
||||
|
||||
# TODO: remove
|
||||
@LOADERS.register_module()
|
||||
class AnnFileLoader:
|
||||
"""Annotation file loader to load annotations from ann_file, and parse raw
|
||||
annotation to dict format with certain parser.
|
||||
|
||||
Args:
|
||||
ann_file (str): Annotation file path.
|
||||
parser (dict): Dictionary to construct parser
|
||||
to parse original annotation infos.
|
||||
repeat (int|float): Repeated times of dataset.
|
||||
file_storage_backend (str): The storage backend type for annotation
|
||||
file. Options are "disk", "http" and "petrel". Default: "disk".
|
||||
file_format (str): The format of annotation file. Options are
|
||||
"txt" and "lmdb". Default: "txt".
|
||||
"""
|
||||
|
||||
_backends = {
|
||||
'disk': HardDiskAnnFileBackend,
|
||||
'petrel': PetrelAnnFileBackend,
|
||||
'http': HTTPAnnFileBackend
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
ann_file,
|
||||
parser,
|
||||
repeat=1,
|
||||
file_storage_backend='disk',
|
||||
file_format='txt',
|
||||
**kwargs):
|
||||
assert isinstance(ann_file, str)
|
||||
assert isinstance(repeat, (int, float))
|
||||
assert isinstance(parser, dict)
|
||||
assert repeat > 0
|
||||
assert file_storage_backend in ['disk', 'http', 'petrel']
|
||||
assert file_format in ['txt', 'lmdb']
|
||||
|
||||
if file_format == 'lmdb' and parser['type'] == 'LineStrParser':
|
||||
raise ValueError('We only support using LineJsonParser '
|
||||
'to parse lmdb file. Please use LineJsonParser '
|
||||
'in the dataset config')
|
||||
self.parser = PARSERS.build(parser)
|
||||
self.repeat = repeat
|
||||
self.ann_file_backend = self._backends[file_storage_backend](
|
||||
file_format, **kwargs)
|
||||
self.ori_data_infos = self._load(ann_file)
|
||||
|
||||
def __len__(self):
|
||||
return int(len(self.ori_data_infos) * self.repeat)
|
||||
|
||||
def _load(self, ann_file):
|
||||
"""Load annotation file."""
|
||||
|
||||
return self.ann_file_backend(ann_file)
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Retrieve anno info of one instance with dict format."""
|
||||
return self.parser.get_item(self.ori_data_infos, index)
|
||||
|
||||
def __iter__(self):
|
||||
self._n = 0
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self._n < len(self):
|
||||
data = self[self._n]
|
||||
self._n += 1
|
||||
return data
|
||||
raise StopIteration
|
||||
|
||||
def close(self):
|
||||
"""For ann_file with lmdb format only."""
|
||||
self.ori_data_infos.close()
|
||||
|
||||
|
||||
@LOADERS.register_module()
|
||||
class HardDiskLoader(AnnFileLoader):
|
||||
"""Load txt format annotation file from hard disks."""
|
||||
|
||||
def __init__(self, ann_file, parser, repeat=1):
|
||||
warnings.warn(
|
||||
'HardDiskLoader is deprecated, please use '
|
||||
'AnnFileLoader instead.', UserWarning)
|
||||
super().__init__(
|
||||
ann_file,
|
||||
parser,
|
||||
repeat,
|
||||
file_storage_backend='disk',
|
||||
file_format='txt')
|
||||
|
||||
|
||||
@LOADERS.register_module()
|
||||
class LmdbLoader(AnnFileLoader):
|
||||
"""Load lmdb format annotation file from hard disks."""
|
||||
|
||||
def __init__(self, ann_file, parser, repeat=1):
|
||||
warnings.warn(
|
||||
'LmdbLoader is deprecated, please use '
|
||||
'AnnFileLoader instead.', UserWarning)
|
||||
super().__init__(
|
||||
ann_file,
|
||||
parser,
|
||||
repeat,
|
||||
file_storage_backend='disk',
|
||||
file_format='lmdb')
|
|
@ -17,6 +17,7 @@ from .fileio import list_from_file, list_to_file
|
|||
from .lmdb_util import recog2lmdb
|
||||
from .logger import get_root_logger
|
||||
from .model import revert_sync_batchnorm
|
||||
from .parsers import LineJsonParser, LineStrParser
|
||||
from .point_utils import dist_points2line, point_distance, points_center
|
||||
from .polygon_utils import (boundary_iou, crop_polygon, is_poly_inside_rect,
|
||||
offset_polygon, poly2bbox, poly2shapely,
|
||||
|
@ -41,5 +42,6 @@ __all__ = [
|
|||
'disable_text_recog_aug_test', 'box_center_distance', 'box_diag',
|
||||
'compute_hmean', 'filter_2dlist_result', 'many2one_match_ic13',
|
||||
'one2one_match_ic13', 'select_top_boundary', 'boundary_iou',
|
||||
'point_distance', 'points_center', 'fill_hole'
|
||||
'point_distance', 'points_center', 'fill_hole', 'LineJsonParser',
|
||||
'LineStrParser'
|
||||
]
|
||||
|
|
|
@ -4,7 +4,7 @@ import warnings
|
|||
from typing import Dict, Tuple
|
||||
|
||||
from mmocr.registry import TASK_UTILS
|
||||
from mmocr.utils import StringStrip
|
||||
from mmocr.utils.string_util import StringStrip
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
|
@ -1,22 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import json
|
||||
from unittest import TestCase
|
||||
|
||||
from mmocr.datasets import LineJsonParser, LineStrParser
|
||||
|
||||
|
||||
class TestParser(TestCase):
|
||||
|
||||
def test_line_json_parser(self):
|
||||
parser = LineJsonParser()
|
||||
line = json.dumps(dict(filename='test.jpg', text='mmocr'))
|
||||
data = parser(line)
|
||||
self.assertEqual(data['filename'], 'test.jpg')
|
||||
self.assertEqual(data['text'], 'mmocr')
|
||||
|
||||
def test_line_str_parser(self):
|
||||
parser = LineStrParser()
|
||||
line = 'test.jpg mmocr'
|
||||
data = parser(line)
|
||||
self.assertEqual(data['filename'], 'test.jpg')
|
||||
self.assertEqual(data['text'], 'mmocr')
|
|
@ -2,7 +2,7 @@
|
|||
import json
|
||||
from unittest import TestCase
|
||||
|
||||
from mmocr.datasets import LineJsonParser, LineStrParser
|
||||
from mmocr.utils import LineJsonParser, LineStrParser
|
||||
|
||||
|
||||
class TestParser(TestCase):
|
Loading…
Reference in New Issue