pre-commit
parent
dc7bfe8a84
commit
807dd10636
ppocr
modeling/backbones
postprocess
ppstructure
|
@ -121,14 +121,14 @@ class LayoutXLMForSer(NLPBaseModel):
|
|||
|
||||
def forward(self, x):
|
||||
x = self.model(
|
||||
input_ids=x[0],
|
||||
bbox=x[1],
|
||||
attention_mask=x[2],
|
||||
token_type_ids=x[3],
|
||||
image=x[4],
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
labels=None)
|
||||
input_ids=x[0],
|
||||
bbox=x[1],
|
||||
attention_mask=x[2],
|
||||
token_type_ids=x[3],
|
||||
image=x[4],
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
labels=None)
|
||||
if not self.training:
|
||||
return x
|
||||
return x[0]
|
||||
|
|
|
@ -68,7 +68,8 @@ class VQASerTokenLayoutLMPostProcess(object):
|
|||
def _infer(self, preds, segment_offset_ids, ocr_infos):
|
||||
results = []
|
||||
|
||||
for pred, segment_offset_id, ocr_info in zip(preds, segment_offset_ids, ocr_infos):
|
||||
for pred, segment_offset_id, ocr_info in zip(preds, segment_offset_ids,
|
||||
ocr_infos):
|
||||
pred = np.argmax(pred, axis=1)
|
||||
pred = [self.id2label_map[idx] for idx in pred]
|
||||
|
||||
|
|
|
@ -40,7 +40,6 @@ def init_args():
|
|||
type=ast.literal_eval,
|
||||
default=None,
|
||||
help='label map according to ppstructure/layout/README_ch.md')
|
||||
|
||||
# params for vqa
|
||||
parser.add_argument("--vqa_algorithm", type=str, default='LayoutXLM')
|
||||
parser.add_argument("--ser_model_dir", type=str)
|
||||
|
@ -73,7 +72,7 @@ def init_args():
|
|||
"--recovery",
|
||||
type=bool,
|
||||
default=False,
|
||||
help='Whether to enable layout of recovery')
|
||||
help='Whether to enable layout of recovery')
|
||||
return parser
|
||||
|
||||
|
||||
|
|
|
@ -97,8 +97,9 @@ def export_single_model(model,
|
|||
shape=[None, 1, 32, 100], dtype="float32"),
|
||||
]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
|
||||
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
|
||||
input_spec=[
|
||||
input_spec = [
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 512], dtype="int64"), # input_ids
|
||||
paddle.static.InputSpec(
|
||||
|
|
|
@ -318,7 +318,7 @@ def create_predictor(args, mode, logger):
|
|||
# create predictor
|
||||
predictor = inference.create_predictor(config)
|
||||
input_names = predictor.get_input_names()
|
||||
if mode in ['ser','re']:
|
||||
if mode in ['ser', 're']:
|
||||
input_tensor = []
|
||||
for name in input_names:
|
||||
input_tensor.append(predictor.get_input_handle(name))
|
||||
|
|
|
@ -44,7 +44,7 @@ def to_tensor(data):
|
|||
from collections import defaultdict
|
||||
data_dict = defaultdict(list)
|
||||
to_tensor_idxs = []
|
||||
|
||||
|
||||
for idx, v in enumerate(data):
|
||||
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
|
||||
if idx not in to_tensor_idxs:
|
||||
|
@ -72,7 +72,10 @@ class SerPredictor(object):
|
|||
|
||||
from paddleocr import PaddleOCR
|
||||
|
||||
self.ocr_engine = PaddleOCR(use_angle_cls=False, show_log=False, use_gpu=global_config['use_gpu'])
|
||||
self.ocr_engine = PaddleOCR(
|
||||
use_angle_cls=False,
|
||||
show_log=False,
|
||||
use_gpu=global_config['use_gpu'])
|
||||
|
||||
# create data ops
|
||||
transforms = []
|
||||
|
@ -82,8 +85,8 @@ class SerPredictor(object):
|
|||
op[op_name]['ocr_engine'] = self.ocr_engine
|
||||
elif op_name == 'KeepKeys':
|
||||
op[op_name]['keep_keys'] = [
|
||||
'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels',
|
||||
'segment_offset_id', 'ocr_info',
|
||||
'input_ids', 'bbox', 'attention_mask', 'token_type_ids',
|
||||
'image', 'labels', 'segment_offset_id', 'ocr_info',
|
||||
'entities'
|
||||
]
|
||||
|
||||
|
@ -103,11 +106,9 @@ class SerPredictor(object):
|
|||
preds = self.model(batch)
|
||||
if self.algorithm in ['LayoutLMv2', 'LayoutXLM']:
|
||||
preds = preds[0]
|
||||
|
||||
|
||||
post_result = self.post_process_class(
|
||||
preds,
|
||||
segment_offset_ids=batch[6],
|
||||
ocr_infos=batch[7])
|
||||
preds, segment_offset_ids=batch[6], ocr_infos=batch[7])
|
||||
return post_result, batch
|
||||
|
||||
|
||||
|
@ -154,4 +155,3 @@ if __name__ == '__main__':
|
|||
|
||||
logger.info("process: [{}/{}], save result to {}".format(
|
||||
idx, len(infer_imgs), save_img_path))
|
||||
|
||||
|
|
|
@ -192,6 +192,6 @@ if __name__ == '__main__':
|
|||
}, ensure_ascii=False) + "\n")
|
||||
img_res = draw_re_results(img_path, result)
|
||||
cv2.imwrite(save_img_path, img_res)
|
||||
|
||||
|
||||
logger.info("process: [{}/{}], save result to {}".format(
|
||||
idx, len(infer_imgs), save_img_path))
|
||||
idx, len(infer_imgs), save_img_path))
|
||||
|
|
Loading…
Reference in New Issue