mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-06-03 21:53:39 +08:00
fix infer_rec for benchmark
This commit is contained in:
parent
6de43fbb47
commit
dd0112f52b
@ -10,3 +10,4 @@ EvalReader:
|
|||||||
TestReader:
|
TestReader:
|
||||||
reader_function: ppocr.data.rec.dataset_traversal,LMDBReader
|
reader_function: ppocr.data.rec.dataset_traversal,LMDBReader
|
||||||
lmdb_sets_dir: ./train_data/data_lmdb_release/evaluation/
|
lmdb_sets_dir: ./train_data/data_lmdb_release/evaluation/
|
||||||
|
infer_img:
|
||||||
|
@ -42,9 +42,11 @@ class LMDBReader(object):
|
|||||||
self.mode = params['mode']
|
self.mode = params['mode']
|
||||||
if params['mode'] == 'train':
|
if params['mode'] == 'train':
|
||||||
self.batch_size = params['train_batch_size_per_card']
|
self.batch_size = params['train_batch_size_per_card']
|
||||||
else:
|
elif params['mode'] == "eval":
|
||||||
self.batch_size = params['test_batch_size_per_card']
|
self.batch_size = params['test_batch_size_per_card']
|
||||||
|
elif params['mode'] == "test":
|
||||||
|
self.batch_size = 1
|
||||||
|
self.infer_img = params["infer_img"]
|
||||||
def load_hierarchical_lmdb_dataset(self):
|
def load_hierarchical_lmdb_dataset(self):
|
||||||
lmdb_sets = {}
|
lmdb_sets = {}
|
||||||
dataset_idx = 0
|
dataset_idx = 0
|
||||||
@ -97,6 +99,15 @@ class LMDBReader(object):
|
|||||||
process_id = 0
|
process_id = 0
|
||||||
|
|
||||||
def sample_iter_reader():
|
def sample_iter_reader():
|
||||||
|
if self.mode == 'test':
|
||||||
|
image_file_list = get_image_file_list(self.infer_img)
|
||||||
|
for single_img in image_file_list:
|
||||||
|
img = cv2.imread(single_img)
|
||||||
|
if img.shape[-1]==1 or len(list(img.shape))==2:
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||||
|
norm_img = process_image(img, self.image_shape)
|
||||||
|
yield norm_img
|
||||||
|
else:
|
||||||
lmdb_sets = self.load_hierarchical_lmdb_dataset()
|
lmdb_sets = self.load_hierarchical_lmdb_dataset()
|
||||||
if process_id == 0:
|
if process_id == 0:
|
||||||
self.print_lmdb_sets_info(lmdb_sets)
|
self.print_lmdb_sets_info(lmdb_sets)
|
||||||
@ -124,7 +135,6 @@ class LMDBReader(object):
|
|||||||
if finish_read_num == len(lmdb_sets):
|
if finish_read_num == len(lmdb_sets):
|
||||||
break
|
break
|
||||||
self.close_lmdb_dataset(lmdb_sets)
|
self.close_lmdb_dataset(lmdb_sets)
|
||||||
|
|
||||||
def batch_iter_reader():
|
def batch_iter_reader():
|
||||||
batch_outs = []
|
batch_outs = []
|
||||||
for outs in sample_iter_reader():
|
for outs in sample_iter_reader():
|
||||||
@ -135,7 +145,9 @@ class LMDBReader(object):
|
|||||||
if len(batch_outs) != 0:
|
if len(batch_outs) != 0:
|
||||||
yield batch_outs
|
yield batch_outs
|
||||||
|
|
||||||
|
if self.mode != 'test':
|
||||||
return batch_iter_reader
|
return batch_iter_reader
|
||||||
|
return sample_iter_reader
|
||||||
|
|
||||||
|
|
||||||
class SimpleReader(object):
|
class SimpleReader(object):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user