Remove test dependency on tools

Signed-off-by: lizz <lizz@sensetime.com>
This commit is contained in:
lizz 2021-04-05 21:00:41 +08:00
parent cc1f103e1c
commit 09ffd284ee
4 changed files with 52 additions and 49 deletions

View File

@ -4,9 +4,10 @@ from mmdet.utils import get_root_logger
from .check_argument import (equal_len, is_2dlist, is_3dlist, is_ndarray_list, from .check_argument import (equal_len, is_2dlist, is_3dlist, is_ndarray_list,
is_none_or_type, is_type_list, valid_boundary) is_none_or_type, is_type_list, valid_boundary)
from .collect_env import collect_env from .collect_env import collect_env
from .lmdb_util import lmdb_converter
__all__ = [ __all__ = [
'Registry', 'build_from_cfg', 'get_root_logger', 'collect_env', 'Registry', 'build_from_cfg', 'get_root_logger', 'collect_env',
'is_3dlist', 'is_ndarray_list', 'is_type_list', 'is_none_or_type', 'is_3dlist', 'is_ndarray_list', 'is_type_list', 'is_none_or_type',
'equal_len', 'is_2dlist', 'valid_boundary' 'equal_len', 'is_2dlist', 'valid_boundary', 'lmdb_converter'
] ]

46
mmocr/utils/lmdb_util.py Normal file
View File

@ -0,0 +1,46 @@
import shutil
import sys
import time
from pathlib import Path
import lmdb
def lmdb_converter(imglist, output, batch_size=1000, coding='utf-8'):
# read imglist
with open(imglist) as f:
lines = f.readlines()
# create lmdb database
if Path(output).is_dir():
while True:
print('%s already exist, delete or not? [Y/n]' % output)
Yn = input().strip()
if Yn in ['Y', 'y']:
shutil.rmtree(output)
break
elif Yn in ['N', 'n']:
return
print('create database %s' % output)
Path(output).mkdir(parents=True, exist_ok=False)
env = lmdb.open(output, map_size=1099511627776)
# build lmdb
beg_time = time.strftime('%H:%M:%S')
for beg_index in range(0, len(lines), batch_size):
end_index = min(beg_index + batch_size, len(lines))
sys.stdout.write('\r[%s-%s], processing [%d-%d] / %d' %
(beg_time, time.strftime('%H:%M:%S'), beg_index,
end_index, len(lines)))
sys.stdout.flush()
batch = [(str(index).encode(coding), lines[index].encode(coding))
for index in range(beg_index, end_index)]
with env.begin(write=True) as txn:
cursor = txn.cursor()
cursor.putmulti(batch, dupdata=False, overwrite=True)
sys.stdout.write('\n')
with env.begin(write=True) as txn:
key = 'total_number'.encode(coding)
value = str(len(lines)).encode(coding)
txn.put(key, value)
print('done', flush=True)

View File

@ -3,9 +3,9 @@ import os.path as osp
import tempfile import tempfile
import pytest import pytest
from tools.data.utils.txt2lmdb import converter
from mmocr.datasets.utils.loader import HardDiskLoader, LmdbLoader, Loader from mmocr.datasets.utils.loader import HardDiskLoader, LmdbLoader, Loader
from mmocr.utils import lmdb_converter
def _create_dummy_line_str_file(ann_file): def _create_dummy_line_str_file(ann_file):
@ -63,7 +63,7 @@ def test_loader():
# test lmdb loader and line str parser # test lmdb loader and line str parser
_create_dummy_line_str_file(ann_file) _create_dummy_line_str_file(ann_file)
lmdb_file = osp.join(tmp_dir.name, 'fake_data.lmdb') lmdb_file = osp.join(tmp_dir.name, 'fake_data.lmdb')
converter(ann_file, lmdb_file) lmdb_converter(ann_file, lmdb_file)
lmdb_loader = LmdbLoader(lmdb_file, parser, repeat=1) lmdb_loader = LmdbLoader(lmdb_file, parser, repeat=1)
assert lmdb_loader[0] == {'filename': 'sample1.jpg', 'text': 'hello'} assert lmdb_loader[0] == {'filename': 'sample1.jpg', 'text': 'hello'}

View File

@ -1,50 +1,6 @@
import argparse import argparse
import shutil
import sys
import time
from pathlib import Path
import lmdb from mmocr.utils import lmdb_converter
def converter(imglist, output, batch_size=1000, coding='utf-8'):
# read imglist
with open(imglist) as f:
lines = f.readlines()
# create lmdb database
if Path(output).is_dir():
while True:
print('%s already exist, delete or not? [Y/n]' % output)
Yn = input().strip()
if Yn in ['Y', 'y']:
shutil.rmtree(output)
break
elif Yn in ['N', 'n']:
return
print('create database %s' % output)
Path(output).mkdir(parents=True, exist_ok=False)
env = lmdb.open(output, map_size=1099511627776)
# build lmdb
beg_time = time.strftime('%H:%M:%S')
for beg_index in range(0, len(lines), batch_size):
end_index = min(beg_index + batch_size, len(lines))
sys.stdout.write('\r[%s-%s], processing [%d-%d] / %d' %
(beg_time, time.strftime('%H:%M:%S'), beg_index,
end_index, len(lines)))
sys.stdout.flush()
batch = [(str(index).encode(coding), lines[index].encode(coding))
for index in range(beg_index, end_index)]
with env.begin(write=True) as txn:
cursor = txn.cursor()
cursor.putmulti(batch, dupdata=False, overwrite=True)
sys.stdout.write('\n')
with env.begin(write=True) as txn:
key = 'total_number'.encode(coding)
value = str(len(lines)).encode(coding)
txn.put(key, value)
print('done', flush=True)
def main(): def main():
@ -66,7 +22,7 @@ def main():
help='bytes coding scheme, default utf8') help='bytes coding scheme, default utf8')
opt = parser.parse_args() opt = parser.parse_args()
converter( lmdb_converter(
opt.imglist, opt.output, batch_size=opt.batch_size, coding=opt.coding) opt.imglist, opt.output, batch_size=opt.batch_size, coding=opt.coding)