[Utils] Migrate datasets/utils

pull/1178/head
wangxinyu 2022-07-05 03:19:21 +00:00 committed by gaotongxiao
parent 9e9c34d74c
commit ab6e897c6b
8 changed files with 5 additions and 344 deletions

View File

@ -6,7 +6,6 @@ from .ocr_seg_dataset import OCRSegDataset
from .pipelines import * # NOQA
from .recog_lmdb_dataset import RecogLMDBDataset
from .recog_text_dataset import RecogTextDataset
from .utils import * # NOQA
__all__ = [
'IcdarDataset', 'OCRDataset', 'OCRSegDataset', 'PARSERS', 'LOADERS',

View File

@ -1,8 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .loader import AnnFileLoader, HardDiskLoader, LmdbLoader
from .parser import LineJsonParser, LineStrParser
__all__ = [
'HardDiskLoader', 'LmdbLoader', 'AnnFileLoader', 'LineStrParser',
'LineJsonParser'
]

View File

@ -1,197 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os
import os.path as osp
import shutil
import warnings
import mmcv
from mmocr import digit_version
from mmocr.utils import list_from_file
# TODO: remove
class LmdbAnnFileBackend:
"""Lmdb storage backend for annotation file.
Args:
lmdb_path (str): Lmdb file path.
"""
def __init__(self, lmdb_path, encoding='utf8'):
"""Currently we support two lmdb formats, one is the lmdb file with
only labels generated by txt2lmdb (deprecated), and one is the lmdb
file generated by recog2lmdb.
The former stores string in 'filename text' format directly in lmdb,
while the latter uses a more reasonable image_key as well as label_key
for querying.
"""
self.lmdb_path = lmdb_path
self.encoding = encoding
self.deprecated_format = False
env = self._get_env()
with env.begin(write=False) as txn:
try:
self.total_number = int(
txn.get(b'num-samples').decode(self.encoding))
except AttributeError:
warnings.warn(
'DeprecationWarning: The lmdb dataset generated with '
'txt2lmdb will be deprecate, please use the latest '
'tools/data/utils/recog2lmdb to generate lmdb dataset. '
'See https://mmocr.readthedocs.io/en/latest/tools.html#'
'convert-text-recognition-dataset-to-lmdb-format for '
'details.')
self.total_number = int(
txn.get(b'total_number').decode(self.encoding))
self.deprecated_format = True
# The lmdb file may contain only the label, or it may contain both
# the label and the image, so we use image_key here for probing.
image_key = f'image-{1:09d}'
if txn.get(image_key.encode(encoding)) is None:
self.label_only = True
else:
self.label_only = False
def __getitem__(self, index):
"""Retrieve one line from lmdb file by index.
In order to support space
reading, the returned lines are in the form of json, such as
'{'filename': 'image1.jpg' ,'text':'HELLO'}'
"""
if not hasattr(self, 'env'):
self.env = self._get_env()
with self.env.begin(write=False) as txn:
if self.deprecated_format:
line = txn.get(str(index).encode('utf-8')).decode(
self.encoding)
keys = line.strip('/n').split(' ')
if len(keys) == 4:
filename, height, width, annotations = keys
line = json.dumps(
dict(
filename=filename,
height=height,
width=width,
annotations=annotations),
ensure_ascii=False)
elif len(keys) == 2:
filename, text = keys
line = json.dumps(
dict(filename=filename, text=text), ensure_ascii=False)
else:
index = index + 1
label_key = f'label-{index:09d}'
if self.label_only:
line = txn.get(label_key.encode('utf-8')).decode(
self.encoding)
else:
img_key = f'image-{index:09d}'
text = txn.get(label_key.encode('utf-8')).decode(
self.encoding)
line = json.dumps(
dict(filename=img_key, text=text), ensure_ascii=False)
return line
def __len__(self):
return self.total_number
def _get_env(self):
try:
import lmdb
except ImportError:
raise ImportError(
'Please install lmdb to enable LmdbAnnFileBackend.')
return lmdb.open(
self.lmdb_path,
max_readers=1,
readonly=True,
lock=False,
readahead=False,
meminit=False,
)
def close(self):
self.env.close()
class HardDiskAnnFileBackend:
"""Load annotation file with raw hard disks storage backend."""
def __init__(self, file_format='txt'):
assert file_format in ['txt', 'lmdb']
self.file_format = file_format
def __call__(self, ann_file):
if self.file_format == 'lmdb':
return LmdbAnnFileBackend(ann_file)
return list_from_file(ann_file)
class PetrelAnnFileBackend:
"""Load annotation file with petrel storage backend."""
def __init__(self, file_format='txt', save_dir='tmp_dir'):
assert file_format in ['txt', 'lmdb']
self.file_format = file_format
self.save_dir = save_dir
def __call__(self, ann_file):
file_client = mmcv.FileClient(backend='petrel')
if self.file_format == 'lmdb':
mmcv_version = digit_version(mmcv.__version__)
if mmcv_version < digit_version('1.3.16'):
raise Exception('Please update mmcv to 1.3.16 or higher '
'to enable "get_local_path" of "FileClient".')
assert file_client.isdir(ann_file)
files = file_client.list_dir_or_file(ann_file)
ann_file_rel_path = ann_file.split('s3://')[-1]
ann_file_dir = osp.dirname(ann_file_rel_path)
ann_file_name = osp.basename(ann_file_rel_path)
local_dir = osp.join(self.save_dir, ann_file_dir, ann_file_name)
if osp.exists(local_dir):
warnings.warn(
f'local_ann_file: {local_dir} is already existed and '
'will be used. If it is not the correct ann_file '
'corresponding to {ann_file}, please remove it or '
'change "save_dir" first then try again.')
else:
os.makedirs(local_dir, exist_ok=True)
print(f'Fetching {ann_file} to {local_dir}...')
for each_file in files:
tmp_file_path = file_client.join_path(ann_file, each_file)
with file_client.get_local_path(
tmp_file_path) as local_path:
shutil.copy(local_path, osp.join(local_dir, each_file))
return LmdbAnnFileBackend(local_dir)
lines = str(file_client.get(ann_file), encoding='utf-8').split('\n')
return [x for x in lines if x.strip() != '']
class HTTPAnnFileBackend:
"""Load annotation file with http storage backend."""
def __init__(self, file_format='txt'):
assert file_format in ['txt', 'lmdb']
self.file_format = file_format
def __call__(self, ann_file):
file_client = mmcv.FileClient(backend='http')
if self.file_format == 'lmdb':
raise NotImplementedError(
'Loading lmdb file on http is not supported yet.')
lines = str(file_client.get(ann_file), encoding='utf-8').split('\n')
return [x for x in lines if x.strip() != '']

View File

@ -1,113 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from mmocr.datasets.builder import LOADERS, PARSERS
from .backend import (HardDiskAnnFileBackend, HTTPAnnFileBackend,
PetrelAnnFileBackend)
# TODO: remove
@LOADERS.register_module()
class AnnFileLoader:
"""Annotation file loader to load annotations from ann_file, and parse raw
annotation to dict format with certain parser.
Args:
ann_file (str): Annotation file path.
parser (dict): Dictionary to construct parser
to parse original annotation infos.
repeat (int|float): Repeated times of dataset.
file_storage_backend (str): The storage backend type for annotation
file. Options are "disk", "http" and "petrel". Default: "disk".
file_format (str): The format of annotation file. Options are
"txt" and "lmdb". Default: "txt".
"""
_backends = {
'disk': HardDiskAnnFileBackend,
'petrel': PetrelAnnFileBackend,
'http': HTTPAnnFileBackend
}
def __init__(self,
ann_file,
parser,
repeat=1,
file_storage_backend='disk',
file_format='txt',
**kwargs):
assert isinstance(ann_file, str)
assert isinstance(repeat, (int, float))
assert isinstance(parser, dict)
assert repeat > 0
assert file_storage_backend in ['disk', 'http', 'petrel']
assert file_format in ['txt', 'lmdb']
if file_format == 'lmdb' and parser['type'] == 'LineStrParser':
raise ValueError('We only support using LineJsonParser '
'to parse lmdb file. Please use LineJsonParser '
'in the dataset config')
self.parser = PARSERS.build(parser)
self.repeat = repeat
self.ann_file_backend = self._backends[file_storage_backend](
file_format, **kwargs)
self.ori_data_infos = self._load(ann_file)
def __len__(self):
return int(len(self.ori_data_infos) * self.repeat)
def _load(self, ann_file):
"""Load annotation file."""
return self.ann_file_backend(ann_file)
def __getitem__(self, index):
"""Retrieve anno info of one instance with dict format."""
return self.parser.get_item(self.ori_data_infos, index)
def __iter__(self):
self._n = 0
return self
def __next__(self):
if self._n < len(self):
data = self[self._n]
self._n += 1
return data
raise StopIteration
def close(self):
"""For ann_file with lmdb format only."""
self.ori_data_infos.close()
@LOADERS.register_module()
class HardDiskLoader(AnnFileLoader):
"""Load txt format annotation file from hard disks."""
def __init__(self, ann_file, parser, repeat=1):
warnings.warn(
'HardDiskLoader is deprecated, please use '
'AnnFileLoader instead.', UserWarning)
super().__init__(
ann_file,
parser,
repeat,
file_storage_backend='disk',
file_format='txt')
@LOADERS.register_module()
class LmdbLoader(AnnFileLoader):
"""Load lmdb format annotation file from hard disks."""
def __init__(self, ann_file, parser, repeat=1):
warnings.warn(
'LmdbLoader is deprecated, please use '
'AnnFileLoader instead.', UserWarning)
super().__init__(
ann_file,
parser,
repeat,
file_storage_backend='disk',
file_format='lmdb')

View File

@ -17,6 +17,7 @@ from .fileio import list_from_file, list_to_file
from .lmdb_util import recog2lmdb
from .logger import get_root_logger
from .model import revert_sync_batchnorm
from .parsers import LineJsonParser, LineStrParser
from .point_utils import dist_points2line, point_distance, points_center
from .polygon_utils import (boundary_iou, crop_polygon, is_poly_inside_rect,
offset_polygon, poly2bbox, poly2shapely,
@ -41,5 +42,6 @@ __all__ = [
'disable_text_recog_aug_test', 'box_center_distance', 'box_diag',
'compute_hmean', 'filter_2dlist_result', 'many2one_match_ic13',
'one2one_match_ic13', 'select_top_boundary', 'boundary_iou',
'point_distance', 'points_center', 'fill_hole'
'point_distance', 'points_center', 'fill_hole', 'LineJsonParser',
'LineStrParser'
]

View File

@ -4,7 +4,7 @@ import warnings
from typing import Dict, Tuple
from mmocr.registry import TASK_UTILS
from mmocr.utils import StringStrip
from mmocr.utils.string_util import StringStrip
@TASK_UTILS.register_module()

View File

@ -1,22 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import json
from unittest import TestCase
from mmocr.datasets import LineJsonParser, LineStrParser
class TestParser(TestCase):
def test_line_json_parser(self):
parser = LineJsonParser()
line = json.dumps(dict(filename='test.jpg', text='mmocr'))
data = parser(line)
self.assertEqual(data['filename'], 'test.jpg')
self.assertEqual(data['text'], 'mmocr')
def test_line_str_parser(self):
parser = LineStrParser()
line = 'test.jpg mmocr'
data = parser(line)
self.assertEqual(data['filename'], 'test.jpg')
self.assertEqual(data['text'], 'mmocr')

View File

@ -2,7 +2,7 @@
import json
from unittest import TestCase
from mmocr.datasets import LineJsonParser, LineStrParser
from mmocr.utils import LineJsonParser, LineStrParser
class TestParser(TestCase):