vqa code integrated into ppocr training system
parent
1ded2ac44a
commit
a323fce66d
|
@ -0,0 +1,122 @@
|
|||
Global:
|
||||
use_gpu: True
|
||||
epoch_num: 200
|
||||
log_smooth_window: 10
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/re_layoutxlm/
|
||||
save_epoch_step: 2000
|
||||
# evaluation is run every 10 iterations after the 0th iteration
|
||||
eval_batch_step: [ 0, 38 ]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: &pretrained_model layoutxlm-base-uncased
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: ppstructure/vqa/images/input/zh_val_21.jpg
|
||||
save_res_path: ./output/re/
|
||||
|
||||
Architecture:
|
||||
model_type: vqa
|
||||
algorithm: &algorithm "LayoutXLM"
|
||||
Transform:
|
||||
Backbone:
|
||||
name: LayoutXLMForRe
|
||||
pretrained_model: *pretrained_model
|
||||
checkpoints:
|
||||
|
||||
Loss:
|
||||
name: LossFromOutput
|
||||
key: loss
|
||||
reduction: mean
|
||||
|
||||
Optimizer:
|
||||
name: AdamW
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
clip_norm: 10
|
||||
lr:
|
||||
learning_rate: 0.00005
|
||||
regularizer:
|
||||
name: Const
|
||||
factor: 0.00000
|
||||
|
||||
PostProcess:
|
||||
name: VQAReTokenLayoutLMPostProcess
|
||||
|
||||
Metric:
|
||||
name: VQAReTokenMetric
|
||||
main_indicator: hmean
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: train_data/XFUND/zh_train/image
|
||||
label_file_list:
|
||||
- train_data/XFUND/zh_train/xfun_normalize_train.json
|
||||
ratio_list: [ 1.0 ]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: RGB
|
||||
channel_first: False
|
||||
- VQATokenLabelEncode: # Class handling label
|
||||
contains_re: True
|
||||
algorithm: *algorithm
|
||||
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
|
||||
- VQATokenPad:
|
||||
max_seq_len: &max_seq_len 512
|
||||
return_attention_mask: True
|
||||
- VQAReTokenRelation:
|
||||
- VQAReTokenChunk:
|
||||
max_seq_len: *max_seq_len
|
||||
- Resize:
|
||||
size: [224,224]
|
||||
- NormalizeImage:
|
||||
scale: 1
|
||||
mean: [ 123.675, 116.28, 103.53 ]
|
||||
std: [ 58.395, 57.12, 57.375 ]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 4
|
||||
collate_fn: ListCollator
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: train_data/XFUND/zh_val/image
|
||||
label_file_list:
|
||||
- train_data/XFUND/zh_val/xfun_normalize_val.json
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: RGB
|
||||
channel_first: False
|
||||
- VQATokenLabelEncode: # Class handling label
|
||||
contains_re: True
|
||||
algorithm: *algorithm
|
||||
class_path: *class_path
|
||||
- VQATokenPad:
|
||||
max_seq_len: *max_seq_len
|
||||
return_attention_mask: True
|
||||
- VQAReTokenRelation:
|
||||
- VQAReTokenChunk:
|
||||
max_seq_len: *max_seq_len
|
||||
- Resize:
|
||||
size: [224,224]
|
||||
- NormalizeImage:
|
||||
scale: 1
|
||||
mean: [ 123.675, 116.28, 103.53 ]
|
||||
std: [ 58.395, 57.12, 57.375 ]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 4
|
||||
collate_fn: ListCollator
|
|
@ -0,0 +1,120 @@
|
|||
Global:
|
||||
use_gpu: True
|
||||
epoch_num: &epoch_num 200
|
||||
log_smooth_window: 10
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/ser_layoutlm/
|
||||
save_epoch_step: 2000
|
||||
# evaluation is run every 10 iterations after the 0th iteration
|
||||
eval_batch_step: [ 0, 19 ]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: &pretrained_model layoutlm-base-uncased
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: ppstructure/vqa/images/input/zh_val_0.jpg
|
||||
save_res_path: ./output/ser/predicts_layoutlm.txt
|
||||
|
||||
Architecture:
|
||||
model_type: vqa
|
||||
algorithm: &algorithm "LayoutLM"
|
||||
Transform:
|
||||
Backbone:
|
||||
name: LayoutLMForSer
|
||||
pretrained_model: *pretrained_model
|
||||
checkpoints:
|
||||
num_classes: &num_classes 7
|
||||
|
||||
Loss:
|
||||
name: VQASerTokenLayoutLMLoss
|
||||
num_classes: *num_classes
|
||||
|
||||
Optimizer:
|
||||
name: AdamW
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Linear
|
||||
learning_rate: 0.00005
|
||||
epochs: *epoch_num
|
||||
warmup_epoch: 2
|
||||
regularizer:
|
||||
name: Const
|
||||
factor: 0.00000
|
||||
|
||||
PostProcess:
|
||||
name: VQASerTokenLayoutLMPostProcess
|
||||
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
|
||||
|
||||
Metric:
|
||||
name: VQASerTokenMetric
|
||||
main_indicator: hmean
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: train_data/XFUND/zh_train/image
|
||||
label_file_list:
|
||||
- train_data/XFUND/zh_train/xfun_normalize_train.json
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: RGB
|
||||
channel_first: False
|
||||
- VQATokenLabelEncode: # Class handling label
|
||||
contains_re: False
|
||||
algorithm: *algorithm
|
||||
class_path: *class_path
|
||||
- VQATokenPad:
|
||||
max_seq_len: &max_seq_len 512
|
||||
return_attention_mask: True
|
||||
- VQASerTokenChunk:
|
||||
max_seq_len: *max_seq_len
|
||||
- Resize:
|
||||
size: [224,224]
|
||||
- NormalizeImage:
|
||||
scale: 1
|
||||
mean: [ 123.675, 116.28, 103.53 ]
|
||||
std: [ 58.395, 57.12, 57.375 ]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: [ 'input_ids','labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: train_data/XFUND/zh_val/image
|
||||
label_file_list:
|
||||
- train_data/XFUND/zh_val/xfun_normalize_val.json
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: RGB
|
||||
channel_first: False
|
||||
- VQATokenLabelEncode: # Class handling label
|
||||
contains_re: False
|
||||
algorithm: *algorithm
|
||||
class_path: *class_path
|
||||
- VQATokenPad:
|
||||
max_seq_len: *max_seq_len
|
||||
return_attention_mask: True
|
||||
- VQASerTokenChunk:
|
||||
max_seq_len: *max_seq_len
|
||||
- Resize:
|
||||
size: [224,224]
|
||||
- NormalizeImage:
|
||||
scale: 1
|
||||
mean: [ 123.675, 116.28, 103.53 ]
|
||||
std: [ 58.395, 57.12, 57.375 ]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: [ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 4
|
|
@ -0,0 +1,121 @@
|
|||
Global:
|
||||
use_gpu: True
|
||||
epoch_num: &epoch_num 200
|
||||
log_smooth_window: 10
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/ser_layoutxlm/
|
||||
save_epoch_step: 2000
|
||||
# evaluation is run every 10 iterations after the 0th iteration
|
||||
eval_batch_step: [ 0, 19 ]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: &pretrained_model layoutxlm-base-uncased
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: ppstructure/vqa/images/input/zh_val_42.jpg
|
||||
save_res_path: ./output/ser
|
||||
|
||||
Architecture:
|
||||
model_type: vqa
|
||||
algorithm: &algorithm "LayoutXLM"
|
||||
Transform:
|
||||
Backbone:
|
||||
name: LayoutXLMForSer
|
||||
pretrained_model: *pretrained_model
|
||||
checkpoints:
|
||||
num_classes: &num_classes 7
|
||||
|
||||
Loss:
|
||||
name: VQASerTokenLayoutLMLoss
|
||||
num_classes: *num_classes
|
||||
|
||||
Optimizer:
|
||||
name: AdamW
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Linear
|
||||
learning_rate: 0.00005
|
||||
epochs: *epoch_num
|
||||
warmup_epoch: 2
|
||||
regularizer:
|
||||
name: Const
|
||||
factor: 0.00000
|
||||
|
||||
PostProcess:
|
||||
name: VQASerTokenLayoutLMPostProcess
|
||||
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
|
||||
|
||||
Metric:
|
||||
name: VQASerTokenMetric
|
||||
main_indicator: hmean
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: train_data/XFUND/zh_train/image
|
||||
label_file_list:
|
||||
- train_data/XFUND/zh_train/xfun_normalize_train.json
|
||||
ratio_list: [ 1.0 ]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: RGB
|
||||
channel_first: False
|
||||
- VQATokenLabelEncode: # Class handling label
|
||||
contains_re: False
|
||||
algorithm: *algorithm
|
||||
class_path: *class_path
|
||||
- VQATokenPad:
|
||||
max_seq_len: &max_seq_len 512
|
||||
return_attention_mask: True
|
||||
- VQASerTokenChunk:
|
||||
max_seq_len: *max_seq_len
|
||||
- Resize:
|
||||
size: [224,224]
|
||||
- NormalizeImage:
|
||||
scale: 1
|
||||
mean: [ 123.675, 116.28, 103.53 ]
|
||||
std: [ 58.395, 57.12, 57.375 ]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: [ 'input_ids','labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: train_data/XFUND/zh_val/image
|
||||
label_file_list:
|
||||
- train_data/XFUND/zh_val/xfun_normalize_val.json
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: RGB
|
||||
channel_first: False
|
||||
- VQATokenLabelEncode: # Class handling label
|
||||
contains_re: False
|
||||
algorithm: *algorithm
|
||||
class_path: *class_path
|
||||
- VQATokenPad:
|
||||
max_seq_len: *max_seq_len
|
||||
return_attention_mask: True
|
||||
- VQASerTokenChunk:
|
||||
max_seq_len: *max_seq_len
|
||||
- Resize:
|
||||
size: [224,224]
|
||||
- NormalizeImage:
|
||||
scale: 1
|
||||
mean: [ 123.675, 116.28, 103.53 ]
|
||||
std: [ 58.395, 57.12, 57.375 ]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: [ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 4
|
|
@ -86,13 +86,19 @@ def build_dataloader(config, mode, device, logger, seed=None):
|
|||
shuffle=shuffle,
|
||||
drop_last=drop_last)
|
||||
|
||||
if 'collate_fn' in loader_config:
|
||||
from . import collate_fn
|
||||
collate_fn = getattr(collate_fn, loader_config['collate_fn'])()
|
||||
else:
|
||||
collate_fn = None
|
||||
data_loader = DataLoader(
|
||||
dataset=dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
places=device,
|
||||
num_workers=num_workers,
|
||||
return_list=True,
|
||||
use_shared_memory=use_shared_memory)
|
||||
use_shared_memory=use_shared_memory,
|
||||
collate_fn=collate_fn)
|
||||
|
||||
# support exit using ctrl+c
|
||||
signal.signal(signal.SIGINT, term_mp)
|
||||
|
|
|
@ -15,20 +15,19 @@
|
|||
import paddle
|
||||
import numbers
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
class DataCollator:
|
||||
class DictCollator(object):
|
||||
"""
|
||||
data batch
|
||||
"""
|
||||
|
||||
def __call__(self, batch):
|
||||
data_dict = {}
|
||||
data_dict = defaultdict(list)
|
||||
to_tensor_keys = []
|
||||
for sample in batch:
|
||||
for k, v in sample.items():
|
||||
if k not in data_dict:
|
||||
data_dict[k] = []
|
||||
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
|
||||
if k not in to_tensor_keys:
|
||||
to_tensor_keys.append(k)
|
||||
|
@ -36,3 +35,22 @@ class DataCollator:
|
|||
for k in to_tensor_keys:
|
||||
data_dict[k] = paddle.to_tensor(data_dict[k])
|
||||
return data_dict
|
||||
|
||||
|
||||
class ListCollator(object):
|
||||
"""
|
||||
data batch
|
||||
"""
|
||||
|
||||
def __call__(self, batch):
|
||||
data_dict = defaultdict(list)
|
||||
to_tensor_idxs = []
|
||||
for sample in batch:
|
||||
for idx, v in enumerate(sample):
|
||||
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
|
||||
if idx not in to_tensor_idxs:
|
||||
to_tensor_idxs.append(idx)
|
||||
data_dict[idx].append(v)
|
||||
for idx in to_tensor_idxs:
|
||||
data_dict[idx] = paddle.to_tensor(data_dict[idx])
|
||||
return list(data_dict.values())
|
|
@ -34,6 +34,8 @@ from .sast_process import *
|
|||
from .pg_process import *
|
||||
from .gen_table_mask import *
|
||||
|
||||
from .vqa import *
|
||||
|
||||
|
||||
def transform(data, ops=None):
|
||||
""" transform """
|
||||
|
|
|
@ -17,6 +17,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import copy
|
||||
import numpy as np
|
||||
import string
|
||||
from shapely.geometry import LineString, Point, Polygon
|
||||
|
@ -736,7 +737,7 @@ class TableLabelEncode(object):
|
|||
% beg_or_end
|
||||
else:
|
||||
assert False, "Unsupport type %s in char_or_elem" \
|
||||
% char_or_elem
|
||||
% char_or_elem
|
||||
return idx
|
||||
|
||||
|
||||
|
@ -782,3 +783,208 @@ class SARLabelEncode(BaseRecLabelEncode):
|
|||
|
||||
def get_ignored_tokens(self):
|
||||
return [self.padding_idx]
|
||||
|
||||
|
||||
class VQATokenLabelEncode(object):
|
||||
"""
|
||||
基于NLP的标签编码
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
class_path,
|
||||
contains_re=False,
|
||||
add_special_ids=False,
|
||||
algorithm='LayoutXLM',
|
||||
infer_mode=False,
|
||||
ocr_engine=None,
|
||||
**kwargs):
|
||||
super(VQATokenLabelEncode, self).__init__()
|
||||
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutLMTokenizer
|
||||
from ppocr.utils.utility import load_vqa_bio_label_maps
|
||||
tokenizer_dict = {
|
||||
'LayoutXLM': {
|
||||
'class': LayoutXLMTokenizer,
|
||||
'pretrained_model': 'layoutxlm-base-uncased'
|
||||
},
|
||||
'LayoutLM': {
|
||||
'class': LayoutLMTokenizer,
|
||||
'pretrained_model': 'layoutlm-base-uncased'
|
||||
}
|
||||
}
|
||||
self.contains_re = contains_re
|
||||
tokenizer_config = tokenizer_dict[algorithm]
|
||||
self.tokenizer = tokenizer_config['class'].from_pretrained(
|
||||
tokenizer_config['pretrained_model'])
|
||||
self.label2id_map, id2label_map = load_vqa_bio_label_maps(class_path)
|
||||
self.add_special_ids = add_special_ids
|
||||
self.infer_mode = infer_mode
|
||||
self.ocr_engine = ocr_engine
|
||||
|
||||
def __call__(self, data):
|
||||
if self.infer_mode == False:
|
||||
return self._train(data)
|
||||
else:
|
||||
return self._infer(data)
|
||||
|
||||
def _train(self, data):
|
||||
info = data['label']
|
||||
|
||||
# read text info
|
||||
info_dict = json.loads(info)
|
||||
height = info_dict["height"]
|
||||
width = info_dict["width"]
|
||||
|
||||
words_list = []
|
||||
bbox_list = []
|
||||
input_ids_list = []
|
||||
token_type_ids_list = []
|
||||
gt_label_list = []
|
||||
|
||||
if self.contains_re:
|
||||
# for re
|
||||
entities = []
|
||||
relations = []
|
||||
id2label = {}
|
||||
entity_id_to_index_map = {}
|
||||
empty_entity = set()
|
||||
for info in info_dict["ocr_info"]:
|
||||
if self.contains_re:
|
||||
# for re
|
||||
if len(info["text"]) == 0:
|
||||
empty_entity.add(info["id"])
|
||||
continue
|
||||
id2label[info["id"]] = info["label"]
|
||||
relations.extend([tuple(sorted(l)) for l in info["linking"]])
|
||||
|
||||
# x1, y1, x2, y2
|
||||
bbox = info["bbox"]
|
||||
label = info["label"]
|
||||
bbox[0] = int(bbox[0] * 1000.0 / width)
|
||||
bbox[2] = int(bbox[2] * 1000.0 / width)
|
||||
bbox[1] = int(bbox[1] * 1000.0 / height)
|
||||
bbox[3] = int(bbox[3] * 1000.0 / height)
|
||||
|
||||
text = info["text"]
|
||||
encode_res = self.tokenizer.encode(
|
||||
text, pad_to_max_seq_len=False, return_attention_mask=True)
|
||||
|
||||
gt_label = []
|
||||
if not self.add_special_ids:
|
||||
# TODO: use tok.all_special_ids to remove
|
||||
encode_res["input_ids"] = encode_res["input_ids"][1:-1]
|
||||
encode_res["token_type_ids"] = encode_res["token_type_ids"][1:
|
||||
-1]
|
||||
encode_res["attention_mask"] = encode_res["attention_mask"][1:
|
||||
-1]
|
||||
if label.lower() == "other":
|
||||
gt_label.extend([0] * len(encode_res["input_ids"]))
|
||||
else:
|
||||
gt_label.append(self.label2id_map[("b-" + label).upper()])
|
||||
gt_label.extend([self.label2id_map[("i-" + label).upper()]] *
|
||||
(len(encode_res["input_ids"]) - 1))
|
||||
if self.contains_re:
|
||||
if gt_label[0] != self.label2id_map["O"]:
|
||||
entity_id_to_index_map[info["id"]] = len(entities)
|
||||
entities.append({
|
||||
"start": len(input_ids_list),
|
||||
"end":
|
||||
len(input_ids_list) + len(encode_res["input_ids"]),
|
||||
"label": label.upper(),
|
||||
})
|
||||
input_ids_list.extend(encode_res["input_ids"])
|
||||
token_type_ids_list.extend(encode_res["token_type_ids"])
|
||||
bbox_list.extend([bbox] * len(encode_res["input_ids"]))
|
||||
gt_label_list.extend(gt_label)
|
||||
words_list.append(text)
|
||||
|
||||
encoded_inputs = {
|
||||
"input_ids": input_ids_list,
|
||||
"labels": gt_label_list,
|
||||
"token_type_ids": token_type_ids_list,
|
||||
"bbox": bbox_list,
|
||||
"attention_mask": [1] * len(input_ids_list),
|
||||
}
|
||||
data.update(encoded_inputs)
|
||||
data['tokenizer_params'] = dict(
|
||||
padding_side=self.tokenizer.padding_side,
|
||||
pad_token_type_id=self.tokenizer.pad_token_type_id,
|
||||
pad_token_id=self.tokenizer.pad_token_id)
|
||||
|
||||
if self.contains_re:
|
||||
data['entities'] = entities
|
||||
data['relations'] = relations
|
||||
data['id2label'] = id2label
|
||||
data['empty_entity'] = empty_entity
|
||||
data['entity_id_to_index_map'] = entity_id_to_index_map
|
||||
return data
|
||||
|
||||
def _infer(self, data):
|
||||
def trans_poly_to_bbox(poly):
|
||||
x1 = np.min([p[0] for p in poly])
|
||||
x2 = np.max([p[0] for p in poly])
|
||||
y1 = np.min([p[1] for p in poly])
|
||||
y2 = np.max([p[1] for p in poly])
|
||||
return [x1, y1, x2, y2]
|
||||
|
||||
height, width, _ = data['image'].shape
|
||||
ocr_result = self.ocr_engine.ocr(data['image'], cls=False)
|
||||
ocr_info = []
|
||||
for res in ocr_result:
|
||||
ocr_info.append({
|
||||
"text": res[1][0],
|
||||
"bbox": trans_poly_to_bbox(res[0]),
|
||||
"poly": res[0],
|
||||
})
|
||||
|
||||
segment_offset_id = []
|
||||
words_list = []
|
||||
bbox_list = []
|
||||
input_ids_list = []
|
||||
token_type_ids_list = []
|
||||
entities = []
|
||||
|
||||
for info in ocr_info:
|
||||
# x1, y1, x2, y2
|
||||
bbox = copy.deepcopy(info["bbox"])
|
||||
bbox[0] = int(bbox[0] * 1000.0 / width)
|
||||
bbox[2] = int(bbox[2] * 1000.0 / width)
|
||||
bbox[1] = int(bbox[1] * 1000.0 / height)
|
||||
bbox[3] = int(bbox[3] * 1000.0 / height)
|
||||
|
||||
text = info["text"]
|
||||
encode_res = self.tokenizer.encode(
|
||||
text, pad_to_max_seq_len=False, return_attention_mask=True)
|
||||
|
||||
if not self.add_special_ids:
|
||||
# TODO: use tok.all_special_ids to remove
|
||||
encode_res["input_ids"] = encode_res["input_ids"][1:-1]
|
||||
encode_res["token_type_ids"] = encode_res["token_type_ids"][1:
|
||||
-1]
|
||||
encode_res["attention_mask"] = encode_res["attention_mask"][1:
|
||||
-1]
|
||||
|
||||
# for re
|
||||
entities.append({
|
||||
"start": len(input_ids_list),
|
||||
"end": len(input_ids_list) + len(encode_res["input_ids"]),
|
||||
"label": "O",
|
||||
})
|
||||
|
||||
input_ids_list.extend(encode_res["input_ids"])
|
||||
token_type_ids_list.extend(encode_res["token_type_ids"])
|
||||
bbox_list.extend([bbox] * len(encode_res["input_ids"]))
|
||||
words_list.append(text)
|
||||
segment_offset_id.append(len(input_ids_list))
|
||||
|
||||
encoded_inputs = {
|
||||
"input_ids": input_ids_list,
|
||||
"token_type_ids": token_type_ids_list,
|
||||
"bbox": bbox_list,
|
||||
"attention_mask": [1] * len(input_ids_list),
|
||||
"entities": entities,
|
||||
'labels': None,
|
||||
'segment_offset_id': segment_offset_id,
|
||||
'ocr_info': ocr_info
|
||||
}
|
||||
data.update(encoded_inputs)
|
||||
return data
|
||||
|
|
|
@ -170,17 +170,19 @@ class Resize(object):
|
|||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
text_polys = data['polys']
|
||||
if 'polys' in data:
|
||||
text_polys = data['polys']
|
||||
|
||||
img_resize, [ratio_h, ratio_w] = self.resize_image(img)
|
||||
new_boxes = []
|
||||
for box in text_polys:
|
||||
new_box = []
|
||||
for cord in box:
|
||||
new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
|
||||
new_boxes.append(new_box)
|
||||
if 'polys' in data:
|
||||
new_boxes = []
|
||||
for box in text_polys:
|
||||
new_box = []
|
||||
for cord in box:
|
||||
new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
|
||||
new_boxes.append(new_box)
|
||||
data['polys'] = np.array(new_boxes, dtype=np.float32)
|
||||
data['image'] = img_resize
|
||||
data['polys'] = np.array(new_boxes, dtype=np.float32)
|
||||
return data
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation
|
||||
|
||||
__all__ = [
|
||||
'VQATokenPad', 'VQASerTokenChunk', 'VQAReTokenChunk', 'VQAReTokenRelation'
|
||||
]
|
|
@ -0,0 +1,17 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .vqa_token_chunk import VQASerTokenChunk, VQAReTokenChunk
|
||||
from .vqa_token_pad import VQATokenPad
|
||||
from .vqa_token_relation import VQAReTokenRelation
|
|
@ -0,0 +1,117 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
class VQASerTokenChunk(object):
|
||||
def __init__(self, max_seq_len=512, infer_mode=False, **kwargs):
|
||||
self.max_seq_len = max_seq_len
|
||||
self.infer_mode = infer_mode
|
||||
|
||||
def __call__(self, data):
|
||||
encoded_inputs_all = []
|
||||
seq_len = len(data['input_ids'])
|
||||
for index in range(0, seq_len, self.max_seq_len):
|
||||
chunk_beg = index
|
||||
chunk_end = min(index + self.max_seq_len, seq_len)
|
||||
encoded_inputs_example = {}
|
||||
for key in data:
|
||||
if key in [
|
||||
'label', 'input_ids', 'labels', 'token_type_ids',
|
||||
'bbox', 'attention_mask'
|
||||
]:
|
||||
if self.infer_mode and key == 'labels':
|
||||
encoded_inputs_example[key] = data[key]
|
||||
else:
|
||||
encoded_inputs_example[key] = data[key][chunk_beg:
|
||||
chunk_end]
|
||||
else:
|
||||
encoded_inputs_example[key] = data[key]
|
||||
|
||||
encoded_inputs_all.append(encoded_inputs_example)
|
||||
return encoded_inputs_all[0]
|
||||
|
||||
|
||||
class VQAReTokenChunk(object):
|
||||
def __init__(self,
|
||||
max_seq_len=512,
|
||||
entities_labels=None,
|
||||
infer_mode=False,
|
||||
**kwargs):
|
||||
self.max_seq_len = max_seq_len
|
||||
self.entities_labels = {
|
||||
'HEADER': 0,
|
||||
'QUESTION': 1,
|
||||
'ANSWER': 2
|
||||
} if entities_labels is None else entities_labels
|
||||
self.infer_mode = infer_mode
|
||||
|
||||
def __call__(self, data):
|
||||
# prepare data
|
||||
entities = data.pop('entities')
|
||||
relations = data.pop('relations')
|
||||
encoded_inputs_all = []
|
||||
for index in range(0, len(data["input_ids"]), self.max_seq_len):
|
||||
item = {}
|
||||
for key in data:
|
||||
if key in [
|
||||
'label', 'input_ids', 'labels', 'token_type_ids',
|
||||
'bbox', 'attention_mask'
|
||||
]:
|
||||
if self.infer_mode and key == 'labels':
|
||||
item[key] = data[key]
|
||||
else:
|
||||
item[key] = data[key][index:index + self.max_seq_len]
|
||||
else:
|
||||
item[key] = data[key]
|
||||
# select entity in current chunk
|
||||
entities_in_this_span = []
|
||||
global_to_local_map = {} #
|
||||
for entity_id, entity in enumerate(entities):
|
||||
if (index <= entity["start"] < index + self.max_seq_len and
|
||||
index <= entity["end"] < index + self.max_seq_len):
|
||||
entity["start"] = entity["start"] - index
|
||||
entity["end"] = entity["end"] - index
|
||||
global_to_local_map[entity_id] = len(entities_in_this_span)
|
||||
entities_in_this_span.append(entity)
|
||||
|
||||
# select relations in current chunk
|
||||
relations_in_this_span = []
|
||||
for relation in relations:
|
||||
if (index <= relation["start_index"] < index + self.max_seq_len
|
||||
and index <= relation["end_index"] <
|
||||
index + self.max_seq_len):
|
||||
relations_in_this_span.append({
|
||||
"head": global_to_local_map[relation["head"]],
|
||||
"tail": global_to_local_map[relation["tail"]],
|
||||
"start_index": relation["start_index"] - index,
|
||||
"end_index": relation["end_index"] - index,
|
||||
})
|
||||
item.update({
|
||||
"entities": self.reformat(entities_in_this_span),
|
||||
"relations": self.reformat(relations_in_this_span),
|
||||
})
|
||||
item['entities']['label'] = [
|
||||
self.entities_labels[x] for x in item['entities']['label']
|
||||
]
|
||||
encoded_inputs_all.append(item)
|
||||
return encoded_inputs_all[0]
|
||||
|
||||
def reformat(self, data):
|
||||
new_data = {}
|
||||
for item in data:
|
||||
for k, v in item.items():
|
||||
if k not in new_data:
|
||||
new_data[k] = []
|
||||
new_data[k].append(v)
|
||||
return new_data
|
|
@ -0,0 +1,101 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import paddle
|
||||
import numpy as np
|
||||
|
||||
|
||||
class VQATokenPad(object):
|
||||
def __init__(self,
|
||||
max_seq_len=512,
|
||||
pad_to_max_seq_len=True,
|
||||
return_attention_mask=True,
|
||||
return_token_type_ids=True,
|
||||
truncation_strategy="longest_first",
|
||||
return_overflowing_tokens=False,
|
||||
return_special_tokens_mask=False,
|
||||
infer_mode=False,
|
||||
**kwargs):
|
||||
self.max_seq_len = max_seq_len
|
||||
self.pad_to_max_seq_len = max_seq_len
|
||||
self.return_attention_mask = return_attention_mask
|
||||
self.return_token_type_ids = return_token_type_ids
|
||||
self.truncation_strategy = truncation_strategy
|
||||
self.return_overflowing_tokens = return_overflowing_tokens
|
||||
self.return_special_tokens_mask = return_special_tokens_mask
|
||||
self.pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
|
||||
self.infer_mode = infer_mode
|
||||
|
||||
def __call__(self, data):
|
||||
needs_to_be_padded = self.pad_to_max_seq_len and len(data[
|
||||
"input_ids"]) < self.max_seq_len
|
||||
|
||||
if needs_to_be_padded:
|
||||
if 'tokenizer_params' in data:
|
||||
tokenizer_params = data.pop('tokenizer_params')
|
||||
else:
|
||||
tokenizer_params = dict(
|
||||
padding_side='right', pad_token_type_id=0, pad_token_id=1)
|
||||
|
||||
difference = self.max_seq_len - len(data["input_ids"])
|
||||
if tokenizer_params['padding_side'] == 'right':
|
||||
if self.return_attention_mask:
|
||||
data["attention_mask"] = [1] * len(data[
|
||||
"input_ids"]) + [0] * difference
|
||||
if self.return_token_type_ids:
|
||||
data["token_type_ids"] = (
|
||||
data["token_type_ids"] +
|
||||
[tokenizer_params['pad_token_type_id']] * difference)
|
||||
if self.return_special_tokens_mask:
|
||||
data["special_tokens_mask"] = data[
|
||||
"special_tokens_mask"] + [1] * difference
|
||||
data["input_ids"] = data["input_ids"] + [
|
||||
tokenizer_params['pad_token_id']
|
||||
] * difference
|
||||
if not self.infer_mode:
|
||||
data["labels"] = data[
|
||||
"labels"] + [self.pad_token_label_id] * difference
|
||||
data["bbox"] = data["bbox"] + [[0, 0, 0, 0]] * difference
|
||||
elif tokenizer_params['padding_side'] == 'left':
|
||||
if self.return_attention_mask:
|
||||
data["attention_mask"] = [0] * difference + [
|
||||
1
|
||||
] * len(data["input_ids"])
|
||||
if self.return_token_type_ids:
|
||||
data["token_type_ids"] = (
|
||||
[tokenizer_params['pad_token_type_id']] * difference +
|
||||
data["token_type_ids"])
|
||||
if self.return_special_tokens_mask:
|
||||
data["special_tokens_mask"] = [
|
||||
1
|
||||
] * difference + data["special_tokens_mask"]
|
||||
data["input_ids"] = [tokenizer_params['pad_token_id']
|
||||
] * difference + data["input_ids"]
|
||||
if not self.infer_mode:
|
||||
data["labels"] = [self.pad_token_label_id
|
||||
] * difference + data["labels"]
|
||||
data["bbox"] = [[0, 0, 0, 0]] * difference + data["bbox"]
|
||||
else:
|
||||
if self.return_attention_mask:
|
||||
data["attention_mask"] = [1] * len(data["input_ids"])
|
||||
|
||||
for key in data:
|
||||
if key in [
|
||||
'input_ids', 'labels', 'token_type_ids', 'bbox',
|
||||
'attention_mask'
|
||||
]:
|
||||
if self.infer_mode and key == 'labels':
|
||||
continue
|
||||
length = min(len(data[key]), self.max_seq_len)
|
||||
data[key] = np.array(data[key][:length], dtype='int64')
|
||||
return data
|
|
@ -0,0 +1,67 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
class VQAReTokenRelation(object):
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, data):
|
||||
"""
|
||||
build relations
|
||||
"""
|
||||
entities = data['entities']
|
||||
relations = data['relations']
|
||||
id2label = data.pop('id2label')
|
||||
empty_entity = data.pop('empty_entity')
|
||||
entity_id_to_index_map = data.pop('entity_id_to_index_map')
|
||||
|
||||
relations = list(set(relations))
|
||||
relations = [
|
||||
rel for rel in relations
|
||||
if rel[0] not in empty_entity and rel[1] not in empty_entity
|
||||
]
|
||||
kv_relations = []
|
||||
for rel in relations:
|
||||
pair = [id2label[rel[0]], id2label[rel[1]]]
|
||||
if pair == ["question", "answer"]:
|
||||
kv_relations.append({
|
||||
"head": entity_id_to_index_map[rel[0]],
|
||||
"tail": entity_id_to_index_map[rel[1]]
|
||||
})
|
||||
elif pair == ["answer", "question"]:
|
||||
kv_relations.append({
|
||||
"head": entity_id_to_index_map[rel[1]],
|
||||
"tail": entity_id_to_index_map[rel[0]]
|
||||
})
|
||||
else:
|
||||
continue
|
||||
relations = sorted(
|
||||
[{
|
||||
"head": rel["head"],
|
||||
"tail": rel["tail"],
|
||||
"start_index": self.get_relation_span(rel, entities)[0],
|
||||
"end_index": self.get_relation_span(rel, entities)[1],
|
||||
} for rel in kv_relations],
|
||||
key=lambda x: x["head"], )
|
||||
|
||||
data['relations'] = relations
|
||||
return data
|
||||
|
||||
def get_relation_span(self, rel, entities):
|
||||
bound = []
|
||||
for entity_index in [rel["head"], rel["tail"]]:
|
||||
bound.append(entities[entity_index]["start"])
|
||||
bound.append(entities[entity_index]["end"])
|
||||
return min(bound), max(bound)
|
|
@ -41,7 +41,6 @@ class SimpleDataSet(Dataset):
|
|||
) == data_source_num, "The length of ratio_list should be the same as the file_list."
|
||||
self.data_dir = dataset_config['data_dir']
|
||||
self.do_shuffle = loader_config['shuffle']
|
||||
|
||||
self.seed = seed
|
||||
logger.info("Initialize indexs of datasets:%s" % label_file_list)
|
||||
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
|
||||
|
@ -50,6 +49,8 @@ class SimpleDataSet(Dataset):
|
|||
self.shuffle_data_random()
|
||||
self.ops = create_operators(dataset_config['transforms'], global_config)
|
||||
|
||||
self.need_reset = True in [x < 1 for x in ratio_list]
|
||||
|
||||
def get_image_info_list(self, file_list, ratio_list):
|
||||
if isinstance(file_list, str):
|
||||
file_list = [file_list]
|
||||
|
@ -95,7 +96,7 @@ class SimpleDataSet(Dataset):
|
|||
data['image'] = img
|
||||
data = transform(data, load_data_ops)
|
||||
|
||||
if data is None or data['polys'].shape[1]!=4:
|
||||
if data is None or data['polys'].shape[1] != 4:
|
||||
continue
|
||||
ext_data.append(data)
|
||||
return ext_data
|
||||
|
@ -121,7 +122,7 @@ class SimpleDataSet(Dataset):
|
|||
self.logger.error(
|
||||
"When parsing line {}, error happened with msg: {}".format(
|
||||
data_line, traceback.format_exc()))
|
||||
outs = None
|
||||
# outs = None
|
||||
if outs is None:
|
||||
# during evaluation, we should fix the idx to get same results for many times of evaluation.
|
||||
rnd_idx = np.random.randint(self.__len__(
|
||||
|
|
|
@ -16,6 +16,9 @@ import copy
|
|||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
# basic_loss
|
||||
from .basic_loss import LossFromOutput
|
||||
|
||||
# det loss
|
||||
from .det_db_loss import DBLoss
|
||||
from .det_east_loss import EASTLoss
|
||||
|
@ -46,12 +49,16 @@ from .combined_loss import CombinedLoss
|
|||
# table loss
|
||||
from .table_att_loss import TableAttentionLoss
|
||||
|
||||
# vqa token loss
|
||||
from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
|
||||
|
||||
|
||||
def build_loss(config):
|
||||
support_dict = [
|
||||
'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss',
|
||||
'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss',
|
||||
'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss'
|
||||
'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
|
||||
'VQASerTokenLayoutLMLoss', 'LossFromOutput'
|
||||
]
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
|
|
|
@ -133,3 +133,18 @@ class DistanceLoss(nn.Layer):
|
|||
|
||||
def forward(self, x, y):
|
||||
return self.loss_func(x, y)
|
||||
|
||||
|
||||
class LossFromOutput(nn.Layer):
|
||||
def __init__(self, key='loss', reduction='none'):
|
||||
super().__init__()
|
||||
self.key = key
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss = predicts[self.key]
|
||||
if self.reduction == 'mean':
|
||||
loss = paddle.mean(loss)
|
||||
elif self.reduction == 'sum':
|
||||
loss = paddle.sum(loss)
|
||||
return {'loss': loss}
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
|
@ -12,24 +12,31 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from paddle import nn
|
||||
|
||||
|
||||
class SERLoss(nn.Layer):
|
||||
class VQASerTokenLayoutLMLoss(nn.Layer):
|
||||
def __init__(self, num_classes):
|
||||
super().__init__()
|
||||
self.loss_class = nn.CrossEntropyLoss()
|
||||
self.num_classes = num_classes
|
||||
self.ignore_index = self.loss_class.ignore_index
|
||||
|
||||
def forward(self, labels, outputs, attention_mask):
|
||||
def forward(self, predicts, batch):
|
||||
labels = batch[1]
|
||||
attention_mask = batch[4]
|
||||
if attention_mask is not None:
|
||||
active_loss = attention_mask.reshape([-1, ]) == 1
|
||||
active_outputs = outputs.reshape(
|
||||
active_outputs = predicts.reshape(
|
||||
[-1, self.num_classes])[active_loss]
|
||||
active_labels = labels.reshape([-1, ])[active_loss]
|
||||
loss = self.loss_class(active_outputs, active_labels)
|
||||
else:
|
||||
loss = self.loss_class(
|
||||
outputs.reshape([-1, self.num_classes]), labels.reshape([-1, ]))
|
||||
return loss
|
||||
predicts.reshape([-1, self.num_classes]),
|
||||
labels.reshape([-1, ]))
|
||||
return {'loss': loss}
|
|
@ -28,12 +28,15 @@ from .e2e_metric import E2EMetric
|
|||
from .distillation_metric import DistillationMetric
|
||||
from .table_metric import TableMetric
|
||||
from .kie_metric import KIEMetric
|
||||
from .vqa_token_ser_metric import VQASerTokenMetric
|
||||
from .vqa_token_re_metric import VQAReTokenMetric
|
||||
|
||||
|
||||
def build_metric(config):
|
||||
support_dict = [
|
||||
"DetMetric", "RecMetric", "ClsMetric", "E2EMetric",
|
||||
"DistillationMetric", "TableMetric", 'KIEMetric'
|
||||
"DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric',
|
||||
'VQAReTokenMetric'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
|
|
|
@ -0,0 +1,177 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from seqeval.metrics import f1_score, precision_score, recall_score
|
||||
|
||||
__all__ = ['KIEMetric']
|
||||
|
||||
|
||||
class VQAReTokenMetric(object):
|
||||
def __init__(self, main_indicator='hmean', **kwargs):
|
||||
self.main_indicator = main_indicator
|
||||
self.reset()
|
||||
|
||||
def __call__(self, preds, batch, **kwargs):
|
||||
pred_relations, relations, entities = preds
|
||||
self.pred_relations_list.extend(pred_relations)
|
||||
self.relations_list.extend(relations)
|
||||
self.entities_list.extend(entities)
|
||||
|
||||
def get_metric(self):
|
||||
gt_relations = []
|
||||
for b in range(len(self.relations_list)):
|
||||
rel_sent = []
|
||||
for head, tail in zip(self.relations_list[b]["head"],
|
||||
self.relations_list[b]["tail"]):
|
||||
rel = {}
|
||||
rel["head_id"] = head
|
||||
rel["head"] = (self.entities_list[b]["start"][rel["head_id"]],
|
||||
self.entities_list[b]["end"][rel["head_id"]])
|
||||
rel["head_type"] = self.entities_list[b]["label"][rel[
|
||||
"head_id"]]
|
||||
|
||||
rel["tail_id"] = tail
|
||||
rel["tail"] = (self.entities_list[b]["start"][rel["tail_id"]],
|
||||
self.entities_list[b]["end"][rel["tail_id"]])
|
||||
rel["tail_type"] = self.entities_list[b]["label"][rel[
|
||||
"tail_id"]]
|
||||
|
||||
rel["type"] = 1
|
||||
rel_sent.append(rel)
|
||||
gt_relations.append(rel_sent)
|
||||
re_metrics = self.re_score(
|
||||
self.pred_relations_list, gt_relations, mode="boundaries")
|
||||
metrics = {
|
||||
"precision": re_metrics["ALL"]["p"],
|
||||
"recall": re_metrics["ALL"]["r"],
|
||||
"hmean": re_metrics["ALL"]["f1"],
|
||||
}
|
||||
self.reset()
|
||||
return metrics
|
||||
|
||||
def reset(self):
|
||||
self.pred_relations_list = []
|
||||
self.relations_list = []
|
||||
self.entities_list = []
|
||||
|
||||
def re_score(self, pred_relations, gt_relations, mode="strict"):
|
||||
"""Evaluate RE predictions
|
||||
|
||||
Args:
|
||||
pred_relations (list) : list of list of predicted relations (several relations in each sentence)
|
||||
gt_relations (list) : list of list of ground truth relations
|
||||
|
||||
rel = { "head": (start_idx (inclusive), end_idx (exclusive)),
|
||||
"tail": (start_idx (inclusive), end_idx (exclusive)),
|
||||
"head_type": ent_type,
|
||||
"tail_type": ent_type,
|
||||
"type": rel_type}
|
||||
|
||||
vocab (Vocab) : dataset vocabulary
|
||||
mode (str) : in 'strict' or 'boundaries'"""
|
||||
|
||||
assert mode in ["strict", "boundaries"]
|
||||
|
||||
relation_types = [v for v in [0, 1] if not v == 0]
|
||||
scores = {
|
||||
rel: {
|
||||
"tp": 0,
|
||||
"fp": 0,
|
||||
"fn": 0
|
||||
}
|
||||
for rel in relation_types + ["ALL"]
|
||||
}
|
||||
|
||||
# Count GT relations and Predicted relations
|
||||
n_sents = len(gt_relations)
|
||||
n_rels = sum([len([rel for rel in sent]) for sent in gt_relations])
|
||||
n_found = sum([len([rel for rel in sent]) for sent in pred_relations])
|
||||
|
||||
# Count TP, FP and FN per type
|
||||
for pred_sent, gt_sent in zip(pred_relations, gt_relations):
|
||||
for rel_type in relation_types:
|
||||
# strict mode takes argument types into account
|
||||
if mode == "strict":
|
||||
pred_rels = {(rel["head"], rel["head_type"], rel["tail"],
|
||||
rel["tail_type"])
|
||||
for rel in pred_sent
|
||||
if rel["type"] == rel_type}
|
||||
gt_rels = {(rel["head"], rel["head_type"], rel["tail"],
|
||||
rel["tail_type"])
|
||||
for rel in gt_sent if rel["type"] == rel_type}
|
||||
|
||||
# boundaries mode only takes argument spans into account
|
||||
elif mode == "boundaries":
|
||||
pred_rels = {(rel["head"], rel["tail"])
|
||||
for rel in pred_sent
|
||||
if rel["type"] == rel_type}
|
||||
gt_rels = {(rel["head"], rel["tail"])
|
||||
for rel in gt_sent if rel["type"] == rel_type}
|
||||
|
||||
scores[rel_type]["tp"] += len(pred_rels & gt_rels)
|
||||
scores[rel_type]["fp"] += len(pred_rels - gt_rels)
|
||||
scores[rel_type]["fn"] += len(gt_rels - pred_rels)
|
||||
|
||||
# Compute per entity Precision / Recall / F1
|
||||
for rel_type in scores.keys():
|
||||
if scores[rel_type]["tp"]:
|
||||
scores[rel_type]["p"] = scores[rel_type]["tp"] / (
|
||||
scores[rel_type]["fp"] + scores[rel_type]["tp"])
|
||||
scores[rel_type]["r"] = scores[rel_type]["tp"] / (
|
||||
scores[rel_type]["fn"] + scores[rel_type]["tp"])
|
||||
else:
|
||||
scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0
|
||||
|
||||
if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0:
|
||||
scores[rel_type]["f1"] = (
|
||||
2 * scores[rel_type]["p"] * scores[rel_type]["r"] /
|
||||
(scores[rel_type]["p"] + scores[rel_type]["r"]))
|
||||
else:
|
||||
scores[rel_type]["f1"] = 0
|
||||
|
||||
# Compute micro F1 Scores
|
||||
tp = sum([scores[rel_type]["tp"] for rel_type in relation_types])
|
||||
fp = sum([scores[rel_type]["fp"] for rel_type in relation_types])
|
||||
fn = sum([scores[rel_type]["fn"] for rel_type in relation_types])
|
||||
|
||||
if tp:
|
||||
precision = tp / (tp + fp)
|
||||
recall = tp / (tp + fn)
|
||||
f1 = 2 * precision * recall / (precision + recall)
|
||||
|
||||
else:
|
||||
precision, recall, f1 = 0, 0, 0
|
||||
|
||||
scores["ALL"]["p"] = precision
|
||||
scores["ALL"]["r"] = recall
|
||||
scores["ALL"]["f1"] = f1
|
||||
scores["ALL"]["tp"] = tp
|
||||
scores["ALL"]["fp"] = fp
|
||||
scores["ALL"]["fn"] = fn
|
||||
|
||||
# Compute Macro F1 Scores
|
||||
scores["ALL"]["Macro_f1"] = np.mean(
|
||||
[scores[ent_type]["f1"] for ent_type in relation_types])
|
||||
scores["ALL"]["Macro_p"] = np.mean(
|
||||
[scores[ent_type]["p"] for ent_type in relation_types])
|
||||
scores["ALL"]["Macro_r"] = np.mean(
|
||||
[scores[ent_type]["r"] for ent_type in relation_types])
|
||||
|
||||
return scores
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from seqeval.metrics import f1_score, precision_score, recall_score
|
||||
|
||||
__all__ = ['KIEMetric']
|
||||
|
||||
|
||||
class VQASerTokenMetric(object):
|
||||
def __init__(self, main_indicator='hmean', **kwargs):
|
||||
self.main_indicator = main_indicator
|
||||
self.reset()
|
||||
|
||||
def __call__(self, preds, batch, **kwargs):
|
||||
preds, labels = preds
|
||||
self.pred_list.extend(preds)
|
||||
self.gt_list.extend(labels)
|
||||
|
||||
def get_metric(self):
|
||||
metircs = {
|
||||
"precision": precision_score(self.gt_list, self.pred_list),
|
||||
"recall": recall_score(self.gt_list, self.pred_list),
|
||||
"hmean": f1_score(self.gt_list, self.pred_list),
|
||||
}
|
||||
self.reset()
|
||||
return metircs
|
||||
|
||||
def reset(self):
|
||||
self.pred_list = []
|
||||
self.gt_list = []
|
|
@ -63,8 +63,12 @@ class BaseModel(nn.Layer):
|
|||
in_channels = self.neck.out_channels
|
||||
|
||||
# # build head, head is need for det, rec and cls
|
||||
config["Head"]['in_channels'] = in_channels
|
||||
self.head = build_head(config["Head"])
|
||||
if 'Head' not in config or config['Head'] is None:
|
||||
self.use_head = False
|
||||
else:
|
||||
self.use_head = True
|
||||
config["Head"]['in_channels'] = in_channels
|
||||
self.head = build_head(config["Head"])
|
||||
|
||||
self.return_all_feats = config.get("return_all_feats", False)
|
||||
|
||||
|
@ -77,7 +81,8 @@ class BaseModel(nn.Layer):
|
|||
if self.use_neck:
|
||||
x = self.neck(x)
|
||||
y["neck_out"] = x
|
||||
x = self.head(x, targets=data)
|
||||
if self.use_head:
|
||||
x = self.head(x, targets=data)
|
||||
if isinstance(x, dict):
|
||||
y.update(x)
|
||||
else:
|
||||
|
|
|
@ -43,6 +43,9 @@ def build_backbone(config, model_type):
|
|||
from .table_resnet_vd import ResNet
|
||||
from .table_mobilenet_v3 import MobileNetV3
|
||||
support_dict = ["ResNet", "MobileNetV3"]
|
||||
elif model_type == 'vqa':
|
||||
from .vqa_layoutlm import LayoutLMForSer, LayoutXLMForSer, LayoutXLMForRe
|
||||
support_dict = ["LayoutLMForSer", "LayoutXLMForSer", 'LayoutXLMForRe']
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -0,0 +1,123 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
from paddle import nn
|
||||
|
||||
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMForTokenClassification, LayoutXLMForRelationExtraction
|
||||
from paddlenlp.transformers import LayoutLMModel, LayoutLMForTokenClassification
|
||||
|
||||
__all__ = ["LayoutXLMForSer", 'LayoutLMForSer']
|
||||
|
||||
|
||||
class NLPBaseModel(nn.Layer):
|
||||
def __init__(self,
|
||||
base_model_class,
|
||||
model_class,
|
||||
type='ser',
|
||||
pretrained_model=None,
|
||||
checkpoints=None,
|
||||
**kwargs):
|
||||
super(NLPBaseModel, self).__init__()
|
||||
assert pretrained_model is not None or checkpoints is not None, "one of pretrained_model and checkpoints must be not None"
|
||||
if checkpoints is not None:
|
||||
self.model = model_class.from_pretrained(checkpoints)
|
||||
else:
|
||||
base_model = base_model_class.from_pretrained(pretrained_model)
|
||||
if type == 'ser':
|
||||
self.model = model_class(
|
||||
base_model, num_classes=kwargs['num_classes'], dropout=None)
|
||||
else:
|
||||
self.model = model_class(base_model, dropout=None)
|
||||
self.out_channels = 1
|
||||
|
||||
|
||||
class LayoutXLMForSer(NLPBaseModel):
|
||||
def __init__(self,
|
||||
num_classes,
|
||||
pretrained_model='layoutxlm-base-uncased',
|
||||
checkpoints=None,
|
||||
**kwargs):
|
||||
super(LayoutXLMForSer, self).__init__(
|
||||
LayoutXLMModel,
|
||||
LayoutXLMForTokenClassification,
|
||||
'ser',
|
||||
pretrained_model,
|
||||
checkpoints,
|
||||
num_classes=num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.model(
|
||||
input_ids=x[0],
|
||||
bbox=x[2],
|
||||
image=x[3],
|
||||
attention_mask=x[4],
|
||||
token_type_ids=x[5],
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
labels=None)
|
||||
return x[0]
|
||||
|
||||
|
||||
class LayoutLMForSer(NLPBaseModel):
|
||||
def __init__(self,
|
||||
num_classes,
|
||||
pretrained_model='layoutxlm-base-uncased',
|
||||
checkpoints=None,
|
||||
**kwargs):
|
||||
super(LayoutLMForSer, self).__init__(
|
||||
LayoutLMModel,
|
||||
LayoutLMForTokenClassification,
|
||||
'ser',
|
||||
pretrained_model,
|
||||
checkpoints,
|
||||
num_classes=num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.model(
|
||||
input_ids=x[0],
|
||||
bbox=x[2],
|
||||
attention_mask=x[4],
|
||||
token_type_ids=x[5],
|
||||
position_ids=None,
|
||||
output_hidden_states=False)
|
||||
return x
|
||||
|
||||
|
||||
class LayoutXLMForRe(NLPBaseModel):
|
||||
def __init__(self,
|
||||
pretrained_model='layoutxlm-base-uncased',
|
||||
checkpoints=None,
|
||||
**kwargs):
|
||||
super(LayoutXLMForRe, self).__init__(
|
||||
LayoutXLMModel, LayoutXLMForRelationExtraction, 're',
|
||||
pretrained_model, checkpoints)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.model(
|
||||
input_ids=x[0],
|
||||
bbox=x[1],
|
||||
labels=None,
|
||||
image=x[2],
|
||||
attention_mask=x[3],
|
||||
token_type_ids=x[4],
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
entities=x[5],
|
||||
relations=x[6])
|
||||
return x
|
|
@ -42,7 +42,9 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
|
|||
# step2 build regularization
|
||||
if 'regularizer' in config and config['regularizer'] is not None:
|
||||
reg_config = config.pop('regularizer')
|
||||
reg_name = reg_config.pop('name') + 'Decay'
|
||||
reg_name = reg_config.pop('name')
|
||||
if not hasattr(regularizer, reg_name):
|
||||
reg_name += 'Decay'
|
||||
reg = getattr(regularizer, reg_name)(**reg_config)()
|
||||
else:
|
||||
reg = None
|
||||
|
|
|
@ -158,3 +158,38 @@ class Adadelta(object):
|
|||
name=self.name,
|
||||
parameters=parameters)
|
||||
return opt
|
||||
|
||||
|
||||
class AdamW(object):
|
||||
def __init__(self,
|
||||
learning_rate=0.001,
|
||||
beta1=0.9,
|
||||
beta2=0.999,
|
||||
epsilon=1e-08,
|
||||
weight_decay=0.01,
|
||||
grad_clip=None,
|
||||
name=None,
|
||||
lazy_mode=False,
|
||||
**kwargs):
|
||||
self.learning_rate = learning_rate
|
||||
self.beta1 = beta1
|
||||
self.beta2 = beta2
|
||||
self.epsilon = epsilon
|
||||
self.learning_rate = learning_rate
|
||||
self.weight_decay = 0.01 if weight_decay is None else weight_decay
|
||||
self.grad_clip = grad_clip
|
||||
self.name = name
|
||||
self.lazy_mode = lazy_mode
|
||||
|
||||
def __call__(self, parameters):
|
||||
opt = optim.AdamW(
|
||||
learning_rate=self.learning_rate,
|
||||
beta1=self.beta1,
|
||||
beta2=self.beta2,
|
||||
epsilon=self.epsilon,
|
||||
weight_decay=self.weight_decay,
|
||||
grad_clip=self.grad_clip,
|
||||
name=self.name,
|
||||
lazy_mode=self.lazy_mode,
|
||||
parameters=parameters)
|
||||
return opt
|
||||
|
|
|
@ -50,3 +50,18 @@ class L2Decay(object):
|
|||
def __call__(self):
|
||||
reg = paddle.regularizer.L2Decay(self.regularization_coeff)
|
||||
return reg
|
||||
|
||||
|
||||
class ConstDecay(object):
|
||||
"""
|
||||
Const L2 Weight Decay Regularization, which encourages the weights to be sparse.
|
||||
Args:
|
||||
factor(float): regularization coeff. Default:0.0.
|
||||
"""
|
||||
|
||||
def __init__(self, factor=0.0):
|
||||
super(ConstDecay, self).__init__()
|
||||
self.regularization_coeff = factor
|
||||
|
||||
def __call__(self):
|
||||
return self.regularization_coeff
|
||||
|
|
|
@ -28,6 +28,8 @@ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, Di
|
|||
TableLabelDecode, NRTRLabelDecode, SARLabelDecode, SEEDLabelDecode
|
||||
from .cls_postprocess import ClsPostProcess
|
||||
from .pg_postprocess import PGPostProcess
|
||||
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
|
||||
from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess
|
||||
|
||||
|
||||
def build_post_process(config, global_config=None):
|
||||
|
@ -36,7 +38,8 @@ def build_post_process(config, global_config=None):
|
|||
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
|
||||
'DistillationCTCLabelDecode', 'TableLabelDecode',
|
||||
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
|
||||
'SEEDLabelDecode'
|
||||
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
|
||||
'VQAReTokenLayoutLMPostProcess'
|
||||
]
|
||||
|
||||
if config['name'] == 'PSEPostProcess':
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import paddle
|
||||
|
||||
|
||||
class VQAReTokenLayoutLMPostProcess(object):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(VQAReTokenLayoutLMPostProcess, self).__init__()
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
if label is not None:
|
||||
return self._metric(preds, label)
|
||||
else:
|
||||
return self._infer(preds, *args, **kwargs)
|
||||
|
||||
def _metric(self, preds, label):
|
||||
return preds['pred_relations'], label[6], label[5]
|
||||
|
||||
def _infer(self, preds, *args, **kwargs):
|
||||
ser_results = kwargs['ser_results']
|
||||
entity_idx_dict_batch = kwargs['entity_idx_dict_batch']
|
||||
pred_relations = preds['pred_relations']
|
||||
|
||||
# 进行 relations 到 ocr信息的转换
|
||||
results = []
|
||||
for pred_relation, ser_result, entity_idx_dict in zip(
|
||||
pred_relations, ser_results, entity_idx_dict_batch):
|
||||
result = []
|
||||
used_tail_id = []
|
||||
for relation in pred_relation:
|
||||
if relation['tail_id'] in used_tail_id:
|
||||
continue
|
||||
used_tail_id.append(relation['tail_id'])
|
||||
ocr_info_head = ser_result[entity_idx_dict[relation['head_id']]]
|
||||
ocr_info_tail = ser_result[entity_idx_dict[relation['tail_id']]]
|
||||
result.append((ocr_info_head, ocr_info_tail))
|
||||
results.append(result)
|
||||
return results
|
|
@ -0,0 +1,93 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import paddle
|
||||
from ppocr.utils.utility import load_vqa_bio_label_maps
|
||||
|
||||
|
||||
class VQASerTokenLayoutLMPostProcess(object):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self, class_path, **kwargs):
|
||||
super(VQASerTokenLayoutLMPostProcess, self).__init__()
|
||||
label2id_map, self.id2label_map = load_vqa_bio_label_maps(class_path)
|
||||
|
||||
self.label2id_map_for_draw = dict()
|
||||
for key in label2id_map:
|
||||
if key.startswith("I-"):
|
||||
self.label2id_map_for_draw[key] = label2id_map["B" + key[1:]]
|
||||
else:
|
||||
self.label2id_map_for_draw[key] = label2id_map[key]
|
||||
|
||||
self.id2label_map_for_show = dict()
|
||||
for key in self.label2id_map_for_draw:
|
||||
val = self.label2id_map_for_draw[key]
|
||||
if key == "O":
|
||||
self.id2label_map_for_show[val] = key
|
||||
if key.startswith("B-") or key.startswith("I-"):
|
||||
self.id2label_map_for_show[val] = key[2:]
|
||||
else:
|
||||
self.id2label_map_for_show[val] = key
|
||||
|
||||
def __call__(self, preds, batch=None, *args, **kwargs):
|
||||
if isinstance(preds, paddle.Tensor):
|
||||
preds = preds.numpy()
|
||||
|
||||
if batch is not None:
|
||||
return self._metric(preds, batch[1])
|
||||
else:
|
||||
return self._infer(preds, **kwargs)
|
||||
|
||||
def _metric(self, preds, label):
|
||||
pred_idxs = preds.argmax(axis=2)
|
||||
decode_out_list = [[] for _ in range(pred_idxs.shape[0])]
|
||||
label_decode_out_list = [[] for _ in range(pred_idxs.shape[0])]
|
||||
|
||||
for i in range(pred_idxs.shape[0]):
|
||||
for j in range(pred_idxs.shape[1]):
|
||||
if label[i, j] != -100:
|
||||
label_decode_out_list[i].append(self.id2label_map[label[i,
|
||||
j]])
|
||||
decode_out_list[i].append(self.id2label_map[pred_idxs[i,
|
||||
j]])
|
||||
return decode_out_list, label_decode_out_list
|
||||
|
||||
def _infer(self, preds, attention_masks, segment_offset_ids, ocr_infos):
|
||||
results = []
|
||||
|
||||
for pred, attention_mask, segment_offset_id, ocr_info in zip(
|
||||
preds, attention_masks, segment_offset_ids, ocr_infos):
|
||||
pred = np.argmax(pred, axis=1)
|
||||
pred = [self.id2label_map[idx] for idx in pred]
|
||||
|
||||
for idx in range(len(segment_offset_id)):
|
||||
if idx == 0:
|
||||
start_id = 0
|
||||
else:
|
||||
start_id = segment_offset_id[idx - 1]
|
||||
|
||||
end_id = segment_offset_id[idx]
|
||||
|
||||
curr_pred = pred[start_id:end_id]
|
||||
curr_pred = [self.label2id_map_for_draw[p] for p in curr_pred]
|
||||
|
||||
if len(curr_pred) <= 0:
|
||||
pred_id = 0
|
||||
else:
|
||||
counts = np.bincount(curr_pred)
|
||||
pred_id = np.argmax(counts)
|
||||
ocr_info[idx]["pred_id"] = int(pred_id)
|
||||
ocr_info[idx]["pred"] = self.id2label_map_for_show[int(pred_id)]
|
||||
results.append(ocr_info)
|
||||
return results
|
|
@ -44,7 +44,7 @@ def _mkdir_if_not_exist(path, logger):
|
|||
raise OSError('Failed to mkdir {}'.format(path))
|
||||
|
||||
|
||||
def load_model(config, model, optimizer=None):
|
||||
def load_model(config, model, optimizer=None, model_type='det'):
|
||||
"""
|
||||
load model from checkpoint or pretrained_model
|
||||
"""
|
||||
|
@ -53,6 +53,33 @@ def load_model(config, model, optimizer=None):
|
|||
checkpoints = global_config.get('checkpoints')
|
||||
pretrained_model = global_config.get('pretrained_model')
|
||||
best_model_dict = {}
|
||||
|
||||
if model_type == 'vqa':
|
||||
checkpoints = config['Architecture']['Backbone']['checkpoints']
|
||||
# load vqa method metric
|
||||
if checkpoints:
|
||||
if os.path.exists(os.path.join(checkpoints, 'metric.states')):
|
||||
with open(os.path.join(checkpoints, 'metric.states'),
|
||||
'rb') as f:
|
||||
states_dict = pickle.load(f) if six.PY2 else pickle.load(
|
||||
f, encoding='latin1')
|
||||
best_model_dict = states_dict.get('best_model_dict', {})
|
||||
if 'epoch' in states_dict:
|
||||
best_model_dict['start_epoch'] = states_dict['epoch'] + 1
|
||||
logger.info("resume from {}".format(checkpoints))
|
||||
|
||||
if optimizer is not None:
|
||||
if checkpoints[-1] in ['/', '\\']:
|
||||
checkpoints = checkpoints[:-1]
|
||||
if os.path.exists(checkpoints + '.pdopt'):
|
||||
optim_dict = paddle.load(checkpoints + '.pdopt')
|
||||
optimizer.set_state_dict(optim_dict)
|
||||
else:
|
||||
logger.warning(
|
||||
"{}.pdopt is not exists, params of optimizer is not loaded".
|
||||
format(checkpoints))
|
||||
return best_model_dict
|
||||
|
||||
if checkpoints:
|
||||
if checkpoints.endswith('.pdparams'):
|
||||
checkpoints = checkpoints.replace('.pdparams', '')
|
||||
|
@ -127,6 +154,7 @@ def save_model(model,
|
|||
optimizer,
|
||||
model_path,
|
||||
logger,
|
||||
config,
|
||||
is_best=False,
|
||||
prefix='ppocr',
|
||||
**kwargs):
|
||||
|
@ -135,13 +163,20 @@ def save_model(model,
|
|||
"""
|
||||
_mkdir_if_not_exist(model_path, logger)
|
||||
model_prefix = os.path.join(model_path, prefix)
|
||||
paddle.save(model.state_dict(), model_prefix + '.pdparams')
|
||||
paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
|
||||
|
||||
if config['Architecture']["model_type"] != 'vqa':
|
||||
paddle.save(model.state_dict(), model_prefix + '.pdparams')
|
||||
metric_prefix = model_prefix
|
||||
else:
|
||||
if config['Global']['distributed']:
|
||||
model._layers.backbone.model.save_pretrained(model_prefix)
|
||||
else:
|
||||
model.backbone.model.save_pretrained(model_prefix)
|
||||
metric_prefix = os.path.join(model_prefix, 'metric')
|
||||
# save metric and config
|
||||
with open(model_prefix + '.states', 'wb') as f:
|
||||
pickle.dump(kwargs, f, protocol=2)
|
||||
if is_best:
|
||||
with open(metric_prefix + '.states', 'wb') as f:
|
||||
pickle.dump(kwargs, f, protocol=2)
|
||||
logger.info('save best model is to {}'.format(model_prefix))
|
||||
else:
|
||||
logger.info("save model in {}".format(model_prefix))
|
||||
|
|
|
@ -77,4 +77,22 @@ def check_and_read_gif(img_path):
|
|||
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
|
||||
imgvalue = frame[:, :, ::-1]
|
||||
return imgvalue, True
|
||||
return None, False
|
||||
return None, False
|
||||
|
||||
|
||||
def load_vqa_bio_label_maps(label_map_path):
|
||||
with open(label_map_path, "r", encoding='utf-8') as fin:
|
||||
lines = fin.readlines()
|
||||
lines = [line.strip() for line in lines]
|
||||
if "O" not in lines:
|
||||
lines.insert(0, "O")
|
||||
labels = []
|
||||
for line in lines:
|
||||
if line == "O":
|
||||
labels.append("O")
|
||||
else:
|
||||
labels.append("B-" + line)
|
||||
labels.append("I-" + line)
|
||||
label2id_map = {label: idx for idx, label in enumerate(labels)}
|
||||
id2label_map = {idx: label for idx, label in enumerate(labels)}
|
||||
return label2id_map, id2label_map
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
|
||||
def draw_ser_results(image,
|
||||
ocr_results,
|
||||
font_path="doc/fonts/simfang.ttf",
|
||||
font_size=18):
|
||||
np.random.seed(2021)
|
||||
color = (np.random.permutation(range(255)),
|
||||
np.random.permutation(range(255)),
|
||||
np.random.permutation(range(255)))
|
||||
color_map = {
|
||||
idx: (color[0][idx], color[1][idx], color[2][idx])
|
||||
for idx in range(1, 255)
|
||||
}
|
||||
if isinstance(image, np.ndarray):
|
||||
image = Image.fromarray(image)
|
||||
elif isinstance(image, str) and os.path.isfile(image):
|
||||
image = Image.open(image).convert('RGB')
|
||||
img_new = image.copy()
|
||||
draw = ImageDraw.Draw(img_new)
|
||||
|
||||
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
|
||||
for ocr_info in ocr_results:
|
||||
if ocr_info["pred_id"] not in color_map:
|
||||
continue
|
||||
color = color_map[ocr_info["pred_id"]]
|
||||
text = "{}: {}".format(ocr_info["pred"], ocr_info["text"])
|
||||
|
||||
draw_box_txt(ocr_info["bbox"], text, draw, font, font_size, color)
|
||||
|
||||
img_new = Image.blend(image, img_new, 0.5)
|
||||
return np.array(img_new)
|
||||
|
||||
|
||||
def draw_box_txt(bbox, text, draw, font, font_size, color):
|
||||
# draw ocr results outline
|
||||
bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3]))
|
||||
draw.rectangle(bbox, fill=color)
|
||||
|
||||
# draw ocr results
|
||||
start_y = max(0, bbox[0][1] - font_size)
|
||||
tw = font.getsize(text)[0]
|
||||
draw.rectangle(
|
||||
[(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1, start_y + font_size)],
|
||||
fill=(0, 0, 255))
|
||||
draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)
|
||||
|
||||
|
||||
def draw_re_results(image,
|
||||
result,
|
||||
font_path="doc/fonts/simfang.ttf",
|
||||
font_size=18):
|
||||
np.random.seed(0)
|
||||
if isinstance(image, np.ndarray):
|
||||
image = Image.fromarray(image)
|
||||
elif isinstance(image, str) and os.path.isfile(image):
|
||||
image = Image.open(image).convert('RGB')
|
||||
img_new = image.copy()
|
||||
draw = ImageDraw.Draw(img_new)
|
||||
|
||||
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
|
||||
color_head = (0, 0, 255)
|
||||
color_tail = (255, 0, 0)
|
||||
color_line = (0, 255, 0)
|
||||
|
||||
for ocr_info_head, ocr_info_tail in result:
|
||||
draw_box_txt(ocr_info_head["bbox"], ocr_info_head["text"], draw, font,
|
||||
font_size, color_head)
|
||||
draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["text"], draw, font,
|
||||
font_size, color_tail)
|
||||
|
||||
center_head = (
|
||||
(ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
|
||||
(ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2)
|
||||
center_tail = (
|
||||
(ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2,
|
||||
(ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2)
|
||||
|
||||
draw.line([center_head, center_tail], fill=color_line, width=5)
|
||||
|
||||
img_new = Image.blend(image, img_new, 0.5)
|
||||
return np.array(img_new)
|
|
@ -24,8 +24,8 @@
|
|||
|
||||
|模型名称|模型简介|推理模型大小|下载地址|
|
||||
| --- | --- | --- | --- |
|
||||
|PP-Layout_v1.0_ser_pretrained|基于LayoutXLM在xfun中文数据集上训练的SER模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar) |
|
||||
|PP-Layout_v1.0_re_pretrained|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar) |
|
||||
|PP-Layout_v1.0_ser_pretrained|基于LayoutXLM在xfun中文数据集上训练的SER模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
|
||||
|PP-Layout_v1.0_re_pretrained|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
|
||||
|
||||
## 3. KIE模型
|
||||
|
||||
|
|
|
@ -20,11 +20,11 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进
|
|||
|
||||
我们在 [XFUN](https://github.com/doc-analysis/XFUND) 的中文数据集上对算法进行了评估,性能如下
|
||||
|
||||
| 模型 | 任务 | f1 | 模型下载地址 |
|
||||
| 模型 | 任务 | hmean | 模型下载地址 |
|
||||
|:---:|:---:|:---:| :---:|
|
||||
| LayoutXLM | RE | 0.7113 | [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar) |
|
||||
| LayoutXLM | SER | 0.9056 | [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar) |
|
||||
| LayoutLM | SER | 0.78 | [链接](https://paddleocr.bj.bcebos.com/pplayout/LayoutLM_ser_pretrained.tar) |
|
||||
| LayoutXLM | RE | 0.7483 | [链接](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
|
||||
| LayoutXLM | SER | 0.9038 | [链接](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
|
||||
| LayoutLM | SER | 0.7731 | [链接](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
|
||||
|
||||
|
||||
|
||||
|
@ -65,10 +65,10 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进
|
|||
pip3 install --upgrade pip
|
||||
|
||||
# GPU安装
|
||||
python3 -m pip install paddlepaddle-gpu==2.2 -i https://mirror.baidu.com/pypi/simple
|
||||
python3 -m pip install "paddlepaddle-gpu>=2.2" -i https://mirror.baidu.com/pypi/simple
|
||||
|
||||
# CPU安装
|
||||
python3 -m pip install paddlepaddle==2.2 -i https://mirror.baidu.com/pypi/simple
|
||||
python3 -m pip install "paddlepaddle>=2.2" -i https://mirror.baidu.com/pypi/simple
|
||||
|
||||
```
|
||||
更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
|
||||
|
@ -79,7 +79,7 @@ python3 -m pip install paddlepaddle==2.2 -i https://mirror.baidu.com/pypi/simple
|
|||
- **(1)pip快速安装PaddleOCR whl包(仅预测)**
|
||||
|
||||
```bash
|
||||
pip install paddleocr
|
||||
python3 -m pip install paddleocr
|
||||
```
|
||||
|
||||
- **(2)下载VQA源码(预测+训练)**
|
||||
|
@ -93,18 +93,10 @@ git clone https://gitee.com/paddlepaddle/PaddleOCR
|
|||
# 注:码云托管代码可能无法实时同步本github项目更新,存在3~5天延时,请优先使用推荐方式。
|
||||
```
|
||||
|
||||
- **(3)安装PaddleNLP**
|
||||
- **(3)安装VQA的`requirements`**
|
||||
|
||||
```bash
|
||||
pip3 install "paddlenlp>=2.2.1"
|
||||
```
|
||||
|
||||
|
||||
- **(4)安装VQA的`requirements`**
|
||||
|
||||
```bash
|
||||
cd ppstructure/vqa
|
||||
pip install -r requirements.txt
|
||||
python3 -m pip install -r ppstructure/vqa/requirements.txt
|
||||
```
|
||||
|
||||
## 4. 使用
|
||||
|
@ -112,6 +104,10 @@ pip install -r requirements.txt
|
|||
|
||||
### 4.1 数据和预训练模型准备
|
||||
|
||||
如果希望直接体验预测过程,可以下载我们提供的预训练模型,跳过训练过程,直接预测即可。
|
||||
|
||||
* 下载处理好的数据集
|
||||
|
||||
处理好的XFUN中文数据集下载地址:[https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar)。
|
||||
|
||||
|
||||
|
@ -121,98 +117,62 @@ pip install -r requirements.txt
|
|||
wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
|
||||
```
|
||||
|
||||
如果希望转换XFUN中其他语言的数据集,可以参考[XFUN数据转换脚本](helper/trans_xfun_data.py)。
|
||||
* 转换数据集
|
||||
|
||||
如果希望直接体验预测过程,可以下载我们提供的预训练模型,跳过训练过程,直接预测即可。
|
||||
若需进行其他XFUN数据集的训练,可使用下面的命令进行数据集的转换
|
||||
|
||||
```bash
|
||||
python3 ppstructure/vqa/helper/trans_xfun_data.py --ori_gt_path=path/to/json_path --output_path=path/to/save_path
|
||||
```
|
||||
|
||||
### 4.2 SER任务
|
||||
|
||||
* 启动训练
|
||||
启动训练之前,需要修改下面的四个字段
|
||||
|
||||
1. `Train.dataset.data_dir`:指向训练集图片存放目录
|
||||
2. `Train.dataset.label_file_list`:指向训练集标注文件
|
||||
3. `Eval.dataset.data_dir`:指指向验证集图片存放目录
|
||||
4. `Eval.dataset.label_file_list`:指向验证集标注文件
|
||||
|
||||
* 启动训练
|
||||
```shell
|
||||
python3.7 train_ser.py \
|
||||
--model_name_or_path "layoutxlm-base-uncased" \
|
||||
--ser_model_type "LayoutXLM" \
|
||||
--train_data_dir "XFUND/zh_train/image" \
|
||||
--train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
|
||||
--eval_data_dir "XFUND/zh_val/image" \
|
||||
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
|
||||
--num_train_epochs 200 \
|
||||
--eval_steps 10 \
|
||||
--output_dir "./output/ser/" \
|
||||
--learning_rate 5e-5 \
|
||||
--warmup_steps 50 \
|
||||
--evaluate_during_training \
|
||||
--seed 2048
|
||||
CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/ser/layoutxlm.yml
|
||||
```
|
||||
|
||||
最终会打印出`precision`, `recall`, `f1`等指标,模型和训练日志会保存在`./output/ser/`文件夹中。
|
||||
最终会打印出`precision`, `recall`, `hmean`等指标。
|
||||
在`./output/ser_layoutxlm/`文件夹中会保存训练日志,最优的模型和最新epoch的模型。
|
||||
|
||||
* 恢复训练
|
||||
|
||||
恢复训练需要将之前训练好的模型所在文件夹路径赋值给 `Architecture.Backbone.checkpoints` 字段。
|
||||
|
||||
```shell
|
||||
python3.7 train_ser.py \
|
||||
--model_name_or_path "model_path" \
|
||||
--ser_model_type "LayoutXLM" \
|
||||
--train_data_dir "XFUND/zh_train/image" \
|
||||
--train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
|
||||
--eval_data_dir "XFUND/zh_val/image" \
|
||||
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
|
||||
--num_train_epochs 200 \
|
||||
--eval_steps 10 \
|
||||
--output_dir "./output/ser/" \
|
||||
--learning_rate 5e-5 \
|
||||
--warmup_steps 50 \
|
||||
--evaluate_during_training \
|
||||
--num_workers 8 \
|
||||
--seed 2048 \
|
||||
--resume
|
||||
CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=path/to/model_dir
|
||||
```
|
||||
|
||||
* 评估
|
||||
```shell
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
python3 eval_ser.py \
|
||||
--model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \
|
||||
--ser_model_type "LayoutXLM" \
|
||||
--eval_data_dir "XFUND/zh_val/image" \
|
||||
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
|
||||
--per_gpu_eval_batch_size 8 \
|
||||
--num_workers 8 \
|
||||
--output_dir "output/ser/" \
|
||||
--seed 2048
|
||||
```
|
||||
最终会打印出`precision`, `recall`, `f1`等指标
|
||||
|
||||
* 使用评估集合中提供的OCR识别结果进行预测
|
||||
评估需要将待评估的模型所在文件夹路径赋值给 `Architecture.Backbone.checkpoints` 字段。
|
||||
|
||||
```shell
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
python3.7 infer_ser.py \
|
||||
--model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \
|
||||
--ser_model_type "LayoutXLM" \
|
||||
--output_dir "output/ser/" \
|
||||
--infer_imgs "XFUND/zh_val/image/" \
|
||||
--ocr_json_path "XFUND/zh_val/xfun_normalize_val.json"
|
||||
CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=path/to/model_dir
|
||||
```
|
||||
最终会打印出`precision`, `recall`, `hmean`等指标
|
||||
|
||||
最终会在`output_res`目录下保存预测结果可视化图像以及预测结果文本文件,文件名为`infer_results.txt`。
|
||||
* 使用`OCR引擎 + SER`串联预测
|
||||
|
||||
* 使用`OCR引擎 + SER`串联结果
|
||||
使用如下命令即可完成`OCR引擎 + SER`的串联预测
|
||||
|
||||
```shell
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
python3.7 infer_ser_e2e.py \
|
||||
--model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \
|
||||
--ser_model_type "LayoutXLM" \
|
||||
--max_seq_length 512 \
|
||||
--output_dir "output/ser_e2e/" \
|
||||
--infer_imgs "images/input/zh_val_0.jpg"
|
||||
CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=PP-Layout_v1.0_ser_pretrained/ Global.infer_img=ppstructure/vqa/images/input/zh_val_42.jpg
|
||||
```
|
||||
|
||||
最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt`。
|
||||
|
||||
* 对`OCR引擎 + SER`预测系统进行端到端评估
|
||||
|
||||
首先使用 `tools/infer_vqa_token_ser.py` 脚本完成数据集的预测,然后使用下面的命令进行评估。
|
||||
|
||||
```shell
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
python3.7 helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt
|
||||
|
@ -223,102 +183,48 @@ python3.7 helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_nor
|
|||
|
||||
* 启动训练
|
||||
|
||||
```shell
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
python3 train_re.py \
|
||||
--model_name_or_path "layoutxlm-base-uncased" \
|
||||
--train_data_dir "XFUND/zh_train/image" \
|
||||
--train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
|
||||
--eval_data_dir "XFUND/zh_val/image" \
|
||||
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
|
||||
--label_map_path "labels/labels_ser.txt" \
|
||||
--num_train_epochs 200 \
|
||||
--eval_steps 10 \
|
||||
--output_dir "output/re/" \
|
||||
--learning_rate 5e-5 \
|
||||
--warmup_steps 50 \
|
||||
--per_gpu_train_batch_size 8 \
|
||||
--per_gpu_eval_batch_size 8 \
|
||||
--num_workers 8 \
|
||||
--evaluate_during_training \
|
||||
--seed 2048
|
||||
启动训练之前,需要修改下面的四个字段
|
||||
|
||||
1. `Train.dataset.data_dir`:指向训练集图片存放目录
|
||||
2. `Train.dataset.label_file_list`:指向训练集标注文件
|
||||
3. `Eval.dataset.data_dir`:指指向验证集图片存放目录
|
||||
4. `Eval.dataset.label_file_list`:指向验证集标注文件
|
||||
|
||||
```shell
|
||||
CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/re/layoutxlm.yml
|
||||
```
|
||||
|
||||
最终会打印出`precision`, `recall`, `hmean`等指标。
|
||||
在`./output/re_layoutxlm/`文件夹中会保存训练日志,最优的模型和最新epoch的模型。
|
||||
|
||||
* 恢复训练
|
||||
|
||||
恢复训练需要将之前训练好的模型所在文件夹路径赋值给 `Architecture.Backbone.checkpoints` 字段。
|
||||
|
||||
```shell
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
python3 train_re.py \
|
||||
--model_name_or_path "model_path" \
|
||||
--train_data_dir "XFUND/zh_train/image" \
|
||||
--train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
|
||||
--eval_data_dir "XFUND/zh_val/image" \
|
||||
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
|
||||
--label_map_path "labels/labels_ser.txt" \
|
||||
--num_train_epochs 2 \
|
||||
--eval_steps 10 \
|
||||
--output_dir "output/re/" \
|
||||
--learning_rate 5e-5 \
|
||||
--warmup_steps 50 \
|
||||
--per_gpu_train_batch_size 8 \
|
||||
--per_gpu_eval_batch_size 8 \
|
||||
--num_workers 8 \
|
||||
--evaluate_during_training \
|
||||
--seed 2048 \
|
||||
--resume
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=path/to/model_dir
|
||||
```
|
||||
|
||||
最终会打印出`precision`, `recall`, `f1`等指标,模型和训练日志会保存在`./output/re/`文件夹中。
|
||||
|
||||
* 评估
|
||||
```shell
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
python3 eval_re.py \
|
||||
--model_name_or_path "PP-Layout_v1.0_re_pretrained/" \
|
||||
--max_seq_length 512 \
|
||||
--eval_data_dir "XFUND/zh_val/image" \
|
||||
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
|
||||
--label_map_path "labels/labels_ser.txt" \
|
||||
--output_dir "output/re/" \
|
||||
--per_gpu_eval_batch_size 8 \
|
||||
--num_workers 8 \
|
||||
--seed 2048
|
||||
```
|
||||
最终会打印出`precision`, `recall`, `f1`等指标
|
||||
|
||||
|
||||
* 使用评估集合中提供的OCR识别结果进行预测
|
||||
评估需要将待评估的模型所在文件夹路径赋值给 `Architecture.Backbone.checkpoints` 字段。
|
||||
|
||||
```shell
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
python3 infer_re.py \
|
||||
--model_name_or_path "PP-Layout_v1.0_re_pretrained/" \
|
||||
--max_seq_length 512 \
|
||||
--eval_data_dir "XFUND/zh_val/image" \
|
||||
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
|
||||
--label_map_path "labels/labels_ser.txt" \
|
||||
--output_dir "output/re/" \
|
||||
--per_gpu_eval_batch_size 1 \
|
||||
--seed 2048
|
||||
CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=path/to/model_dir
|
||||
```
|
||||
最终会打印出`precision`, `recall`, `hmean`等指标
|
||||
|
||||
最终会在`output_res`目录下保存预测结果可视化图像以及预测结果文本文件,文件名为`infer_results.txt`。
|
||||
|
||||
* 使用`OCR引擎 + SER + RE`串联结果
|
||||
* 使用`OCR引擎 + SER + RE`串联预测
|
||||
|
||||
使用如下命令即可完成`OCR引擎 + SER + RE`的串联预测
|
||||
```shell
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
python3.7 infer_ser_re_e2e.py \
|
||||
--model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \
|
||||
--re_model_name_or_path "PP-Layout_v1.0_re_pretrained/" \
|
||||
--ser_model_type "LayoutXLM" \
|
||||
--max_seq_length 512 \
|
||||
--output_dir "output/ser_re_e2e/" \
|
||||
--infer_imgs "images/input/zh_val_21.jpg"
|
||||
python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=PP-Layout_v1.0_re_pretrained/ Global.infer_img=ppstructure/vqa/images/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=PP-Layout_v1.0_ser_pretrained/
|
||||
```
|
||||
|
||||
最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt`。
|
||||
|
||||
|
||||
## 参考链接
|
||||
|
||||
- LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf
|
||||
|
|
|
@ -1,125 +0,0 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
import paddle
|
||||
|
||||
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
|
||||
|
||||
from xfun import XFUNDataset
|
||||
from vqa_utils import parse_args, get_bio_label_maps, print_arguments
|
||||
from data_collator import DataCollator
|
||||
from metric import re_score
|
||||
|
||||
from ppocr.utils.logging import get_logger
|
||||
|
||||
|
||||
def cal_metric(re_preds, re_labels, entities):
|
||||
gt_relations = []
|
||||
for b in range(len(re_labels)):
|
||||
rel_sent = []
|
||||
for head, tail in zip(re_labels[b]["head"], re_labels[b]["tail"]):
|
||||
rel = {}
|
||||
rel["head_id"] = head
|
||||
rel["head"] = (entities[b]["start"][rel["head_id"]],
|
||||
entities[b]["end"][rel["head_id"]])
|
||||
rel["head_type"] = entities[b]["label"][rel["head_id"]]
|
||||
|
||||
rel["tail_id"] = tail
|
||||
rel["tail"] = (entities[b]["start"][rel["tail_id"]],
|
||||
entities[b]["end"][rel["tail_id"]])
|
||||
rel["tail_type"] = entities[b]["label"][rel["tail_id"]]
|
||||
|
||||
rel["type"] = 1
|
||||
rel_sent.append(rel)
|
||||
gt_relations.append(rel_sent)
|
||||
re_metrics = re_score(re_preds, gt_relations, mode="boundaries")
|
||||
return re_metrics
|
||||
|
||||
|
||||
def evaluate(model, eval_dataloader, logger, prefix=""):
|
||||
# Eval!
|
||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||
logger.info(" Num examples = {}".format(len(eval_dataloader.dataset)))
|
||||
|
||||
re_preds = []
|
||||
re_labels = []
|
||||
entities = []
|
||||
eval_loss = 0.0
|
||||
model.eval()
|
||||
for idx, batch in enumerate(eval_dataloader):
|
||||
with paddle.no_grad():
|
||||
outputs = model(**batch)
|
||||
loss = outputs['loss'].mean().item()
|
||||
if paddle.distributed.get_rank() == 0:
|
||||
logger.info("[Eval] process: {}/{}, loss: {:.5f}".format(
|
||||
idx, len(eval_dataloader), loss))
|
||||
|
||||
eval_loss += loss
|
||||
re_preds.extend(outputs['pred_relations'])
|
||||
re_labels.extend(batch['relations'])
|
||||
entities.extend(batch['entities'])
|
||||
re_metrics = cal_metric(re_preds, re_labels, entities)
|
||||
re_metrics = {
|
||||
"precision": re_metrics["ALL"]["p"],
|
||||
"recall": re_metrics["ALL"]["r"],
|
||||
"f1": re_metrics["ALL"]["f1"],
|
||||
}
|
||||
model.train()
|
||||
return re_metrics
|
||||
|
||||
|
||||
def eval(args):
|
||||
logger = get_logger()
|
||||
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
|
||||
pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
|
||||
|
||||
tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
|
||||
|
||||
model = LayoutXLMForRelationExtraction.from_pretrained(
|
||||
args.model_name_or_path)
|
||||
|
||||
eval_dataset = XFUNDataset(
|
||||
tokenizer,
|
||||
data_dir=args.eval_data_dir,
|
||||
label_path=args.eval_label_path,
|
||||
label2id_map=label2id_map,
|
||||
img_size=(224, 224),
|
||||
max_seq_len=args.max_seq_length,
|
||||
pad_token_label_id=pad_token_label_id,
|
||||
contains_re=True,
|
||||
add_special_ids=False,
|
||||
return_attention_mask=True,
|
||||
load_mode='all')
|
||||
|
||||
eval_dataloader = paddle.io.DataLoader(
|
||||
eval_dataset,
|
||||
batch_size=args.per_gpu_eval_batch_size,
|
||||
num_workers=args.num_workers,
|
||||
shuffle=False,
|
||||
collate_fn=DataCollator())
|
||||
|
||||
results = evaluate(model, eval_dataloader, logger)
|
||||
logger.info("eval results: {}".format(results))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
eval(args)
|
|
@ -1,177 +0,0 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
import random
|
||||
import time
|
||||
import copy
|
||||
import logging
|
||||
|
||||
import argparse
|
||||
import paddle
|
||||
import numpy as np
|
||||
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
|
||||
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
|
||||
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
|
||||
|
||||
from xfun import XFUNDataset
|
||||
from losses import SERLoss
|
||||
from vqa_utils import parse_args, get_bio_label_maps, print_arguments
|
||||
|
||||
from ppocr.utils.logging import get_logger
|
||||
|
||||
MODELS = {
|
||||
'LayoutXLM':
|
||||
(LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
|
||||
'LayoutLM':
|
||||
(LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
|
||||
}
|
||||
|
||||
|
||||
def eval(args):
|
||||
logger = get_logger()
|
||||
print_arguments(args, logger)
|
||||
|
||||
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
|
||||
pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
|
||||
|
||||
tokenizer_class, base_model_class, model_class = MODELS[args.ser_model_type]
|
||||
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
|
||||
model = model_class.from_pretrained(args.model_name_or_path)
|
||||
|
||||
eval_dataset = XFUNDataset(
|
||||
tokenizer,
|
||||
data_dir=args.eval_data_dir,
|
||||
label_path=args.eval_label_path,
|
||||
label2id_map=label2id_map,
|
||||
img_size=(224, 224),
|
||||
pad_token_label_id=pad_token_label_id,
|
||||
contains_re=False,
|
||||
add_special_ids=False,
|
||||
return_attention_mask=True,
|
||||
load_mode='all')
|
||||
|
||||
eval_dataloader = paddle.io.DataLoader(
|
||||
eval_dataset,
|
||||
batch_size=args.per_gpu_eval_batch_size,
|
||||
num_workers=args.num_workers,
|
||||
use_shared_memory=True,
|
||||
collate_fn=None, )
|
||||
|
||||
loss_class = SERLoss(len(label2id_map))
|
||||
|
||||
results, _ = evaluate(args, model, tokenizer, loss_class, eval_dataloader,
|
||||
label2id_map, id2label_map, pad_token_label_id,
|
||||
logger)
|
||||
|
||||
logger.info(results)
|
||||
|
||||
|
||||
def evaluate(args,
|
||||
model,
|
||||
tokenizer,
|
||||
loss_class,
|
||||
eval_dataloader,
|
||||
label2id_map,
|
||||
id2label_map,
|
||||
pad_token_label_id,
|
||||
logger,
|
||||
prefix=""):
|
||||
|
||||
eval_loss = 0.0
|
||||
nb_eval_steps = 0
|
||||
preds = None
|
||||
out_label_ids = None
|
||||
model.eval()
|
||||
for idx, batch in enumerate(eval_dataloader):
|
||||
with paddle.no_grad():
|
||||
if args.ser_model_type == 'LayoutLM':
|
||||
if 'image' in batch:
|
||||
batch.pop('image')
|
||||
labels = batch.pop('labels')
|
||||
outputs = model(**batch)
|
||||
if args.ser_model_type == 'LayoutXLM':
|
||||
outputs = outputs[0]
|
||||
loss = loss_class(labels, outputs, batch['attention_mask'])
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
if paddle.distributed.get_rank() == 0:
|
||||
logger.info("[Eval]process: {}/{}, loss: {:.5f}".format(
|
||||
idx, len(eval_dataloader), loss.numpy()[0]))
|
||||
|
||||
eval_loss += loss.item()
|
||||
nb_eval_steps += 1
|
||||
if preds is None:
|
||||
preds = outputs.numpy()
|
||||
out_label_ids = labels.numpy()
|
||||
else:
|
||||
preds = np.append(preds, outputs.numpy(), axis=0)
|
||||
out_label_ids = np.append(out_label_ids, labels.numpy(), axis=0)
|
||||
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
preds = np.argmax(preds, axis=2)
|
||||
|
||||
# label_map = {i: label.upper() for i, label in enumerate(labels)}
|
||||
|
||||
out_label_list = [[] for _ in range(out_label_ids.shape[0])]
|
||||
preds_list = [[] for _ in range(out_label_ids.shape[0])]
|
||||
|
||||
for i in range(out_label_ids.shape[0]):
|
||||
for j in range(out_label_ids.shape[1]):
|
||||
if out_label_ids[i, j] != pad_token_label_id:
|
||||
out_label_list[i].append(id2label_map[out_label_ids[i][j]])
|
||||
preds_list[i].append(id2label_map[preds[i][j]])
|
||||
|
||||
results = {
|
||||
"loss": eval_loss,
|
||||
"precision": precision_score(out_label_list, preds_list),
|
||||
"recall": recall_score(out_label_list, preds_list),
|
||||
"f1": f1_score(out_label_list, preds_list),
|
||||
}
|
||||
|
||||
with open(
|
||||
os.path.join(args.output_dir, "test_gt.txt"), "w",
|
||||
encoding='utf-8') as fout:
|
||||
for lbl in out_label_list:
|
||||
for l in lbl:
|
||||
fout.write(l + "\t")
|
||||
fout.write("\n")
|
||||
with open(
|
||||
os.path.join(args.output_dir, "test_pred.txt"), "w",
|
||||
encoding='utf-8') as fout:
|
||||
for lbl in preds_list:
|
||||
for l in lbl:
|
||||
fout.write(l + "\t")
|
||||
fout.write("\n")
|
||||
|
||||
report = classification_report(out_label_list, preds_list)
|
||||
logger.info("\n" + report)
|
||||
|
||||
logger.info("***** Eval results %s *****", prefix)
|
||||
for key in sorted(results.keys()):
|
||||
logger.info(" %s = %s", key, str(results[key]))
|
||||
model.train()
|
||||
return results, preds_list
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
eval(args)
|
|
@ -49,4 +49,16 @@ def transfer_xfun_data(json_path=None, output_file=None):
|
|||
print("===ok====")
|
||||
|
||||
|
||||
transfer_xfun_data("./xfun/zh.val.json", "./xfun_normalize_val.json")
|
||||
def parser_args():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="args for paddleserving")
|
||||
parser.add_argument(
|
||||
"--ori_gt_path", type=str, required=True, help='origin xfun gt path')
|
||||
parser.add_argument(
|
||||
"--output_path", type=str, required=True, help='path to save')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
args = parser_args()
|
||||
transfer_xfun_data(args.ori_gt_path, args.output_path)
|
||||
|
|
|
@ -1,61 +0,0 @@
|
|||
export CUDA_VISIBLE_DEVICES=6
|
||||
# python3.7 infer_ser_e2e.py \
|
||||
# --model_name_or_path "output/ser_distributed/best_model" \
|
||||
# --max_seq_length 512 \
|
||||
# --output_dir "output_res_e2e/" \
|
||||
# --infer_imgs "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/zh.val/zh_val_0.jpg"
|
||||
|
||||
|
||||
# python3.7 infer_ser_re_e2e.py \
|
||||
# --model_name_or_path "output/ser_distributed/best_model" \
|
||||
# --re_model_name_or_path "output/re_test/best_model" \
|
||||
# --max_seq_length 512 \
|
||||
# --output_dir "output_ser_re_e2e_train/" \
|
||||
# --infer_imgs "images/input/zh_val_21.jpg"
|
||||
|
||||
# python3.7 infer_ser.py \
|
||||
# --model_name_or_path "output/ser_LayoutLM/best_model" \
|
||||
# --ser_model_type "LayoutLM" \
|
||||
# --output_dir "ser_LayoutLM/" \
|
||||
# --infer_imgs "images/input/zh_val_21.jpg" \
|
||||
# --ocr_json_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json"
|
||||
|
||||
python3.7 infer_ser.py \
|
||||
--model_name_or_path "output/ser_new/best_model" \
|
||||
--ser_model_type "LayoutXLM" \
|
||||
--output_dir "ser_new/" \
|
||||
--infer_imgs "images/input/zh_val_21.jpg" \
|
||||
--ocr_json_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json"
|
||||
|
||||
# python3.7 infer_ser_e2e.py \
|
||||
# --model_name_or_path "output/ser_new/best_model" \
|
||||
# --ser_model_type "LayoutXLM" \
|
||||
# --max_seq_length 512 \
|
||||
# --output_dir "output/ser_new/" \
|
||||
# --infer_imgs "images/input/zh_val_0.jpg"
|
||||
|
||||
|
||||
# python3.7 infer_ser_e2e.py \
|
||||
# --model_name_or_path "output/ser_LayoutLM/best_model" \
|
||||
# --ser_model_type "LayoutLM" \
|
||||
# --max_seq_length 512 \
|
||||
# --output_dir "output/ser_LayoutLM/" \
|
||||
# --infer_imgs "images/input/zh_val_0.jpg"
|
||||
|
||||
# python3 infer_re.py \
|
||||
# --model_name_or_path "/ssd1/zhoujun20/VQA/PaddleOCR/ppstructure/vqa/output/re_test/best_model/" \
|
||||
# --max_seq_length 512 \
|
||||
# --eval_data_dir "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/zh.val" \
|
||||
# --eval_label_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json" \
|
||||
# --label_map_path 'labels/labels_ser.txt' \
|
||||
# --output_dir "output_res" \
|
||||
# --per_gpu_eval_batch_size 1 \
|
||||
# --seed 2048
|
||||
|
||||
# python3.7 infer_ser_re_e2e.py \
|
||||
# --model_name_or_path "output/ser_LayoutLM/best_model" \
|
||||
# --ser_model_type "LayoutLM" \
|
||||
# --re_model_name_or_path "output/re_new/best_model" \
|
||||
# --max_seq_length 512 \
|
||||
# --output_dir "output_ser_re_e2e/" \
|
||||
# --infer_imgs "images/input/zh_val_21.jpg"
|
|
@ -1,165 +0,0 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
import random
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
|
||||
|
||||
from xfun import XFUNDataset
|
||||
from vqa_utils import parse_args, get_bio_label_maps, draw_re_results
|
||||
from data_collator import DataCollator
|
||||
|
||||
from ppocr.utils.logging import get_logger
|
||||
|
||||
|
||||
def infer(args):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
logger = get_logger()
|
||||
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
|
||||
pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
|
||||
|
||||
tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
|
||||
|
||||
model = LayoutXLMForRelationExtraction.from_pretrained(
|
||||
args.model_name_or_path)
|
||||
|
||||
eval_dataset = XFUNDataset(
|
||||
tokenizer,
|
||||
data_dir=args.eval_data_dir,
|
||||
label_path=args.eval_label_path,
|
||||
label2id_map=label2id_map,
|
||||
img_size=(224, 224),
|
||||
max_seq_len=args.max_seq_length,
|
||||
pad_token_label_id=pad_token_label_id,
|
||||
contains_re=True,
|
||||
add_special_ids=False,
|
||||
return_attention_mask=True,
|
||||
load_mode='all')
|
||||
|
||||
eval_dataloader = paddle.io.DataLoader(
|
||||
eval_dataset,
|
||||
batch_size=args.per_gpu_eval_batch_size,
|
||||
num_workers=8,
|
||||
shuffle=False,
|
||||
collate_fn=DataCollator())
|
||||
|
||||
# 读取gt的oct数据
|
||||
ocr_info_list = load_ocr(args.eval_data_dir, args.eval_label_path)
|
||||
|
||||
for idx, batch in enumerate(eval_dataloader):
|
||||
ocr_info = ocr_info_list[idx]
|
||||
image_path = ocr_info['image_path']
|
||||
ocr_info = ocr_info['ocr_info']
|
||||
|
||||
save_img_path = os.path.join(
|
||||
args.output_dir,
|
||||
os.path.splitext(os.path.basename(image_path))[0] + "_re.jpg")
|
||||
logger.info("[Infer] process: {}/{}, save result to {}".format(
|
||||
idx, len(eval_dataloader), save_img_path))
|
||||
with paddle.no_grad():
|
||||
outputs = model(**batch)
|
||||
pred_relations = outputs['pred_relations']
|
||||
|
||||
# 根据entity里的信息,做token解码后去过滤不要的ocr_info
|
||||
ocr_info = filter_bg_by_txt(ocr_info, batch, tokenizer)
|
||||
|
||||
# 进行 relations 到 ocr信息的转换
|
||||
result = []
|
||||
used_tail_id = []
|
||||
for relations in pred_relations:
|
||||
for relation in relations:
|
||||
if relation['tail_id'] in used_tail_id:
|
||||
continue
|
||||
if relation['head_id'] not in ocr_info or relation[
|
||||
'tail_id'] not in ocr_info:
|
||||
continue
|
||||
used_tail_id.append(relation['tail_id'])
|
||||
ocr_info_head = ocr_info[relation['head_id']]
|
||||
ocr_info_tail = ocr_info[relation['tail_id']]
|
||||
result.append((ocr_info_head, ocr_info_tail))
|
||||
|
||||
img = cv2.imread(image_path)
|
||||
img_show = draw_re_results(img, result)
|
||||
cv2.imwrite(save_img_path, img_show)
|
||||
|
||||
|
||||
def load_ocr(img_folder, json_path):
|
||||
import json
|
||||
d = []
|
||||
with open(json_path, "r", encoding='utf-8') as fin:
|
||||
lines = fin.readlines()
|
||||
for line in lines:
|
||||
image_name, info_str = line.split("\t")
|
||||
info_dict = json.loads(info_str)
|
||||
info_dict['image_path'] = os.path.join(img_folder, image_name)
|
||||
d.append(info_dict)
|
||||
return d
|
||||
|
||||
|
||||
def filter_bg_by_txt(ocr_info, batch, tokenizer):
|
||||
entities = batch['entities'][0]
|
||||
input_ids = batch['input_ids'][0]
|
||||
|
||||
new_info_dict = {}
|
||||
for i in range(len(entities['start'])):
|
||||
entitie_head = entities['start'][i]
|
||||
entitie_tail = entities['end'][i]
|
||||
word_input_ids = input_ids[entitie_head:entitie_tail].numpy().tolist()
|
||||
txt = tokenizer.convert_ids_to_tokens(word_input_ids)
|
||||
txt = tokenizer.convert_tokens_to_string(txt)
|
||||
|
||||
for i, info in enumerate(ocr_info):
|
||||
if info['text'] == txt:
|
||||
new_info_dict[i] = info
|
||||
return new_info_dict
|
||||
|
||||
|
||||
def post_process(pred_relations, ocr_info, img):
|
||||
result = []
|
||||
for relations in pred_relations:
|
||||
for relation in relations:
|
||||
ocr_info_head = ocr_info[relation['head_id']]
|
||||
ocr_info_tail = ocr_info[relation['tail_id']]
|
||||
result.append((ocr_info_head, ocr_info_tail))
|
||||
return result
|
||||
|
||||
|
||||
def draw_re(result, image_path, output_folder):
|
||||
img = cv2.imread(image_path)
|
||||
|
||||
from matplotlib import pyplot as plt
|
||||
for ocr_info_head, ocr_info_tail in result:
|
||||
cv2.rectangle(
|
||||
img,
|
||||
tuple(ocr_info_head['bbox'][:2]),
|
||||
tuple(ocr_info_head['bbox'][2:]), (255, 0, 0),
|
||||
thickness=2)
|
||||
cv2.rectangle(
|
||||
img,
|
||||
tuple(ocr_info_tail['bbox'][:2]),
|
||||
tuple(ocr_info_tail['bbox'][2:]), (0, 0, 255),
|
||||
thickness=2)
|
||||
center_p1 = [(ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
|
||||
(ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2]
|
||||
center_p2 = [(ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2,
|
||||
(ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2]
|
||||
cv2.line(
|
||||
img, tuple(center_p1), tuple(center_p2), (0, 255, 0), thickness=2)
|
||||
plt.imshow(img)
|
||||
plt.savefig(
|
||||
os.path.join(output_folder, os.path.basename(image_path)), dpi=600)
|
||||
# plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
infer(args)
|
|
@ -1,302 +0,0 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
|
||||
import json
|
||||
import cv2
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
|
||||
import paddle
|
||||
|
||||
# relative reference
|
||||
from vqa_utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
|
||||
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
|
||||
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
|
||||
|
||||
MODELS = {
|
||||
'LayoutXLM':
|
||||
(LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
|
||||
'LayoutLM':
|
||||
(LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
|
||||
}
|
||||
|
||||
|
||||
def pad_sentences(tokenizer,
|
||||
encoded_inputs,
|
||||
max_seq_len=512,
|
||||
pad_to_max_seq_len=True,
|
||||
return_attention_mask=True,
|
||||
return_token_type_ids=True,
|
||||
return_overflowing_tokens=False,
|
||||
return_special_tokens_mask=False):
|
||||
# Padding with larger size, reshape is carried out
|
||||
max_seq_len = (
|
||||
len(encoded_inputs["input_ids"]) // max_seq_len + 1) * max_seq_len
|
||||
|
||||
needs_to_be_padded = pad_to_max_seq_len and \
|
||||
max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len
|
||||
|
||||
if needs_to_be_padded:
|
||||
difference = max_seq_len - len(encoded_inputs["input_ids"])
|
||||
if tokenizer.padding_side == 'right':
|
||||
if return_attention_mask:
|
||||
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
|
||||
"input_ids"]) + [0] * difference
|
||||
if return_token_type_ids:
|
||||
encoded_inputs["token_type_ids"] = (
|
||||
encoded_inputs["token_type_ids"] +
|
||||
[tokenizer.pad_token_type_id] * difference)
|
||||
if return_special_tokens_mask:
|
||||
encoded_inputs["special_tokens_mask"] = encoded_inputs[
|
||||
"special_tokens_mask"] + [1] * difference
|
||||
encoded_inputs["input_ids"] = encoded_inputs[
|
||||
"input_ids"] + [tokenizer.pad_token_id] * difference
|
||||
encoded_inputs["bbox"] = encoded_inputs["bbox"] + [[0, 0, 0, 0]
|
||||
] * difference
|
||||
else:
|
||||
assert False, "padding_side of tokenizer just supports [\"right\"] but got {}".format(
|
||||
tokenizer.padding_side)
|
||||
else:
|
||||
if return_attention_mask:
|
||||
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
|
||||
"input_ids"])
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
|
||||
def split_page(encoded_inputs, max_seq_len=512):
|
||||
"""
|
||||
truncate is often used in training process
|
||||
"""
|
||||
for key in encoded_inputs:
|
||||
encoded_inputs[key] = paddle.to_tensor(encoded_inputs[key])
|
||||
if encoded_inputs[key].ndim <= 1: # for input_ids, att_mask and so on
|
||||
encoded_inputs[key] = encoded_inputs[key].reshape([-1, max_seq_len])
|
||||
else: # for bbox
|
||||
encoded_inputs[key] = encoded_inputs[key].reshape(
|
||||
[-1, max_seq_len, 4])
|
||||
return encoded_inputs
|
||||
|
||||
|
||||
def preprocess(
|
||||
tokenizer,
|
||||
ori_img,
|
||||
ocr_info,
|
||||
img_size=(224, 224),
|
||||
pad_token_label_id=-100,
|
||||
max_seq_len=512,
|
||||
add_special_ids=False,
|
||||
return_attention_mask=True, ):
|
||||
ocr_info = deepcopy(ocr_info)
|
||||
height = ori_img.shape[0]
|
||||
width = ori_img.shape[1]
|
||||
|
||||
img = cv2.resize(ori_img,
|
||||
(224, 224)).transpose([2, 0, 1]).astype(np.float32)
|
||||
|
||||
segment_offset_id = []
|
||||
words_list = []
|
||||
bbox_list = []
|
||||
input_ids_list = []
|
||||
token_type_ids_list = []
|
||||
|
||||
for info in ocr_info:
|
||||
# x1, y1, x2, y2
|
||||
bbox = info["bbox"]
|
||||
bbox[0] = int(bbox[0] * 1000.0 / width)
|
||||
bbox[2] = int(bbox[2] * 1000.0 / width)
|
||||
bbox[1] = int(bbox[1] * 1000.0 / height)
|
||||
bbox[3] = int(bbox[3] * 1000.0 / height)
|
||||
|
||||
text = info["text"]
|
||||
encode_res = tokenizer.encode(
|
||||
text, pad_to_max_seq_len=False, return_attention_mask=True)
|
||||
|
||||
if not add_special_ids:
|
||||
# TODO: use tok.all_special_ids to remove
|
||||
encode_res["input_ids"] = encode_res["input_ids"][1:-1]
|
||||
encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1]
|
||||
encode_res["attention_mask"] = encode_res["attention_mask"][1:-1]
|
||||
|
||||
input_ids_list.extend(encode_res["input_ids"])
|
||||
token_type_ids_list.extend(encode_res["token_type_ids"])
|
||||
bbox_list.extend([bbox] * len(encode_res["input_ids"]))
|
||||
words_list.append(text)
|
||||
segment_offset_id.append(len(input_ids_list))
|
||||
|
||||
encoded_inputs = {
|
||||
"input_ids": input_ids_list,
|
||||
"token_type_ids": token_type_ids_list,
|
||||
"bbox": bbox_list,
|
||||
"attention_mask": [1] * len(input_ids_list),
|
||||
}
|
||||
|
||||
encoded_inputs = pad_sentences(
|
||||
tokenizer,
|
||||
encoded_inputs,
|
||||
max_seq_len=max_seq_len,
|
||||
return_attention_mask=return_attention_mask)
|
||||
|
||||
encoded_inputs = split_page(encoded_inputs)
|
||||
|
||||
fake_bs = encoded_inputs["input_ids"].shape[0]
|
||||
|
||||
encoded_inputs["image"] = paddle.to_tensor(img).unsqueeze(0).expand(
|
||||
[fake_bs] + list(img.shape))
|
||||
|
||||
encoded_inputs["segment_offset_id"] = segment_offset_id
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
|
||||
def postprocess(attention_mask, preds, label_map_path):
|
||||
if isinstance(preds, paddle.Tensor):
|
||||
preds = preds.numpy()
|
||||
preds = np.argmax(preds, axis=2)
|
||||
|
||||
_, label_map = get_bio_label_maps(label_map_path)
|
||||
|
||||
preds_list = [[] for _ in range(preds.shape[0])]
|
||||
|
||||
# keep batch info
|
||||
for i in range(preds.shape[0]):
|
||||
for j in range(preds.shape[1]):
|
||||
if attention_mask[i][j] == 1:
|
||||
preds_list[i].append(label_map[preds[i][j]])
|
||||
|
||||
return preds_list
|
||||
|
||||
|
||||
def merge_preds_list_with_ocr_info(label_map_path, ocr_info, segment_offset_id,
|
||||
preds_list):
|
||||
# must ensure the preds_list is generated from the same image
|
||||
preds = [p for pred in preds_list for p in pred]
|
||||
label2id_map, _ = get_bio_label_maps(label_map_path)
|
||||
for key in label2id_map:
|
||||
if key.startswith("I-"):
|
||||
label2id_map[key] = label2id_map["B" + key[1:]]
|
||||
|
||||
id2label_map = dict()
|
||||
for key in label2id_map:
|
||||
val = label2id_map[key]
|
||||
if key == "O":
|
||||
id2label_map[val] = key
|
||||
if key.startswith("B-") or key.startswith("I-"):
|
||||
id2label_map[val] = key[2:]
|
||||
else:
|
||||
id2label_map[val] = key
|
||||
|
||||
for idx in range(len(segment_offset_id)):
|
||||
if idx == 0:
|
||||
start_id = 0
|
||||
else:
|
||||
start_id = segment_offset_id[idx - 1]
|
||||
|
||||
end_id = segment_offset_id[idx]
|
||||
|
||||
curr_pred = preds[start_id:end_id]
|
||||
curr_pred = [label2id_map[p] for p in curr_pred]
|
||||
|
||||
if len(curr_pred) <= 0:
|
||||
pred_id = 0
|
||||
else:
|
||||
counts = np.bincount(curr_pred)
|
||||
pred_id = np.argmax(counts)
|
||||
ocr_info[idx]["pred_id"] = int(pred_id)
|
||||
ocr_info[idx]["pred"] = id2label_map[pred_id]
|
||||
return ocr_info
|
||||
|
||||
|
||||
@paddle.no_grad()
|
||||
def infer(args):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# init token and model
|
||||
tokenizer_class, base_model_class, model_class = MODELS[args.ser_model_type]
|
||||
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
|
||||
model = model_class.from_pretrained(args.model_name_or_path)
|
||||
|
||||
model.eval()
|
||||
|
||||
# load ocr results json
|
||||
ocr_results = dict()
|
||||
with open(args.ocr_json_path, "r", encoding='utf-8') as fin:
|
||||
lines = fin.readlines()
|
||||
for line in lines:
|
||||
img_name, json_info = line.split("\t")
|
||||
ocr_results[os.path.basename(img_name)] = json.loads(json_info)
|
||||
|
||||
# get infer img list
|
||||
infer_imgs = get_image_file_list(args.infer_imgs)
|
||||
|
||||
# loop for infer
|
||||
with open(
|
||||
os.path.join(args.output_dir, "infer_results.txt"),
|
||||
"w",
|
||||
encoding='utf-8') as fout:
|
||||
for idx, img_path in enumerate(infer_imgs):
|
||||
save_img_path = os.path.join(args.output_dir,
|
||||
os.path.basename(img_path))
|
||||
print("process: [{}/{}], save result to {}".format(
|
||||
idx, len(infer_imgs), save_img_path))
|
||||
|
||||
img = cv2.imread(img_path)
|
||||
|
||||
ocr_info = ocr_results[os.path.basename(img_path)]["ocr_info"]
|
||||
inputs = preprocess(
|
||||
tokenizer=tokenizer,
|
||||
ori_img=img,
|
||||
ocr_info=ocr_info,
|
||||
max_seq_len=args.max_seq_length)
|
||||
if args.ser_model_type == 'LayoutLM':
|
||||
preds = model(
|
||||
input_ids=inputs["input_ids"],
|
||||
bbox=inputs["bbox"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
attention_mask=inputs["attention_mask"])
|
||||
elif args.ser_model_type == 'LayoutXLM':
|
||||
preds = model(
|
||||
input_ids=inputs["input_ids"],
|
||||
bbox=inputs["bbox"],
|
||||
image=inputs["image"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
attention_mask=inputs["attention_mask"])
|
||||
preds = preds[0]
|
||||
|
||||
preds = postprocess(inputs["attention_mask"], preds,
|
||||
args.label_map_path)
|
||||
ocr_info = merge_preds_list_with_ocr_info(
|
||||
args.label_map_path, ocr_info, inputs["segment_offset_id"],
|
||||
preds)
|
||||
|
||||
fout.write(img_path + "\t" + json.dumps(
|
||||
{
|
||||
"ocr_info": ocr_info,
|
||||
}, ensure_ascii=False) + "\n")
|
||||
|
||||
img_res = draw_ser_results(img, ocr_info)
|
||||
cv2.imwrite(save_img_path, img_res)
|
||||
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
infer(args)
|
|
@ -1,156 +0,0 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
|
||||
import json
|
||||
import cv2
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from PIL import Image
|
||||
|
||||
import paddle
|
||||
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
|
||||
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
|
||||
|
||||
# relative reference
|
||||
from vqa_utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
|
||||
|
||||
from vqa_utils import pad_sentences, split_page, preprocess, postprocess, merge_preds_list_with_ocr_info
|
||||
|
||||
MODELS = {
|
||||
'LayoutXLM':
|
||||
(LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
|
||||
'LayoutLM':
|
||||
(LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
|
||||
}
|
||||
|
||||
|
||||
def trans_poly_to_bbox(poly):
|
||||
x1 = np.min([p[0] for p in poly])
|
||||
x2 = np.max([p[0] for p in poly])
|
||||
y1 = np.min([p[1] for p in poly])
|
||||
y2 = np.max([p[1] for p in poly])
|
||||
return [x1, y1, x2, y2]
|
||||
|
||||
|
||||
def parse_ocr_info_for_ser(ocr_result):
|
||||
ocr_info = []
|
||||
for res in ocr_result:
|
||||
ocr_info.append({
|
||||
"text": res[1][0],
|
||||
"bbox": trans_poly_to_bbox(res[0]),
|
||||
"poly": res[0],
|
||||
})
|
||||
return ocr_info
|
||||
|
||||
|
||||
class SerPredictor(object):
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self.max_seq_length = args.max_seq_length
|
||||
|
||||
# init ser token and model
|
||||
tokenizer_class, base_model_class, model_class = MODELS[
|
||||
args.ser_model_type]
|
||||
self.tokenizer = tokenizer_class.from_pretrained(
|
||||
args.model_name_or_path)
|
||||
self.model = model_class.from_pretrained(args.model_name_or_path)
|
||||
self.model.eval()
|
||||
|
||||
# init ocr_engine
|
||||
from paddleocr import PaddleOCR
|
||||
|
||||
self.ocr_engine = PaddleOCR(
|
||||
rec_model_dir=args.rec_model_dir,
|
||||
det_model_dir=args.det_model_dir,
|
||||
use_angle_cls=False,
|
||||
show_log=False)
|
||||
# init dict
|
||||
label2id_map, self.id2label_map = get_bio_label_maps(
|
||||
args.label_map_path)
|
||||
self.label2id_map_for_draw = dict()
|
||||
for key in label2id_map:
|
||||
if key.startswith("I-"):
|
||||
self.label2id_map_for_draw[key] = label2id_map["B" + key[1:]]
|
||||
else:
|
||||
self.label2id_map_for_draw[key] = label2id_map[key]
|
||||
|
||||
def __call__(self, img):
|
||||
ocr_result = self.ocr_engine.ocr(img, cls=False)
|
||||
|
||||
ocr_info = parse_ocr_info_for_ser(ocr_result)
|
||||
|
||||
inputs = preprocess(
|
||||
tokenizer=self.tokenizer,
|
||||
ori_img=img,
|
||||
ocr_info=ocr_info,
|
||||
max_seq_len=self.max_seq_length)
|
||||
|
||||
if self.args.ser_model_type == 'LayoutLM':
|
||||
preds = self.model(
|
||||
input_ids=inputs["input_ids"],
|
||||
bbox=inputs["bbox"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
attention_mask=inputs["attention_mask"])
|
||||
elif self.args.ser_model_type == 'LayoutXLM':
|
||||
preds = self.model(
|
||||
input_ids=inputs["input_ids"],
|
||||
bbox=inputs["bbox"],
|
||||
image=inputs["image"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
attention_mask=inputs["attention_mask"])
|
||||
preds = preds[0]
|
||||
|
||||
preds = postprocess(inputs["attention_mask"], preds, self.id2label_map)
|
||||
ocr_info = merge_preds_list_with_ocr_info(
|
||||
ocr_info, inputs["segment_offset_id"], preds,
|
||||
self.label2id_map_for_draw)
|
||||
return ocr_info, inputs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# get infer img list
|
||||
infer_imgs = get_image_file_list(args.infer_imgs)
|
||||
|
||||
# loop for infer
|
||||
ser_engine = SerPredictor(args)
|
||||
with open(
|
||||
os.path.join(args.output_dir, "infer_results.txt"),
|
||||
"w",
|
||||
encoding='utf-8') as fout:
|
||||
for idx, img_path in enumerate(infer_imgs):
|
||||
save_img_path = os.path.join(
|
||||
args.output_dir,
|
||||
os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg")
|
||||
print("process: [{}/{}], save result to {}".format(
|
||||
idx, len(infer_imgs), save_img_path))
|
||||
|
||||
img = cv2.imread(img_path)
|
||||
|
||||
result, _ = ser_engine(img)
|
||||
fout.write(img_path + "\t" + json.dumps(
|
||||
{
|
||||
"ser_resule": result,
|
||||
}, ensure_ascii=False) + "\n")
|
||||
|
||||
img_res = draw_ser_results(img, result)
|
||||
cv2.imwrite(save_img_path, img_res)
|
|
@ -1,135 +0,0 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import cv2
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from PIL import Image
|
||||
|
||||
import paddle
|
||||
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForRelationExtraction
|
||||
|
||||
# relative reference
|
||||
from vqa_utils import parse_args, get_image_file_list, draw_re_results
|
||||
from infer_ser_e2e import SerPredictor
|
||||
|
||||
|
||||
def make_input(ser_input, ser_result, max_seq_len=512):
|
||||
entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2}
|
||||
|
||||
entities = ser_input['entities'][0]
|
||||
assert len(entities) == len(ser_result)
|
||||
|
||||
# entities
|
||||
start = []
|
||||
end = []
|
||||
label = []
|
||||
entity_idx_dict = {}
|
||||
for i, (res, entity) in enumerate(zip(ser_result, entities)):
|
||||
if res['pred'] == 'O':
|
||||
continue
|
||||
entity_idx_dict[len(start)] = i
|
||||
start.append(entity['start'])
|
||||
end.append(entity['end'])
|
||||
label.append(entities_labels[res['pred']])
|
||||
entities = dict(start=start, end=end, label=label)
|
||||
|
||||
# relations
|
||||
head = []
|
||||
tail = []
|
||||
for i in range(len(entities["label"])):
|
||||
for j in range(len(entities["label"])):
|
||||
if entities["label"][i] == 1 and entities["label"][j] == 2:
|
||||
head.append(i)
|
||||
tail.append(j)
|
||||
|
||||
relations = dict(head=head, tail=tail)
|
||||
|
||||
batch_size = ser_input["input_ids"].shape[0]
|
||||
entities_batch = []
|
||||
relations_batch = []
|
||||
for b in range(batch_size):
|
||||
entities_batch.append(entities)
|
||||
relations_batch.append(relations)
|
||||
|
||||
ser_input['entities'] = entities_batch
|
||||
ser_input['relations'] = relations_batch
|
||||
|
||||
ser_input.pop('segment_offset_id')
|
||||
return ser_input, entity_idx_dict
|
||||
|
||||
|
||||
class SerReSystem(object):
|
||||
def __init__(self, args):
|
||||
self.ser_engine = SerPredictor(args)
|
||||
self.tokenizer = LayoutXLMTokenizer.from_pretrained(
|
||||
args.re_model_name_or_path)
|
||||
self.model = LayoutXLMForRelationExtraction.from_pretrained(
|
||||
args.re_model_name_or_path)
|
||||
self.model.eval()
|
||||
|
||||
def __call__(self, img):
|
||||
ser_result, ser_inputs = self.ser_engine(img)
|
||||
re_input, entity_idx_dict = make_input(ser_inputs, ser_result)
|
||||
|
||||
re_result = self.model(**re_input)
|
||||
|
||||
pred_relations = re_result['pred_relations'][0]
|
||||
# 进行 relations 到 ocr信息的转换
|
||||
result = []
|
||||
used_tail_id = []
|
||||
for relation in pred_relations:
|
||||
if relation['tail_id'] in used_tail_id:
|
||||
continue
|
||||
used_tail_id.append(relation['tail_id'])
|
||||
ocr_info_head = ser_result[entity_idx_dict[relation['head_id']]]
|
||||
ocr_info_tail = ser_result[entity_idx_dict[relation['tail_id']]]
|
||||
result.append((ocr_info_head, ocr_info_tail))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# get infer img list
|
||||
infer_imgs = get_image_file_list(args.infer_imgs)
|
||||
|
||||
# loop for infer
|
||||
ser_re_engine = SerReSystem(args)
|
||||
with open(
|
||||
os.path.join(args.output_dir, "infer_results.txt"),
|
||||
"w",
|
||||
encoding='utf-8') as fout:
|
||||
for idx, img_path in enumerate(infer_imgs):
|
||||
save_img_path = os.path.join(
|
||||
args.output_dir,
|
||||
os.path.splitext(os.path.basename(img_path))[0] + "_re.jpg")
|
||||
print("process: [{}/{}], save result to {}".format(
|
||||
idx, len(infer_imgs), save_img_path))
|
||||
|
||||
img = cv2.imread(img_path)
|
||||
|
||||
result = ser_re_engine(img)
|
||||
fout.write(img_path + "\t" + json.dumps(
|
||||
{
|
||||
"result": result,
|
||||
}, ensure_ascii=False) + "\n")
|
||||
|
||||
img_res = draw_re_results(img, result)
|
||||
cv2.imwrite(save_img_path, img_res)
|
|
@ -1,175 +0,0 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PREFIX_CHECKPOINT_DIR = "checkpoint"
|
||||
_re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$")
|
||||
|
||||
|
||||
def get_last_checkpoint(folder):
|
||||
content = os.listdir(folder)
|
||||
checkpoints = [
|
||||
path for path in content
|
||||
if _re_checkpoint.search(path) is not None and os.path.isdir(
|
||||
os.path.join(folder, path))
|
||||
]
|
||||
if len(checkpoints) == 0:
|
||||
return
|
||||
return os.path.join(
|
||||
folder,
|
||||
max(checkpoints,
|
||||
key=lambda x: int(_re_checkpoint.search(x).groups()[0])))
|
||||
|
||||
|
||||
def re_score(pred_relations, gt_relations, mode="strict"):
|
||||
"""Evaluate RE predictions
|
||||
|
||||
Args:
|
||||
pred_relations (list) : list of list of predicted relations (several relations in each sentence)
|
||||
gt_relations (list) : list of list of ground truth relations
|
||||
|
||||
rel = { "head": (start_idx (inclusive), end_idx (exclusive)),
|
||||
"tail": (start_idx (inclusive), end_idx (exclusive)),
|
||||
"head_type": ent_type,
|
||||
"tail_type": ent_type,
|
||||
"type": rel_type}
|
||||
|
||||
vocab (Vocab) : dataset vocabulary
|
||||
mode (str) : in 'strict' or 'boundaries'"""
|
||||
|
||||
assert mode in ["strict", "boundaries"]
|
||||
|
||||
relation_types = [v for v in [0, 1] if not v == 0]
|
||||
scores = {
|
||||
rel: {
|
||||
"tp": 0,
|
||||
"fp": 0,
|
||||
"fn": 0
|
||||
}
|
||||
for rel in relation_types + ["ALL"]
|
||||
}
|
||||
|
||||
# Count GT relations and Predicted relations
|
||||
n_sents = len(gt_relations)
|
||||
n_rels = sum([len([rel for rel in sent]) for sent in gt_relations])
|
||||
n_found = sum([len([rel for rel in sent]) for sent in pred_relations])
|
||||
|
||||
# Count TP, FP and FN per type
|
||||
for pred_sent, gt_sent in zip(pred_relations, gt_relations):
|
||||
for rel_type in relation_types:
|
||||
# strict mode takes argument types into account
|
||||
if mode == "strict":
|
||||
pred_rels = {(rel["head"], rel["head_type"], rel["tail"],
|
||||
rel["tail_type"])
|
||||
for rel in pred_sent if rel["type"] == rel_type}
|
||||
gt_rels = {(rel["head"], rel["head_type"], rel["tail"],
|
||||
rel["tail_type"])
|
||||
for rel in gt_sent if rel["type"] == rel_type}
|
||||
|
||||
# boundaries mode only takes argument spans into account
|
||||
elif mode == "boundaries":
|
||||
pred_rels = {(rel["head"], rel["tail"])
|
||||
for rel in pred_sent if rel["type"] == rel_type}
|
||||
gt_rels = {(rel["head"], rel["tail"])
|
||||
for rel in gt_sent if rel["type"] == rel_type}
|
||||
|
||||
scores[rel_type]["tp"] += len(pred_rels & gt_rels)
|
||||
scores[rel_type]["fp"] += len(pred_rels - gt_rels)
|
||||
scores[rel_type]["fn"] += len(gt_rels - pred_rels)
|
||||
|
||||
# Compute per entity Precision / Recall / F1
|
||||
for rel_type in scores.keys():
|
||||
if scores[rel_type]["tp"]:
|
||||
scores[rel_type]["p"] = scores[rel_type]["tp"] / (
|
||||
scores[rel_type]["fp"] + scores[rel_type]["tp"])
|
||||
scores[rel_type]["r"] = scores[rel_type]["tp"] / (
|
||||
scores[rel_type]["fn"] + scores[rel_type]["tp"])
|
||||
else:
|
||||
scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0
|
||||
|
||||
if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0:
|
||||
scores[rel_type]["f1"] = (
|
||||
2 * scores[rel_type]["p"] * scores[rel_type]["r"] /
|
||||
(scores[rel_type]["p"] + scores[rel_type]["r"]))
|
||||
else:
|
||||
scores[rel_type]["f1"] = 0
|
||||
|
||||
# Compute micro F1 Scores
|
||||
tp = sum([scores[rel_type]["tp"] for rel_type in relation_types])
|
||||
fp = sum([scores[rel_type]["fp"] for rel_type in relation_types])
|
||||
fn = sum([scores[rel_type]["fn"] for rel_type in relation_types])
|
||||
|
||||
if tp:
|
||||
precision = tp / (tp + fp)
|
||||
recall = tp / (tp + fn)
|
||||
f1 = 2 * precision * recall / (precision + recall)
|
||||
|
||||
else:
|
||||
precision, recall, f1 = 0, 0, 0
|
||||
|
||||
scores["ALL"]["p"] = precision
|
||||
scores["ALL"]["r"] = recall
|
||||
scores["ALL"]["f1"] = f1
|
||||
scores["ALL"]["tp"] = tp
|
||||
scores["ALL"]["fp"] = fp
|
||||
scores["ALL"]["fn"] = fn
|
||||
|
||||
# Compute Macro F1 Scores
|
||||
scores["ALL"]["Macro_f1"] = np.mean(
|
||||
[scores[ent_type]["f1"] for ent_type in relation_types])
|
||||
scores["ALL"]["Macro_p"] = np.mean(
|
||||
[scores[ent_type]["p"] for ent_type in relation_types])
|
||||
scores["ALL"]["Macro_r"] = np.mean(
|
||||
[scores[ent_type]["r"] for ent_type in relation_types])
|
||||
|
||||
# logger.info(f"RE Evaluation in *** {mode.upper()} *** mode")
|
||||
|
||||
# logger.info(
|
||||
# "processed {} sentences with {} relations; found: {} relations; correct: {}.".format(
|
||||
# n_sents, n_rels, n_found, tp
|
||||
# )
|
||||
# )
|
||||
# logger.info(
|
||||
# "\tALL\t TP: {};\tFP: {};\tFN: {}".format(scores["ALL"]["tp"], scores["ALL"]["fp"], scores["ALL"]["fn"])
|
||||
# )
|
||||
# logger.info("\t\t(m avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (micro)".format(precision, recall, f1))
|
||||
# logger.info(
|
||||
# "\t\t(M avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (Macro)\n".format(
|
||||
# scores["ALL"]["Macro_p"], scores["ALL"]["Macro_r"], scores["ALL"]["Macro_f1"]
|
||||
# )
|
||||
# )
|
||||
|
||||
# for rel_type in relation_types:
|
||||
# logger.info(
|
||||
# "\t{}: \tTP: {};\tFP: {};\tFN: {};\tprecision: {:.2f};\trecall: {:.2f};\tf1: {:.2f};\t{}".format(
|
||||
# rel_type,
|
||||
# scores[rel_type]["tp"],
|
||||
# scores[rel_type]["fp"],
|
||||
# scores[rel_type]["fn"],
|
||||
# scores[rel_type]["p"],
|
||||
# scores[rel_type]["r"],
|
||||
# scores[rel_type]["f1"],
|
||||
# scores[rel_type]["tp"] + scores[rel_type]["fp"],
|
||||
# )
|
||||
# )
|
||||
|
||||
return scores
|
|
@ -1,3 +1,4 @@
|
|||
sentencepiece
|
||||
yacs
|
||||
seqeval
|
||||
seqeval
|
||||
paddlenlp>=2.2.1
|
|
@ -1,229 +0,0 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
import random
|
||||
import time
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
|
||||
|
||||
from xfun import XFUNDataset
|
||||
from vqa_utils import parse_args, get_bio_label_maps, print_arguments, set_seed
|
||||
from data_collator import DataCollator
|
||||
from eval_re import evaluate
|
||||
|
||||
from ppocr.utils.logging import get_logger
|
||||
|
||||
|
||||
def train(args):
|
||||
logger = get_logger(log_file=os.path.join(args.output_dir, "train.log"))
|
||||
rank = paddle.distributed.get_rank()
|
||||
distributed = paddle.distributed.get_world_size() > 1
|
||||
|
||||
print_arguments(args, logger)
|
||||
|
||||
# Added here for reproducibility (even between python 2 and 3)
|
||||
set_seed(args.seed)
|
||||
|
||||
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
|
||||
pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
|
||||
|
||||
# dist mode
|
||||
if distributed:
|
||||
paddle.distributed.init_parallel_env()
|
||||
|
||||
tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
|
||||
if not args.resume:
|
||||
model = LayoutXLMModel.from_pretrained(args.model_name_or_path)
|
||||
model = LayoutXLMForRelationExtraction(model, dropout=None)
|
||||
logger.info('train from scratch')
|
||||
else:
|
||||
logger.info('resume from {}'.format(args.model_name_or_path))
|
||||
model = LayoutXLMForRelationExtraction.from_pretrained(
|
||||
args.model_name_or_path)
|
||||
|
||||
# dist mode
|
||||
if distributed:
|
||||
model = paddle.DataParallel(model)
|
||||
|
||||
train_dataset = XFUNDataset(
|
||||
tokenizer,
|
||||
data_dir=args.train_data_dir,
|
||||
label_path=args.train_label_path,
|
||||
label2id_map=label2id_map,
|
||||
img_size=(224, 224),
|
||||
max_seq_len=args.max_seq_length,
|
||||
pad_token_label_id=pad_token_label_id,
|
||||
contains_re=True,
|
||||
add_special_ids=False,
|
||||
return_attention_mask=True,
|
||||
load_mode='all')
|
||||
|
||||
eval_dataset = XFUNDataset(
|
||||
tokenizer,
|
||||
data_dir=args.eval_data_dir,
|
||||
label_path=args.eval_label_path,
|
||||
label2id_map=label2id_map,
|
||||
img_size=(224, 224),
|
||||
max_seq_len=args.max_seq_length,
|
||||
pad_token_label_id=pad_token_label_id,
|
||||
contains_re=True,
|
||||
add_special_ids=False,
|
||||
return_attention_mask=True,
|
||||
load_mode='all')
|
||||
|
||||
train_sampler = paddle.io.DistributedBatchSampler(
|
||||
train_dataset, batch_size=args.per_gpu_train_batch_size, shuffle=True)
|
||||
|
||||
train_dataloader = paddle.io.DataLoader(
|
||||
train_dataset,
|
||||
batch_sampler=train_sampler,
|
||||
num_workers=args.num_workers,
|
||||
use_shared_memory=True,
|
||||
collate_fn=DataCollator())
|
||||
|
||||
eval_dataloader = paddle.io.DataLoader(
|
||||
eval_dataset,
|
||||
batch_size=args.per_gpu_eval_batch_size,
|
||||
num_workers=args.num_workers,
|
||||
shuffle=False,
|
||||
collate_fn=DataCollator())
|
||||
|
||||
t_total = len(train_dataloader) * args.num_train_epochs
|
||||
|
||||
# build linear decay with warmup lr sch
|
||||
lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
|
||||
learning_rate=args.learning_rate,
|
||||
decay_steps=t_total,
|
||||
end_lr=0.0,
|
||||
power=1.0)
|
||||
if args.warmup_steps > 0:
|
||||
lr_scheduler = paddle.optimizer.lr.LinearWarmup(
|
||||
lr_scheduler,
|
||||
args.warmup_steps,
|
||||
start_lr=0,
|
||||
end_lr=args.learning_rate, )
|
||||
grad_clip = paddle.nn.ClipGradByNorm(clip_norm=10)
|
||||
optimizer = paddle.optimizer.Adam(
|
||||
learning_rate=args.learning_rate,
|
||||
parameters=model.parameters(),
|
||||
epsilon=args.adam_epsilon,
|
||||
grad_clip=grad_clip,
|
||||
weight_decay=args.weight_decay)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = {}".format(len(train_dataset)))
|
||||
logger.info(" Num Epochs = {}".format(args.num_train_epochs))
|
||||
logger.info(" Instantaneous batch size per GPU = {}".format(
|
||||
args.per_gpu_train_batch_size))
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = {}".
|
||||
format(args.per_gpu_train_batch_size *
|
||||
paddle.distributed.get_world_size()))
|
||||
logger.info(" Total optimization steps = {}".format(t_total))
|
||||
|
||||
global_step = 0
|
||||
model.clear_gradients()
|
||||
train_dataloader_len = len(train_dataloader)
|
||||
best_metirc = {'f1': 0}
|
||||
model.train()
|
||||
|
||||
train_reader_cost = 0.0
|
||||
train_run_cost = 0.0
|
||||
total_samples = 0
|
||||
reader_start = time.time()
|
||||
|
||||
print_step = 1
|
||||
|
||||
for epoch in range(int(args.num_train_epochs)):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
train_reader_cost += time.time() - reader_start
|
||||
train_start = time.time()
|
||||
outputs = model(**batch)
|
||||
train_run_cost += time.time() - train_start
|
||||
# model outputs are always tuple in ppnlp (see doc)
|
||||
loss = outputs['loss']
|
||||
loss = loss.mean()
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.clear_grad()
|
||||
# lr_scheduler.step() # Update learning rate schedule
|
||||
|
||||
global_step += 1
|
||||
total_samples += batch['image'].shape[0]
|
||||
|
||||
if rank == 0 and step % print_step == 0:
|
||||
logger.info(
|
||||
"epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {:.6f}, lr: {:.6f}, avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec".
|
||||
format(epoch, args.num_train_epochs, step,
|
||||
train_dataloader_len, global_step,
|
||||
np.mean(loss.numpy()),
|
||||
optimizer.get_lr(), train_reader_cost / print_step, (
|
||||
train_reader_cost + train_run_cost) / print_step,
|
||||
total_samples / print_step, total_samples / (
|
||||
train_reader_cost + train_run_cost)))
|
||||
|
||||
train_reader_cost = 0.0
|
||||
train_run_cost = 0.0
|
||||
total_samples = 0
|
||||
|
||||
if rank == 0 and args.eval_steps > 0 and global_step % args.eval_steps == 0 and args.evaluate_during_training:
|
||||
# Log metrics
|
||||
# Only evaluate when single GPU otherwise metrics may not average well
|
||||
results = evaluate(model, eval_dataloader, logger)
|
||||
if results['f1'] >= best_metirc['f1']:
|
||||
best_metirc = results
|
||||
output_dir = os.path.join(args.output_dir, "best_model")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
if distributed:
|
||||
model._layers.save_pretrained(output_dir)
|
||||
else:
|
||||
model.save_pretrained(output_dir)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
paddle.save(args,
|
||||
os.path.join(output_dir, "training_args.bin"))
|
||||
logger.info("Saving model checkpoint to {}".format(
|
||||
output_dir))
|
||||
logger.info("eval results: {}".format(results))
|
||||
logger.info("best_metirc: {}".format(best_metirc))
|
||||
reader_start = time.time()
|
||||
|
||||
if rank == 0:
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args.output_dir, "latest_model")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
if distributed:
|
||||
model._layers.save_pretrained(output_dir)
|
||||
else:
|
||||
model.save_pretrained(output_dir)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
paddle.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||
logger.info("Saving model checkpoint to {}".format(output_dir))
|
||||
logger.info("best_metirc: {}".format(best_metirc))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
train(args)
|
|
@ -1,248 +0,0 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
import random
|
||||
import time
|
||||
import copy
|
||||
import logging
|
||||
|
||||
import argparse
|
||||
import paddle
|
||||
import numpy as np
|
||||
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
|
||||
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
|
||||
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
|
||||
|
||||
from xfun import XFUNDataset
|
||||
from vqa_utils import parse_args, get_bio_label_maps, print_arguments, set_seed
|
||||
from eval_ser import evaluate
|
||||
from losses import SERLoss
|
||||
from ppocr.utils.logging import get_logger
|
||||
|
||||
MODELS = {
|
||||
'LayoutXLM':
|
||||
(LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
|
||||
'LayoutLM':
|
||||
(LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
|
||||
}
|
||||
|
||||
|
||||
def train(args):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
rank = paddle.distributed.get_rank()
|
||||
distributed = paddle.distributed.get_world_size() > 1
|
||||
|
||||
logger = get_logger(log_file=os.path.join(args.output_dir, "train.log"))
|
||||
print_arguments(args, logger)
|
||||
|
||||
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
|
||||
loss_class = SERLoss(len(label2id_map))
|
||||
|
||||
pad_token_label_id = loss_class.ignore_index
|
||||
|
||||
# dist mode
|
||||
if distributed:
|
||||
paddle.distributed.init_parallel_env()
|
||||
|
||||
tokenizer_class, base_model_class, model_class = MODELS[args.ser_model_type]
|
||||
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
|
||||
if not args.resume:
|
||||
base_model = base_model_class.from_pretrained(args.model_name_or_path)
|
||||
model = model_class(
|
||||
base_model, num_classes=len(label2id_map), dropout=None)
|
||||
logger.info('train from scratch')
|
||||
else:
|
||||
logger.info('resume from {}'.format(args.model_name_or_path))
|
||||
model = model_class.from_pretrained(args.model_name_or_path)
|
||||
|
||||
# dist mode
|
||||
if distributed:
|
||||
model = paddle.DataParallel(model)
|
||||
|
||||
train_dataset = XFUNDataset(
|
||||
tokenizer,
|
||||
data_dir=args.train_data_dir,
|
||||
label_path=args.train_label_path,
|
||||
label2id_map=label2id_map,
|
||||
img_size=(224, 224),
|
||||
pad_token_label_id=pad_token_label_id,
|
||||
contains_re=False,
|
||||
add_special_ids=False,
|
||||
return_attention_mask=True,
|
||||
load_mode='all')
|
||||
eval_dataset = XFUNDataset(
|
||||
tokenizer,
|
||||
data_dir=args.eval_data_dir,
|
||||
label_path=args.eval_label_path,
|
||||
label2id_map=label2id_map,
|
||||
img_size=(224, 224),
|
||||
pad_token_label_id=pad_token_label_id,
|
||||
contains_re=False,
|
||||
add_special_ids=False,
|
||||
return_attention_mask=True,
|
||||
load_mode='all')
|
||||
|
||||
train_sampler = paddle.io.DistributedBatchSampler(
|
||||
train_dataset, batch_size=args.per_gpu_train_batch_size, shuffle=True)
|
||||
|
||||
train_dataloader = paddle.io.DataLoader(
|
||||
train_dataset,
|
||||
batch_sampler=train_sampler,
|
||||
num_workers=args.num_workers,
|
||||
use_shared_memory=True,
|
||||
collate_fn=None, )
|
||||
|
||||
eval_dataloader = paddle.io.DataLoader(
|
||||
eval_dataset,
|
||||
batch_size=args.per_gpu_eval_batch_size,
|
||||
num_workers=args.num_workers,
|
||||
use_shared_memory=True,
|
||||
collate_fn=None, )
|
||||
|
||||
t_total = len(train_dataloader) * args.num_train_epochs
|
||||
|
||||
# build linear decay with warmup lr sch
|
||||
lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
|
||||
learning_rate=args.learning_rate,
|
||||
decay_steps=t_total,
|
||||
end_lr=0.0,
|
||||
power=1.0)
|
||||
if args.warmup_steps > 0:
|
||||
lr_scheduler = paddle.optimizer.lr.LinearWarmup(
|
||||
lr_scheduler,
|
||||
args.warmup_steps,
|
||||
start_lr=0,
|
||||
end_lr=args.learning_rate, )
|
||||
|
||||
optimizer = paddle.optimizer.AdamW(
|
||||
learning_rate=lr_scheduler,
|
||||
parameters=model.parameters(),
|
||||
epsilon=args.adam_epsilon,
|
||||
weight_decay=args.weight_decay)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataset))
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per GPU = %d",
|
||||
args.per_gpu_train_batch_size)
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed) = %d",
|
||||
args.per_gpu_train_batch_size * paddle.distributed.get_world_size(), )
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
global_step = 0
|
||||
tr_loss = 0.0
|
||||
set_seed(args.seed)
|
||||
best_metrics = None
|
||||
|
||||
train_reader_cost = 0.0
|
||||
train_run_cost = 0.0
|
||||
total_samples = 0
|
||||
reader_start = time.time()
|
||||
|
||||
print_step = 1
|
||||
model.train()
|
||||
for epoch_id in range(args.num_train_epochs):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
train_reader_cost += time.time() - reader_start
|
||||
|
||||
if args.ser_model_type == 'LayoutLM':
|
||||
if 'image' in batch:
|
||||
batch.pop('image')
|
||||
labels = batch.pop('labels')
|
||||
|
||||
train_start = time.time()
|
||||
outputs = model(**batch)
|
||||
train_run_cost += time.time() - train_start
|
||||
if args.ser_model_type == 'LayoutXLM':
|
||||
outputs = outputs[0]
|
||||
loss = loss_class(labels, outputs, batch['attention_mask'])
|
||||
|
||||
# model outputs are always tuple in ppnlp (see doc)
|
||||
loss = loss.mean()
|
||||
loss.backward()
|
||||
tr_loss += loss.item()
|
||||
optimizer.step()
|
||||
lr_scheduler.step() # Update learning rate schedule
|
||||
optimizer.clear_grad()
|
||||
global_step += 1
|
||||
total_samples += batch['input_ids'].shape[0]
|
||||
|
||||
if rank == 0 and step % print_step == 0:
|
||||
logger.info(
|
||||
"epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {:.6f}, lr: {:.6f}, avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec".
|
||||
format(epoch_id, args.num_train_epochs, step,
|
||||
len(train_dataloader), global_step,
|
||||
loss.numpy()[0],
|
||||
lr_scheduler.get_lr(), train_reader_cost /
|
||||
print_step, (train_reader_cost + train_run_cost) /
|
||||
print_step, total_samples / print_step, total_samples
|
||||
/ (train_reader_cost + train_run_cost)))
|
||||
|
||||
train_reader_cost = 0.0
|
||||
train_run_cost = 0.0
|
||||
total_samples = 0
|
||||
|
||||
if rank == 0 and args.eval_steps > 0 and global_step % args.eval_steps == 0 and args.evaluate_during_training:
|
||||
# Log metrics
|
||||
# Only evaluate when single GPU otherwise metrics may not average well
|
||||
results, _ = evaluate(args, model, tokenizer, loss_class,
|
||||
eval_dataloader, label2id_map,
|
||||
id2label_map, pad_token_label_id, logger)
|
||||
|
||||
if best_metrics is None or results["f1"] >= best_metrics["f1"]:
|
||||
best_metrics = copy.deepcopy(results)
|
||||
output_dir = os.path.join(args.output_dir, "best_model")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
if distributed:
|
||||
model._layers.save_pretrained(output_dir)
|
||||
else:
|
||||
model.save_pretrained(output_dir)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
paddle.save(args,
|
||||
os.path.join(output_dir, "training_args.bin"))
|
||||
logger.info("Saving model checkpoint to {}".format(
|
||||
output_dir))
|
||||
|
||||
logger.info("[epoch {}/{}][iter: {}/{}] results: {}".format(
|
||||
epoch_id, args.num_train_epochs, step,
|
||||
len(train_dataloader), results))
|
||||
if best_metrics is not None:
|
||||
logger.info("best metrics: {}".format(best_metrics))
|
||||
reader_start = time.time()
|
||||
if rank == 0:
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args.output_dir, "latest_model")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
if distributed:
|
||||
model._layers.save_pretrained(output_dir)
|
||||
else:
|
||||
model.save_pretrained(output_dir)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
paddle.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||
logger.info("Saving model checkpoint to {}".format(output_dir))
|
||||
return global_step, tr_loss / global_step
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
train(args)
|
|
@ -1,400 +0,0 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import cv2
|
||||
import random
|
||||
import numpy as np
|
||||
import imghdr
|
||||
from copy import deepcopy
|
||||
|
||||
import paddle
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
paddle.seed(seed)
|
||||
|
||||
|
||||
def get_bio_label_maps(label_map_path):
|
||||
with open(label_map_path, "r", encoding='utf-8') as fin:
|
||||
lines = fin.readlines()
|
||||
lines = [line.strip() for line in lines]
|
||||
if "O" not in lines:
|
||||
lines.insert(0, "O")
|
||||
labels = []
|
||||
for line in lines:
|
||||
if line == "O":
|
||||
labels.append("O")
|
||||
else:
|
||||
labels.append("B-" + line)
|
||||
labels.append("I-" + line)
|
||||
label2id_map = {label: idx for idx, label in enumerate(labels)}
|
||||
id2label_map = {idx: label for idx, label in enumerate(labels)}
|
||||
return label2id_map, id2label_map
|
||||
|
||||
|
||||
def get_image_file_list(img_file):
|
||||
imgs_lists = []
|
||||
if img_file is None or not os.path.exists(img_file):
|
||||
raise Exception("not found any img file in {}".format(img_file))
|
||||
|
||||
img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'GIF'}
|
||||
if os.path.isfile(img_file) and imghdr.what(img_file) in img_end:
|
||||
imgs_lists.append(img_file)
|
||||
elif os.path.isdir(img_file):
|
||||
for single_file in os.listdir(img_file):
|
||||
file_path = os.path.join(img_file, single_file)
|
||||
if os.path.isfile(file_path) and imghdr.what(file_path) in img_end:
|
||||
imgs_lists.append(file_path)
|
||||
if len(imgs_lists) == 0:
|
||||
raise Exception("not found any img file in {}".format(img_file))
|
||||
imgs_lists = sorted(imgs_lists)
|
||||
return imgs_lists
|
||||
|
||||
|
||||
def draw_ser_results(image,
|
||||
ocr_results,
|
||||
font_path="../../doc/fonts/simfang.ttf",
|
||||
font_size=18):
|
||||
np.random.seed(2021)
|
||||
color = (np.random.permutation(range(255)),
|
||||
np.random.permutation(range(255)),
|
||||
np.random.permutation(range(255)))
|
||||
color_map = {
|
||||
idx: (color[0][idx], color[1][idx], color[2][idx])
|
||||
for idx in range(1, 255)
|
||||
}
|
||||
if isinstance(image, np.ndarray):
|
||||
image = Image.fromarray(image)
|
||||
img_new = image.copy()
|
||||
draw = ImageDraw.Draw(img_new)
|
||||
|
||||
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
|
||||
for ocr_info in ocr_results:
|
||||
if ocr_info["pred_id"] not in color_map:
|
||||
continue
|
||||
color = color_map[ocr_info["pred_id"]]
|
||||
text = "{}: {}".format(ocr_info["pred"], ocr_info["text"])
|
||||
|
||||
draw_box_txt(ocr_info["bbox"], text, draw, font, font_size, color)
|
||||
|
||||
img_new = Image.blend(image, img_new, 0.5)
|
||||
return np.array(img_new)
|
||||
|
||||
|
||||
def draw_box_txt(bbox, text, draw, font, font_size, color):
|
||||
# draw ocr results outline
|
||||
bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3]))
|
||||
draw.rectangle(bbox, fill=color)
|
||||
|
||||
# draw ocr results
|
||||
start_y = max(0, bbox[0][1] - font_size)
|
||||
tw = font.getsize(text)[0]
|
||||
draw.rectangle(
|
||||
[(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1, start_y + font_size)],
|
||||
fill=(0, 0, 255))
|
||||
draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)
|
||||
|
||||
|
||||
def draw_re_results(image,
|
||||
result,
|
||||
font_path="../../doc/fonts/simfang.ttf",
|
||||
font_size=18):
|
||||
np.random.seed(0)
|
||||
if isinstance(image, np.ndarray):
|
||||
image = Image.fromarray(image)
|
||||
img_new = image.copy()
|
||||
draw = ImageDraw.Draw(img_new)
|
||||
|
||||
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
|
||||
color_head = (0, 0, 255)
|
||||
color_tail = (255, 0, 0)
|
||||
color_line = (0, 255, 0)
|
||||
|
||||
for ocr_info_head, ocr_info_tail in result:
|
||||
draw_box_txt(ocr_info_head["bbox"], ocr_info_head["text"], draw, font,
|
||||
font_size, color_head)
|
||||
draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["text"], draw, font,
|
||||
font_size, color_tail)
|
||||
|
||||
center_head = (
|
||||
(ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
|
||||
(ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2)
|
||||
center_tail = (
|
||||
(ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2,
|
||||
(ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2)
|
||||
|
||||
draw.line([center_head, center_tail], fill=color_line, width=5)
|
||||
|
||||
img_new = Image.blend(image, img_new, 0.5)
|
||||
return np.array(img_new)
|
||||
|
||||
|
||||
# pad sentences
|
||||
def pad_sentences(tokenizer,
|
||||
encoded_inputs,
|
||||
max_seq_len=512,
|
||||
pad_to_max_seq_len=True,
|
||||
return_attention_mask=True,
|
||||
return_token_type_ids=True,
|
||||
return_overflowing_tokens=False,
|
||||
return_special_tokens_mask=False):
|
||||
# Padding with larger size, reshape is carried out
|
||||
max_seq_len = (
|
||||
len(encoded_inputs["input_ids"]) // max_seq_len + 1) * max_seq_len
|
||||
|
||||
needs_to_be_padded = pad_to_max_seq_len and \
|
||||
max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len
|
||||
|
||||
if needs_to_be_padded:
|
||||
difference = max_seq_len - len(encoded_inputs["input_ids"])
|
||||
if tokenizer.padding_side == 'right':
|
||||
if return_attention_mask:
|
||||
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
|
||||
"input_ids"]) + [0] * difference
|
||||
if return_token_type_ids:
|
||||
encoded_inputs["token_type_ids"] = (
|
||||
encoded_inputs["token_type_ids"] +
|
||||
[tokenizer.pad_token_type_id] * difference)
|
||||
if return_special_tokens_mask:
|
||||
encoded_inputs["special_tokens_mask"] = encoded_inputs[
|
||||
"special_tokens_mask"] + [1] * difference
|
||||
encoded_inputs["input_ids"] = encoded_inputs[
|
||||
"input_ids"] + [tokenizer.pad_token_id] * difference
|
||||
encoded_inputs["bbox"] = encoded_inputs["bbox"] + [[0, 0, 0, 0]
|
||||
] * difference
|
||||
else:
|
||||
if return_attention_mask:
|
||||
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
|
||||
"input_ids"])
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
|
||||
def split_page(encoded_inputs, max_seq_len=512):
|
||||
"""
|
||||
truncate is often used in training process
|
||||
"""
|
||||
for key in encoded_inputs:
|
||||
if key == 'entities':
|
||||
encoded_inputs[key] = [encoded_inputs[key]]
|
||||
continue
|
||||
encoded_inputs[key] = paddle.to_tensor(encoded_inputs[key])
|
||||
if encoded_inputs[key].ndim <= 1: # for input_ids, att_mask and so on
|
||||
encoded_inputs[key] = encoded_inputs[key].reshape([-1, max_seq_len])
|
||||
else: # for bbox
|
||||
encoded_inputs[key] = encoded_inputs[key].reshape(
|
||||
[-1, max_seq_len, 4])
|
||||
return encoded_inputs
|
||||
|
||||
|
||||
def preprocess(
|
||||
tokenizer,
|
||||
ori_img,
|
||||
ocr_info,
|
||||
img_size=(224, 224),
|
||||
pad_token_label_id=-100,
|
||||
max_seq_len=512,
|
||||
add_special_ids=False,
|
||||
return_attention_mask=True, ):
|
||||
ocr_info = deepcopy(ocr_info)
|
||||
height = ori_img.shape[0]
|
||||
width = ori_img.shape[1]
|
||||
|
||||
img = cv2.resize(ori_img, img_size).transpose([2, 0, 1]).astype(np.float32)
|
||||
|
||||
segment_offset_id = []
|
||||
words_list = []
|
||||
bbox_list = []
|
||||
input_ids_list = []
|
||||
token_type_ids_list = []
|
||||
entities = []
|
||||
|
||||
for info in ocr_info:
|
||||
# x1, y1, x2, y2
|
||||
bbox = info["bbox"]
|
||||
bbox[0] = int(bbox[0] * 1000.0 / width)
|
||||
bbox[2] = int(bbox[2] * 1000.0 / width)
|
||||
bbox[1] = int(bbox[1] * 1000.0 / height)
|
||||
bbox[3] = int(bbox[3] * 1000.0 / height)
|
||||
|
||||
text = info["text"]
|
||||
encode_res = tokenizer.encode(
|
||||
text, pad_to_max_seq_len=False, return_attention_mask=True)
|
||||
|
||||
if not add_special_ids:
|
||||
# TODO: use tok.all_special_ids to remove
|
||||
encode_res["input_ids"] = encode_res["input_ids"][1:-1]
|
||||
encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1]
|
||||
encode_res["attention_mask"] = encode_res["attention_mask"][1:-1]
|
||||
|
||||
# for re
|
||||
entities.append({
|
||||
"start": len(input_ids_list),
|
||||
"end": len(input_ids_list) + len(encode_res["input_ids"]),
|
||||
"label": "O",
|
||||
})
|
||||
|
||||
input_ids_list.extend(encode_res["input_ids"])
|
||||
token_type_ids_list.extend(encode_res["token_type_ids"])
|
||||
bbox_list.extend([bbox] * len(encode_res["input_ids"]))
|
||||
words_list.append(text)
|
||||
segment_offset_id.append(len(input_ids_list))
|
||||
|
||||
encoded_inputs = {
|
||||
"input_ids": input_ids_list,
|
||||
"token_type_ids": token_type_ids_list,
|
||||
"bbox": bbox_list,
|
||||
"attention_mask": [1] * len(input_ids_list),
|
||||
"entities": entities
|
||||
}
|
||||
|
||||
encoded_inputs = pad_sentences(
|
||||
tokenizer,
|
||||
encoded_inputs,
|
||||
max_seq_len=max_seq_len,
|
||||
return_attention_mask=return_attention_mask)
|
||||
|
||||
encoded_inputs = split_page(encoded_inputs)
|
||||
|
||||
fake_bs = encoded_inputs["input_ids"].shape[0]
|
||||
|
||||
encoded_inputs["image"] = paddle.to_tensor(img).unsqueeze(0).expand(
|
||||
[fake_bs] + list(img.shape))
|
||||
|
||||
encoded_inputs["segment_offset_id"] = segment_offset_id
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
|
||||
def postprocess(attention_mask, preds, id2label_map):
|
||||
if isinstance(preds, paddle.Tensor):
|
||||
preds = preds.numpy()
|
||||
preds = np.argmax(preds, axis=2)
|
||||
|
||||
preds_list = [[] for _ in range(preds.shape[0])]
|
||||
|
||||
# keep batch info
|
||||
for i in range(preds.shape[0]):
|
||||
for j in range(preds.shape[1]):
|
||||
if attention_mask[i][j] == 1:
|
||||
preds_list[i].append(id2label_map[preds[i][j]])
|
||||
|
||||
return preds_list
|
||||
|
||||
|
||||
def merge_preds_list_with_ocr_info(ocr_info, segment_offset_id, preds_list,
|
||||
label2id_map_for_draw):
|
||||
# must ensure the preds_list is generated from the same image
|
||||
preds = [p for pred in preds_list for p in pred]
|
||||
|
||||
id2label_map = dict()
|
||||
for key in label2id_map_for_draw:
|
||||
val = label2id_map_for_draw[key]
|
||||
if key == "O":
|
||||
id2label_map[val] = key
|
||||
if key.startswith("B-") or key.startswith("I-"):
|
||||
id2label_map[val] = key[2:]
|
||||
else:
|
||||
id2label_map[val] = key
|
||||
|
||||
for idx in range(len(segment_offset_id)):
|
||||
if idx == 0:
|
||||
start_id = 0
|
||||
else:
|
||||
start_id = segment_offset_id[idx - 1]
|
||||
|
||||
end_id = segment_offset_id[idx]
|
||||
|
||||
curr_pred = preds[start_id:end_id]
|
||||
curr_pred = [label2id_map_for_draw[p] for p in curr_pred]
|
||||
|
||||
if len(curr_pred) <= 0:
|
||||
pred_id = 0
|
||||
else:
|
||||
counts = np.bincount(curr_pred)
|
||||
pred_id = np.argmax(counts)
|
||||
ocr_info[idx]["pred_id"] = int(pred_id)
|
||||
ocr_info[idx]["pred"] = id2label_map[int(pred_id)]
|
||||
return ocr_info
|
||||
|
||||
|
||||
def print_arguments(args, logger=None):
|
||||
print_func = logger.info if logger is not None else print
|
||||
"""print arguments"""
|
||||
print_func('----------- Configuration Arguments -----------')
|
||||
for arg, value in sorted(vars(args).items()):
|
||||
print_func('%s: %s' % (arg, value))
|
||||
print_func('------------------------------------------------')
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
# yapf: disable
|
||||
parser.add_argument("--model_name_or_path",
|
||||
default=None, type=str, required=True,)
|
||||
parser.add_argument("--ser_model_type",
|
||||
default='LayoutXLM', type=str)
|
||||
parser.add_argument("--re_model_name_or_path",
|
||||
default=None, type=str, required=False,)
|
||||
parser.add_argument("--train_data_dir", default=None,
|
||||
type=str, required=False,)
|
||||
parser.add_argument("--train_label_path", default=None,
|
||||
type=str, required=False,)
|
||||
parser.add_argument("--eval_data_dir", default=None,
|
||||
type=str, required=False,)
|
||||
parser.add_argument("--eval_label_path", default=None,
|
||||
type=str, required=False,)
|
||||
parser.add_argument("--output_dir", default=None, type=str, required=True,)
|
||||
parser.add_argument("--max_seq_length", default=512, type=int,)
|
||||
parser.add_argument("--evaluate_during_training", action="store_true",)
|
||||
parser.add_argument("--num_workers", default=8, type=int,)
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8,
|
||||
type=int, help="Batch size per GPU/CPU for training.",)
|
||||
parser.add_argument("--per_gpu_eval_batch_size", default=8,
|
||||
type=int, help="Batch size per GPU/CPU for eval.",)
|
||||
parser.add_argument("--learning_rate", default=5e-5,
|
||||
type=float, help="The initial learning rate for Adam.",)
|
||||
parser.add_argument("--weight_decay", default=0.0,
|
||||
type=float, help="Weight decay if we apply some.",)
|
||||
parser.add_argument("--adam_epsilon", default=1e-8,
|
||||
type=float, help="Epsilon for Adam optimizer.",)
|
||||
parser.add_argument("--max_grad_norm", default=1.0,
|
||||
type=float, help="Max gradient norm.",)
|
||||
parser.add_argument("--num_train_epochs", default=3, type=int,
|
||||
help="Total number of training epochs to perform.",)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int,
|
||||
help="Linear warmup over warmup_steps.",)
|
||||
parser.add_argument("--eval_steps", type=int, default=10,
|
||||
help="eval every X updates steps.",)
|
||||
parser.add_argument("--seed", type=int, default=2048,
|
||||
help="random seed for initialization",)
|
||||
|
||||
parser.add_argument("--rec_model_dir", default=None, type=str, )
|
||||
parser.add_argument("--det_model_dir", default=None, type=str, )
|
||||
parser.add_argument(
|
||||
"--label_map_path", default="./labels/labels_ser.txt", type=str, required=False, )
|
||||
parser.add_argument("--infer_imgs", default=None, type=str, required=False)
|
||||
parser.add_argument("--resume", action='store_true')
|
||||
parser.add_argument("--ocr_json_path", default=None,
|
||||
type=str, required=False, help="ocr prediction results")
|
||||
# yapf: enable
|
||||
args = parser.parse_args()
|
||||
return args
|
|
@ -1,464 +0,0 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import paddle
|
||||
import copy
|
||||
from paddle.io import Dataset
|
||||
|
||||
__all__ = ["XFUNDataset"]
|
||||
|
||||
|
||||
class XFUNDataset(Dataset):
|
||||
"""
|
||||
Example:
|
||||
print("=====begin to build dataset=====")
|
||||
from paddlenlp.transformers import LayoutXLMTokenizer
|
||||
tokenizer = LayoutXLMTokenizer.from_pretrained("/paddle/models/transformers/layoutxlm-base-paddle/")
|
||||
tok_res = tokenizer.tokenize("Maribyrnong")
|
||||
# res = tokenizer.convert_ids_to_tokens(val_data["input_ids"][0])
|
||||
dataset = XfunDatasetForSer(
|
||||
tokenizer,
|
||||
data_dir="./zh.val/",
|
||||
label_path="zh.val/xfun_normalize_val.json",
|
||||
img_size=(224,224))
|
||||
print(len(dataset))
|
||||
|
||||
data = dataset[0]
|
||||
print(data.keys())
|
||||
print("input_ids: ", data["input_ids"])
|
||||
print("labels: ", data["labels"])
|
||||
print("token_type_ids: ", data["token_type_ids"])
|
||||
print("words_list: ", data["words_list"])
|
||||
print("image shape: ", data["image"].shape)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
tokenizer,
|
||||
data_dir,
|
||||
label_path,
|
||||
contains_re=False,
|
||||
label2id_map=None,
|
||||
img_size=(224, 224),
|
||||
pad_token_label_id=None,
|
||||
add_special_ids=False,
|
||||
return_attention_mask=True,
|
||||
load_mode='all',
|
||||
max_seq_len=512):
|
||||
super().__init__()
|
||||
self.tokenizer = tokenizer
|
||||
self.data_dir = data_dir
|
||||
self.label_path = label_path
|
||||
self.contains_re = contains_re
|
||||
self.label2id_map = label2id_map
|
||||
self.img_size = img_size
|
||||
self.pad_token_label_id = pad_token_label_id
|
||||
self.add_special_ids = add_special_ids
|
||||
self.return_attention_mask = return_attention_mask
|
||||
self.load_mode = load_mode
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
if self.pad_token_label_id is None:
|
||||
self.pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
|
||||
|
||||
self.all_lines = self.read_all_lines()
|
||||
|
||||
self.entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2}
|
||||
self.return_keys = {
|
||||
'bbox': {
|
||||
'type': 'np',
|
||||
'dtype': 'int64'
|
||||
},
|
||||
'input_ids': {
|
||||
'type': 'np',
|
||||
'dtype': 'int64'
|
||||
},
|
||||
'labels': {
|
||||
'type': 'np',
|
||||
'dtype': 'int64'
|
||||
},
|
||||
'attention_mask': {
|
||||
'type': 'np',
|
||||
'dtype': 'int64'
|
||||
},
|
||||
'image': {
|
||||
'type': 'np',
|
||||
'dtype': 'float32'
|
||||
},
|
||||
'token_type_ids': {
|
||||
'type': 'np',
|
||||
'dtype': 'int64'
|
||||
},
|
||||
'entities': {
|
||||
'type': 'dict'
|
||||
},
|
||||
'relations': {
|
||||
'type': 'dict'
|
||||
}
|
||||
}
|
||||
|
||||
if load_mode == "all":
|
||||
self.encoded_inputs_all = self._parse_label_file_all()
|
||||
|
||||
def pad_sentences(self,
|
||||
encoded_inputs,
|
||||
max_seq_len=512,
|
||||
pad_to_max_seq_len=True,
|
||||
return_attention_mask=True,
|
||||
return_token_type_ids=True,
|
||||
truncation_strategy="longest_first",
|
||||
return_overflowing_tokens=False,
|
||||
return_special_tokens_mask=False):
|
||||
# Padding
|
||||
needs_to_be_padded = pad_to_max_seq_len and \
|
||||
max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len
|
||||
|
||||
if needs_to_be_padded:
|
||||
difference = max_seq_len - len(encoded_inputs["input_ids"])
|
||||
if self.tokenizer.padding_side == 'right':
|
||||
if return_attention_mask:
|
||||
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
|
||||
"input_ids"]) + [0] * difference
|
||||
if return_token_type_ids:
|
||||
encoded_inputs["token_type_ids"] = (
|
||||
encoded_inputs["token_type_ids"] +
|
||||
[self.tokenizer.pad_token_type_id] * difference)
|
||||
if return_special_tokens_mask:
|
||||
encoded_inputs["special_tokens_mask"] = encoded_inputs[
|
||||
"special_tokens_mask"] + [1] * difference
|
||||
encoded_inputs["input_ids"] = encoded_inputs[
|
||||
"input_ids"] + [self.tokenizer.pad_token_id] * difference
|
||||
encoded_inputs["labels"] = encoded_inputs[
|
||||
"labels"] + [self.pad_token_label_id] * difference
|
||||
encoded_inputs["bbox"] = encoded_inputs[
|
||||
"bbox"] + [[0, 0, 0, 0]] * difference
|
||||
elif self.tokenizer.padding_side == 'left':
|
||||
if return_attention_mask:
|
||||
encoded_inputs["attention_mask"] = [0] * difference + [
|
||||
1
|
||||
] * len(encoded_inputs["input_ids"])
|
||||
if return_token_type_ids:
|
||||
encoded_inputs["token_type_ids"] = (
|
||||
[self.tokenizer.pad_token_type_id] * difference +
|
||||
encoded_inputs["token_type_ids"])
|
||||
if return_special_tokens_mask:
|
||||
encoded_inputs["special_tokens_mask"] = [
|
||||
1
|
||||
] * difference + encoded_inputs["special_tokens_mask"]
|
||||
encoded_inputs["input_ids"] = [
|
||||
self.tokenizer.pad_token_id
|
||||
] * difference + encoded_inputs["input_ids"]
|
||||
encoded_inputs["labels"] = [
|
||||
self.pad_token_label_id
|
||||
] * difference + encoded_inputs["labels"]
|
||||
encoded_inputs["bbox"] = [
|
||||
[0, 0, 0, 0]
|
||||
] * difference + encoded_inputs["bbox"]
|
||||
else:
|
||||
if return_attention_mask:
|
||||
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
|
||||
"input_ids"])
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
def truncate_inputs(self, encoded_inputs, max_seq_len=512):
|
||||
for key in encoded_inputs:
|
||||
if key == "sample_id":
|
||||
continue
|
||||
length = min(len(encoded_inputs[key]), max_seq_len)
|
||||
encoded_inputs[key] = encoded_inputs[key][:length]
|
||||
return encoded_inputs
|
||||
|
||||
def read_all_lines(self, ):
|
||||
with open(self.label_path, "r", encoding='utf-8') as fin:
|
||||
lines = fin.readlines()
|
||||
return lines
|
||||
|
||||
def _parse_label_file_all(self):
|
||||
"""
|
||||
parse all samples
|
||||
"""
|
||||
encoded_inputs_all = []
|
||||
for line in self.all_lines:
|
||||
encoded_inputs_all.extend(self._parse_label_file(line))
|
||||
return encoded_inputs_all
|
||||
|
||||
def _parse_label_file(self, line):
|
||||
"""
|
||||
parse single sample
|
||||
"""
|
||||
|
||||
image_name, info_str = line.split("\t")
|
||||
image_path = os.path.join(self.data_dir, image_name)
|
||||
|
||||
def add_imgge_path(x):
|
||||
x['image_path'] = image_path
|
||||
return x
|
||||
|
||||
encoded_inputs = self._read_encoded_inputs_sample(info_str)
|
||||
if self.contains_re:
|
||||
encoded_inputs = self._chunk_re(encoded_inputs)
|
||||
else:
|
||||
encoded_inputs = self._chunk_ser(encoded_inputs)
|
||||
encoded_inputs = list(map(add_imgge_path, encoded_inputs))
|
||||
return encoded_inputs
|
||||
|
||||
def _read_encoded_inputs_sample(self, info_str):
|
||||
"""
|
||||
parse label info
|
||||
"""
|
||||
# read text info
|
||||
info_dict = json.loads(info_str)
|
||||
height = info_dict["height"]
|
||||
width = info_dict["width"]
|
||||
|
||||
words_list = []
|
||||
bbox_list = []
|
||||
input_ids_list = []
|
||||
token_type_ids_list = []
|
||||
gt_label_list = []
|
||||
|
||||
if self.contains_re:
|
||||
# for re
|
||||
entities = []
|
||||
relations = []
|
||||
id2label = {}
|
||||
entity_id_to_index_map = {}
|
||||
empty_entity = set()
|
||||
for info in info_dict["ocr_info"]:
|
||||
if self.contains_re:
|
||||
# for re
|
||||
if len(info["text"]) == 0:
|
||||
empty_entity.add(info["id"])
|
||||
continue
|
||||
id2label[info["id"]] = info["label"]
|
||||
relations.extend([tuple(sorted(l)) for l in info["linking"]])
|
||||
|
||||
# x1, y1, x2, y2
|
||||
bbox = info["bbox"]
|
||||
label = info["label"]
|
||||
bbox[0] = int(bbox[0] * 1000.0 / width)
|
||||
bbox[2] = int(bbox[2] * 1000.0 / width)
|
||||
bbox[1] = int(bbox[1] * 1000.0 / height)
|
||||
bbox[3] = int(bbox[3] * 1000.0 / height)
|
||||
|
||||
text = info["text"]
|
||||
encode_res = self.tokenizer.encode(
|
||||
text, pad_to_max_seq_len=False, return_attention_mask=True)
|
||||
|
||||
gt_label = []
|
||||
if not self.add_special_ids:
|
||||
# TODO: use tok.all_special_ids to remove
|
||||
encode_res["input_ids"] = encode_res["input_ids"][1:-1]
|
||||
encode_res["token_type_ids"] = encode_res["token_type_ids"][1:
|
||||
-1]
|
||||
encode_res["attention_mask"] = encode_res["attention_mask"][1:
|
||||
-1]
|
||||
if label.lower() == "other":
|
||||
gt_label.extend([0] * len(encode_res["input_ids"]))
|
||||
else:
|
||||
gt_label.append(self.label2id_map[("b-" + label).upper()])
|
||||
gt_label.extend([self.label2id_map[("i-" + label).upper()]] *
|
||||
(len(encode_res["input_ids"]) - 1))
|
||||
if self.contains_re:
|
||||
if gt_label[0] != self.label2id_map["O"]:
|
||||
entity_id_to_index_map[info["id"]] = len(entities)
|
||||
entities.append({
|
||||
"start": len(input_ids_list),
|
||||
"end":
|
||||
len(input_ids_list) + len(encode_res["input_ids"]),
|
||||
"label": label.upper(),
|
||||
})
|
||||
input_ids_list.extend(encode_res["input_ids"])
|
||||
token_type_ids_list.extend(encode_res["token_type_ids"])
|
||||
bbox_list.extend([bbox] * len(encode_res["input_ids"]))
|
||||
gt_label_list.extend(gt_label)
|
||||
words_list.append(text)
|
||||
|
||||
encoded_inputs = {
|
||||
"input_ids": input_ids_list,
|
||||
"labels": gt_label_list,
|
||||
"token_type_ids": token_type_ids_list,
|
||||
"bbox": bbox_list,
|
||||
"attention_mask": [1] * len(input_ids_list),
|
||||
# "words_list": words_list,
|
||||
}
|
||||
encoded_inputs = self.pad_sentences(
|
||||
encoded_inputs,
|
||||
max_seq_len=self.max_seq_len,
|
||||
return_attention_mask=self.return_attention_mask)
|
||||
encoded_inputs = self.truncate_inputs(encoded_inputs)
|
||||
|
||||
if self.contains_re:
|
||||
relations = self._relations(entities, relations, id2label,
|
||||
empty_entity, entity_id_to_index_map)
|
||||
encoded_inputs['relations'] = relations
|
||||
encoded_inputs['entities'] = entities
|
||||
return encoded_inputs
|
||||
|
||||
def _chunk_ser(self, encoded_inputs):
|
||||
encoded_inputs_all = []
|
||||
seq_len = len(encoded_inputs['input_ids'])
|
||||
chunk_size = 512
|
||||
for chunk_id, index in enumerate(range(0, seq_len, chunk_size)):
|
||||
chunk_beg = index
|
||||
chunk_end = min(index + chunk_size, seq_len)
|
||||
encoded_inputs_example = {}
|
||||
for key in encoded_inputs:
|
||||
encoded_inputs_example[key] = encoded_inputs[key][chunk_beg:
|
||||
chunk_end]
|
||||
|
||||
encoded_inputs_all.append(encoded_inputs_example)
|
||||
return encoded_inputs_all
|
||||
|
||||
def _chunk_re(self, encoded_inputs):
|
||||
# prepare data
|
||||
entities = encoded_inputs.pop('entities')
|
||||
relations = encoded_inputs.pop('relations')
|
||||
encoded_inputs_all = []
|
||||
chunk_size = 512
|
||||
for chunk_id, index in enumerate(
|
||||
range(0, len(encoded_inputs["input_ids"]), chunk_size)):
|
||||
item = {}
|
||||
for k in encoded_inputs:
|
||||
item[k] = encoded_inputs[k][index:index + chunk_size]
|
||||
|
||||
# select entity in current chunk
|
||||
entities_in_this_span = []
|
||||
global_to_local_map = {} #
|
||||
for entity_id, entity in enumerate(entities):
|
||||
if (index <= entity["start"] < index + chunk_size and
|
||||
index <= entity["end"] < index + chunk_size):
|
||||
entity["start"] = entity["start"] - index
|
||||
entity["end"] = entity["end"] - index
|
||||
global_to_local_map[entity_id] = len(entities_in_this_span)
|
||||
entities_in_this_span.append(entity)
|
||||
|
||||
# select relations in current chunk
|
||||
relations_in_this_span = []
|
||||
for relation in relations:
|
||||
if (index <= relation["start_index"] < index + chunk_size and
|
||||
index <= relation["end_index"] < index + chunk_size):
|
||||
relations_in_this_span.append({
|
||||
"head": global_to_local_map[relation["head"]],
|
||||
"tail": global_to_local_map[relation["tail"]],
|
||||
"start_index": relation["start_index"] - index,
|
||||
"end_index": relation["end_index"] - index,
|
||||
})
|
||||
item.update({
|
||||
"entities": reformat(entities_in_this_span),
|
||||
"relations": reformat(relations_in_this_span),
|
||||
})
|
||||
item['entities']['label'] = [
|
||||
self.entities_labels[x] for x in item['entities']['label']
|
||||
]
|
||||
encoded_inputs_all.append(item)
|
||||
return encoded_inputs_all
|
||||
|
||||
def _relations(self, entities, relations, id2label, empty_entity,
|
||||
entity_id_to_index_map):
|
||||
"""
|
||||
build relations
|
||||
"""
|
||||
relations = list(set(relations))
|
||||
relations = [
|
||||
rel for rel in relations
|
||||
if rel[0] not in empty_entity and rel[1] not in empty_entity
|
||||
]
|
||||
kv_relations = []
|
||||
for rel in relations:
|
||||
pair = [id2label[rel[0]], id2label[rel[1]]]
|
||||
if pair == ["question", "answer"]:
|
||||
kv_relations.append({
|
||||
"head": entity_id_to_index_map[rel[0]],
|
||||
"tail": entity_id_to_index_map[rel[1]]
|
||||
})
|
||||
elif pair == ["answer", "question"]:
|
||||
kv_relations.append({
|
||||
"head": entity_id_to_index_map[rel[1]],
|
||||
"tail": entity_id_to_index_map[rel[0]]
|
||||
})
|
||||
else:
|
||||
continue
|
||||
relations = sorted(
|
||||
[{
|
||||
"head": rel["head"],
|
||||
"tail": rel["tail"],
|
||||
"start_index": get_relation_span(rel, entities)[0],
|
||||
"end_index": get_relation_span(rel, entities)[1],
|
||||
} for rel in kv_relations],
|
||||
key=lambda x: x["head"], )
|
||||
return relations
|
||||
|
||||
def load_img(self, image_path):
|
||||
# read img
|
||||
img = cv2.imread(image_path)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
resize_h, resize_w = self.img_size
|
||||
im_shape = img.shape[0:2]
|
||||
im_scale_y = resize_h / im_shape[0]
|
||||
im_scale_x = resize_w / im_shape[1]
|
||||
img_new = cv2.resize(
|
||||
img, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=2)
|
||||
mean = np.array([0.485, 0.456, 0.406])[np.newaxis, np.newaxis, :]
|
||||
std = np.array([0.229, 0.224, 0.225])[np.newaxis, np.newaxis, :]
|
||||
img_new = img_new / 255.0
|
||||
img_new -= mean
|
||||
img_new /= std
|
||||
img = img_new.transpose((2, 0, 1))
|
||||
return img
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.load_mode == "all":
|
||||
data = copy.deepcopy(self.encoded_inputs_all[idx])
|
||||
else:
|
||||
data = self._parse_label_file(self.all_lines[idx])[0]
|
||||
|
||||
image_path = data.pop('image_path')
|
||||
data["image"] = self.load_img(image_path)
|
||||
|
||||
return_data = {}
|
||||
for k, v in data.items():
|
||||
if k in self.return_keys:
|
||||
if self.return_keys[k]['type'] == 'np':
|
||||
v = np.array(v, dtype=self.return_keys[k]['dtype'])
|
||||
return_data[k] = v
|
||||
return return_data
|
||||
|
||||
def __len__(self, ):
|
||||
if self.load_mode == "all":
|
||||
return len(self.encoded_inputs_all)
|
||||
else:
|
||||
return len(self.all_lines)
|
||||
|
||||
|
||||
def get_relation_span(rel, entities):
|
||||
bound = []
|
||||
for entity_index in [rel["head"], rel["tail"]]:
|
||||
bound.append(entities[entity_index]["start"])
|
||||
bound.append(entities[entity_index]["end"])
|
||||
return min(bound), max(bound)
|
||||
|
||||
|
||||
def reformat(data):
|
||||
new_data = {}
|
||||
for item in data:
|
||||
for k, v in item.items():
|
||||
if k not in new_data:
|
||||
new_data[k] = []
|
||||
new_data[k].append(v)
|
||||
return new_data
|
|
@ -13,4 +13,3 @@ lxml
|
|||
premailer
|
||||
openpyxl
|
||||
fasttext==0.9.1
|
||||
paddlenlp>=2.2.1
|
||||
|
|
|
@ -61,7 +61,8 @@ def main():
|
|||
else:
|
||||
model_type = None
|
||||
|
||||
best_model_dict = load_model(config, model)
|
||||
best_model_dict = load_model(
|
||||
config, model, model_type=config['Architecture']["model_type"])
|
||||
if len(best_model_dict):
|
||||
logger.info('metric in ckpt ***************')
|
||||
for k, v in best_model_dict.items():
|
||||
|
|
|
@ -0,0 +1,135 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
|
||||
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
||||
import cv2
|
||||
import json
|
||||
import paddle
|
||||
|
||||
from ppocr.data import create_operators, transform
|
||||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.utils.save_load import load_model
|
||||
from ppocr.utils.visual import draw_ser_results
|
||||
from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps
|
||||
import tools.program as program
|
||||
|
||||
|
||||
def to_tensor(data):
|
||||
import numbers
|
||||
from collections import defaultdict
|
||||
data_dict = defaultdict(list)
|
||||
to_tensor_idxs = []
|
||||
for idx, v in enumerate(data):
|
||||
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
|
||||
if idx not in to_tensor_idxs:
|
||||
to_tensor_idxs.append(idx)
|
||||
data_dict[idx].append(v)
|
||||
for idx in to_tensor_idxs:
|
||||
data_dict[idx] = paddle.to_tensor(data_dict[idx])
|
||||
return list(data_dict.values())
|
||||
|
||||
|
||||
class SerPredictor(object):
|
||||
def __init__(self, config):
|
||||
global_config = config['Global']
|
||||
|
||||
# build post process
|
||||
self.post_process_class = build_post_process(config['PostProcess'],
|
||||
global_config)
|
||||
|
||||
# build model
|
||||
self.model = build_model(config['Architecture'])
|
||||
|
||||
load_model(
|
||||
config, self.model, model_type=config['Architecture']["model_type"])
|
||||
|
||||
from paddleocr import PaddleOCR
|
||||
|
||||
self.ocr_engine = PaddleOCR(use_angle_cls=False, show_log=False)
|
||||
|
||||
# create data ops
|
||||
transforms = []
|
||||
for op in config['Eval']['dataset']['transforms']:
|
||||
op_name = list(op)[0]
|
||||
if 'Label' in op_name:
|
||||
op[op_name]['ocr_engine'] = self.ocr_engine
|
||||
elif op_name == 'KeepKeys':
|
||||
op[op_name]['keep_keys'] = [
|
||||
'input_ids', 'labels', 'bbox', 'image', 'attention_mask',
|
||||
'token_type_ids', 'segment_offset_id', 'ocr_info',
|
||||
'entities'
|
||||
]
|
||||
|
||||
transforms.append(op)
|
||||
global_config['infer_mode'] = True
|
||||
self.ops = create_operators(config['Eval']['dataset']['transforms'],
|
||||
global_config)
|
||||
self.model.eval()
|
||||
|
||||
def __call__(self, img_path):
|
||||
with open(img_path, 'rb') as f:
|
||||
img = f.read()
|
||||
data = {'image': img}
|
||||
batch = transform(data, self.ops)
|
||||
batch = to_tensor(batch)
|
||||
preds = self.model(batch)
|
||||
post_result = self.post_process_class(
|
||||
preds,
|
||||
attention_masks=batch[4],
|
||||
segment_offset_ids=batch[6],
|
||||
ocr_infos=batch[7])
|
||||
return post_result, batch
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config, device, logger, vdl_writer = program.preprocess()
|
||||
os.makedirs(config['Global']['save_res_path'], exist_ok=True)
|
||||
|
||||
ser_engine = SerPredictor(config)
|
||||
|
||||
infer_imgs = get_image_file_list(config['Global']['infer_img'])
|
||||
with open(
|
||||
os.path.join(config['Global']['save_res_path'],
|
||||
"infer_results.txt"),
|
||||
"w",
|
||||
encoding='utf-8') as fout:
|
||||
for idx, img_path in enumerate(infer_imgs):
|
||||
save_img_path = os.path.join(
|
||||
config['Global']['save_res_path'],
|
||||
os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg")
|
||||
logger.info("process: [{}/{}], save result to {}".format(
|
||||
idx, len(infer_imgs), save_img_path))
|
||||
|
||||
result, _ = ser_engine(img_path)
|
||||
result = result[0]
|
||||
fout.write(img_path + "\t" + json.dumps(
|
||||
{
|
||||
"ocr_info": result,
|
||||
}, ensure_ascii=False) + "\n")
|
||||
img_res = draw_ser_results(img_path, result)
|
||||
cv2.imwrite(save_img_path, img_res)
|
|
@ -0,0 +1,199 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
|
||||
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
||||
import cv2
|
||||
import json
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
|
||||
from ppocr.data import create_operators, transform
|
||||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.utils.save_load import load_model
|
||||
from ppocr.utils.visual import draw_re_results
|
||||
from ppocr.utils.logging import get_logger
|
||||
from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps, print_dict
|
||||
from tools.program import ArgsParser, load_config, merge_config, check_gpu
|
||||
from tools.infer_vqa_token_ser import SerPredictor
|
||||
|
||||
|
||||
class ReArgsParser(ArgsParser):
|
||||
def __init__(self):
|
||||
super(ReArgsParser, self).__init__()
|
||||
self.add_argument(
|
||||
"-c_ser", "--config_ser", help="ser configuration file to use")
|
||||
self.add_argument(
|
||||
"-o_ser",
|
||||
"--opt_ser",
|
||||
nargs='+',
|
||||
help="set ser configuration options ")
|
||||
|
||||
def parse_args(self, argv=None):
|
||||
args = super(ReArgsParser, self).parse_args(argv)
|
||||
assert args.config_ser is not None, \
|
||||
"Please specify --config_ser=ser_configure_file_path."
|
||||
args.opt_ser = self._parse_opt(args.opt_ser)
|
||||
return args
|
||||
|
||||
|
||||
def make_input(ser_inputs, ser_results):
|
||||
entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2}
|
||||
|
||||
entities = ser_inputs[8][0]
|
||||
ser_results = ser_results[0]
|
||||
assert len(entities) == len(ser_results)
|
||||
|
||||
# entities
|
||||
start = []
|
||||
end = []
|
||||
label = []
|
||||
entity_idx_dict = {}
|
||||
for i, (res, entity) in enumerate(zip(ser_results, entities)):
|
||||
if res['pred'] == 'O':
|
||||
continue
|
||||
entity_idx_dict[len(start)] = i
|
||||
start.append(entity['start'])
|
||||
end.append(entity['end'])
|
||||
label.append(entities_labels[res['pred']])
|
||||
entities = dict(start=start, end=end, label=label)
|
||||
|
||||
# relations
|
||||
head = []
|
||||
tail = []
|
||||
for i in range(len(entities["label"])):
|
||||
for j in range(len(entities["label"])):
|
||||
if entities["label"][i] == 1 and entities["label"][j] == 2:
|
||||
head.append(i)
|
||||
tail.append(j)
|
||||
|
||||
relations = dict(head=head, tail=tail)
|
||||
|
||||
batch_size = ser_inputs[0].shape[0]
|
||||
entities_batch = []
|
||||
relations_batch = []
|
||||
entity_idx_dict_batch = []
|
||||
for b in range(batch_size):
|
||||
entities_batch.append(entities)
|
||||
relations_batch.append(relations)
|
||||
entity_idx_dict_batch.append(entity_idx_dict)
|
||||
|
||||
ser_inputs[8] = entities_batch
|
||||
ser_inputs.append(relations_batch)
|
||||
|
||||
ser_inputs.pop(7)
|
||||
ser_inputs.pop(6)
|
||||
ser_inputs.pop(1)
|
||||
return ser_inputs, entity_idx_dict_batch
|
||||
|
||||
|
||||
class SerRePredictor(object):
|
||||
def __init__(self, config, ser_config):
|
||||
self.ser_engine = SerPredictor(ser_config)
|
||||
|
||||
# init re model
|
||||
global_config = config['Global']
|
||||
|
||||
# build post process
|
||||
self.post_process_class = build_post_process(config['PostProcess'],
|
||||
global_config)
|
||||
|
||||
# build model
|
||||
self.model = build_model(config['Architecture'])
|
||||
|
||||
load_model(
|
||||
config, self.model, model_type=config['Architecture']["model_type"])
|
||||
|
||||
self.model.eval()
|
||||
|
||||
def __call__(self, img_path):
|
||||
ser_results, ser_inputs = self.ser_engine(img_path)
|
||||
paddle.save(ser_inputs, 'ser_inputs.npy')
|
||||
paddle.save(ser_results, 'ser_results.npy')
|
||||
re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
|
||||
preds = self.model(re_input)
|
||||
post_result = self.post_process_class(
|
||||
preds,
|
||||
ser_results=ser_results,
|
||||
entity_idx_dict_batch=entity_idx_dict_batch)
|
||||
return post_result
|
||||
|
||||
|
||||
def preprocess():
|
||||
FLAGS = ReArgsParser().parse_args()
|
||||
config = load_config(FLAGS.config)
|
||||
config = merge_config(config, FLAGS.opt)
|
||||
|
||||
ser_config = load_config(FLAGS.config_ser)
|
||||
ser_config = merge_config(ser_config, FLAGS.opt_ser)
|
||||
|
||||
logger = get_logger(name='root')
|
||||
|
||||
# check if set use_gpu=True in paddlepaddle cpu version
|
||||
use_gpu = config['Global']['use_gpu']
|
||||
check_gpu(use_gpu)
|
||||
|
||||
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
||||
device = paddle.set_device(device)
|
||||
|
||||
logger.info('{} re config {}'.format('*' * 10, '*' * 10))
|
||||
print_dict(config, logger)
|
||||
logger.info('\n')
|
||||
logger.info('{} ser config {}'.format('*' * 10, '*' * 10))
|
||||
print_dict(ser_config, logger)
|
||||
logger.info('train with paddle {} and device {}'.format(paddle.__version__,
|
||||
device))
|
||||
return config, ser_config, device, logger
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config, ser_config, device, logger = preprocess()
|
||||
os.makedirs(config['Global']['save_res_path'], exist_ok=True)
|
||||
|
||||
ser_re_engine = SerRePredictor(config, ser_config)
|
||||
|
||||
infer_imgs = get_image_file_list(config['Global']['infer_img'])
|
||||
with open(
|
||||
os.path.join(config['Global']['save_res_path'],
|
||||
"infer_results.txt"),
|
||||
"w",
|
||||
encoding='utf-8') as fout:
|
||||
for idx, img_path in enumerate(infer_imgs):
|
||||
save_img_path = os.path.join(
|
||||
config['Global']['save_res_path'],
|
||||
os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg")
|
||||
logger.info("process: [{}/{}], save result to {}".format(
|
||||
idx, len(infer_imgs), save_img_path))
|
||||
|
||||
result = ser_re_engine(img_path)
|
||||
result = result[0]
|
||||
fout.write(img_path + "\t" + json.dumps(
|
||||
{
|
||||
"ser_resule": result,
|
||||
}, ensure_ascii=False) + "\n")
|
||||
img_res = draw_re_results(img_path, result)
|
||||
cv2.imwrite(save_img_path, img_res)
|
113
tools/program.py
113
tools/program.py
|
@ -69,24 +69,6 @@ class ArgsParser(ArgumentParser):
|
|||
return config
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
"""Single level attribute dict, NOT recursive"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(AttrDict, self).__init__()
|
||||
super(AttrDict, self).update(kwargs)
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key in self:
|
||||
return self[key]
|
||||
raise AttributeError("object has no attribute '{}'".format(key))
|
||||
|
||||
|
||||
global_config = AttrDict()
|
||||
|
||||
default_config = {'Global': {'debug': False, }}
|
||||
|
||||
|
||||
def load_config(file_path):
|
||||
"""
|
||||
Load config from yml/yaml file.
|
||||
|
@ -94,38 +76,38 @@ def load_config(file_path):
|
|||
file_path (str): Path of the config file to be loaded.
|
||||
Returns: global config
|
||||
"""
|
||||
merge_config(default_config)
|
||||
_, ext = os.path.splitext(file_path)
|
||||
assert ext in ['.yml', '.yaml'], "only support yaml files for now"
|
||||
merge_config(yaml.load(open(file_path, 'rb'), Loader=yaml.Loader))
|
||||
return global_config
|
||||
config = yaml.load(open(file_path, 'rb'), Loader=yaml.Loader)
|
||||
return config
|
||||
|
||||
|
||||
def merge_config(config):
|
||||
def merge_config(config, opts):
|
||||
"""
|
||||
Merge config into global config.
|
||||
Args:
|
||||
config (dict): Config to be merged.
|
||||
Returns: global config
|
||||
"""
|
||||
for key, value in config.items():
|
||||
for key, value in opts.items():
|
||||
if "." not in key:
|
||||
if isinstance(value, dict) and key in global_config:
|
||||
global_config[key].update(value)
|
||||
if isinstance(value, dict) and key in config:
|
||||
config[key].update(value)
|
||||
else:
|
||||
global_config[key] = value
|
||||
config[key] = value
|
||||
else:
|
||||
sub_keys = key.split('.')
|
||||
assert (
|
||||
sub_keys[0] in global_config
|
||||
sub_keys[0] in config
|
||||
), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
|
||||
global_config.keys(), sub_keys[0])
|
||||
cur = global_config[sub_keys[0]]
|
||||
config.keys(), sub_keys[0])
|
||||
cur = config[sub_keys[0]]
|
||||
for idx, sub_key in enumerate(sub_keys[1:]):
|
||||
if idx == len(sub_keys) - 2:
|
||||
cur[sub_key] = value
|
||||
else:
|
||||
cur = cur[sub_key]
|
||||
return config
|
||||
|
||||
|
||||
def check_gpu(use_gpu):
|
||||
|
@ -204,20 +186,24 @@ def train(config,
|
|||
model_type = None
|
||||
algorithm = config['Architecture']['algorithm']
|
||||
|
||||
if 'start_epoch' in best_model_dict:
|
||||
start_epoch = best_model_dict['start_epoch']
|
||||
else:
|
||||
start_epoch = 1
|
||||
start_epoch = best_model_dict[
|
||||
'start_epoch'] if 'start_epoch' in best_model_dict else 1
|
||||
|
||||
train_reader_cost = 0.0
|
||||
train_run_cost = 0.0
|
||||
total_samples = 0
|
||||
reader_start = time.time()
|
||||
|
||||
max_iter = len(train_dataloader) - 1 if platform.system(
|
||||
) == "Windows" else len(train_dataloader)
|
||||
|
||||
for epoch in range(start_epoch, epoch_num + 1):
|
||||
train_dataloader = build_dataloader(
|
||||
config, 'Train', device, logger, seed=epoch)
|
||||
train_reader_cost = 0.0
|
||||
train_run_cost = 0.0
|
||||
total_samples = 0
|
||||
reader_start = time.time()
|
||||
max_iter = len(train_dataloader) - 1 if platform.system(
|
||||
) == "Windows" else len(train_dataloader)
|
||||
if train_dataloader.dataset.need_reset:
|
||||
train_dataloader = build_dataloader(
|
||||
config, 'Train', device, logger, seed=epoch)
|
||||
max_iter = len(train_dataloader) - 1 if platform.system(
|
||||
) == "Windows" else len(train_dataloader)
|
||||
|
||||
for idx, batch in enumerate(train_dataloader):
|
||||
profiler.add_profiler_step(profiler_options)
|
||||
train_reader_cost += time.time() - reader_start
|
||||
|
@ -239,10 +225,11 @@ def train(config,
|
|||
else:
|
||||
if model_type == 'table' or extra_input:
|
||||
preds = model(images, data=batch[1:])
|
||||
elif model_type == "kie":
|
||||
elif model_type in ["kie", 'vqa']:
|
||||
preds = model(batch)
|
||||
else:
|
||||
preds = model(images)
|
||||
|
||||
loss = loss_class(preds, batch)
|
||||
avg_loss = loss['loss']
|
||||
|
||||
|
@ -256,6 +243,7 @@ def train(config,
|
|||
optimizer.clear_grad()
|
||||
|
||||
train_run_cost += time.time() - train_start
|
||||
global_step += 1
|
||||
total_samples += len(images)
|
||||
|
||||
if not isinstance(lr_scheduler, float):
|
||||
|
@ -285,12 +273,13 @@ def train(config,
|
|||
(global_step > 0 and global_step % print_batch_step == 0) or
|
||||
(idx >= len(train_dataloader) - 1)):
|
||||
logs = train_stats.log()
|
||||
strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format(
|
||||
strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: {:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ips: {:.5f}'.format(
|
||||
epoch, epoch_num, global_step, logs, train_reader_cost /
|
||||
print_batch_step, (train_reader_cost + train_run_cost) /
|
||||
print_batch_step, total_samples,
|
||||
print_batch_step, total_samples / print_batch_step,
|
||||
total_samples / (train_reader_cost + train_run_cost))
|
||||
logger.info(strs)
|
||||
|
||||
train_reader_cost = 0.0
|
||||
train_run_cost = 0.0
|
||||
total_samples = 0
|
||||
|
@ -330,6 +319,7 @@ def train(config,
|
|||
optimizer,
|
||||
save_model_dir,
|
||||
logger,
|
||||
config,
|
||||
is_best=True,
|
||||
prefix='best_accuracy',
|
||||
best_model_dict=best_model_dict,
|
||||
|
@ -344,8 +334,7 @@ def train(config,
|
|||
vdl_writer.add_scalar('EVAL/best_{}'.format(main_indicator),
|
||||
best_model_dict[main_indicator],
|
||||
global_step)
|
||||
global_step += 1
|
||||
optimizer.clear_grad()
|
||||
|
||||
reader_start = time.time()
|
||||
if dist.get_rank() == 0:
|
||||
save_model(
|
||||
|
@ -353,6 +342,7 @@ def train(config,
|
|||
optimizer,
|
||||
save_model_dir,
|
||||
logger,
|
||||
config,
|
||||
is_best=False,
|
||||
prefix='latest',
|
||||
best_model_dict=best_model_dict,
|
||||
|
@ -364,6 +354,7 @@ def train(config,
|
|||
optimizer,
|
||||
save_model_dir,
|
||||
logger,
|
||||
config,
|
||||
is_best=False,
|
||||
prefix='iter_epoch_{}'.format(epoch),
|
||||
best_model_dict=best_model_dict,
|
||||
|
@ -401,19 +392,28 @@ def eval(model,
|
|||
start = time.time()
|
||||
if model_type == 'table' or extra_input:
|
||||
preds = model(images, data=batch[1:])
|
||||
elif model_type == "kie":
|
||||
elif model_type in ["kie", 'vqa']:
|
||||
preds = model(batch)
|
||||
else:
|
||||
preds = model(images)
|
||||
batch = [item.numpy() for item in batch]
|
||||
|
||||
batch_numpy = []
|
||||
for item in batch:
|
||||
if isinstance(item, paddle.Tensor):
|
||||
batch_numpy.append(item.numpy())
|
||||
else:
|
||||
batch_numpy.append(item)
|
||||
# Obtain usable results from post-processing methods
|
||||
total_time += time.time() - start
|
||||
# Evaluate the results of the current batch
|
||||
if model_type in ['table', 'kie']:
|
||||
eval_class(preds, batch)
|
||||
eval_class(preds, batch_numpy)
|
||||
elif model_type in ['vqa']:
|
||||
post_result = post_process_class(preds, batch_numpy)
|
||||
eval_class(post_result, batch_numpy)
|
||||
else:
|
||||
post_result = post_process_class(preds, batch[1])
|
||||
eval_class(post_result, batch)
|
||||
post_result = post_process_class(preds, batch_numpy[1])
|
||||
eval_class(post_result, batch_numpy)
|
||||
|
||||
pbar.update(1)
|
||||
total_frame += len(images)
|
||||
|
@ -479,9 +479,9 @@ def preprocess(is_train=False):
|
|||
FLAGS = ArgsParser().parse_args()
|
||||
profiler_options = FLAGS.profiler_options
|
||||
config = load_config(FLAGS.config)
|
||||
merge_config(FLAGS.opt)
|
||||
config = merge_config(config, FLAGS.opt)
|
||||
profile_dic = {"profiler_options": FLAGS.profiler_options}
|
||||
merge_config(profile_dic)
|
||||
config = merge_config(config, profile_dic)
|
||||
|
||||
if is_train:
|
||||
# save_config
|
||||
|
@ -503,13 +503,8 @@ def preprocess(is_train=False):
|
|||
assert alg in [
|
||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
|
||||
'SEED', 'SDMGR'
|
||||
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM'
|
||||
]
|
||||
windows_not_support_list = ['PSE']
|
||||
if platform.system() == "Windows" and alg in windows_not_support_list:
|
||||
logger.warning('{} is not support in Windows now'.format(
|
||||
windows_not_support_list))
|
||||
sys.exit()
|
||||
|
||||
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
||||
device = paddle.set_device(device)
|
||||
|
|
|
@ -97,7 +97,8 @@ def main(config, device, logger, vdl_writer):
|
|||
# build metric
|
||||
eval_class = build_metric(config['Metric'])
|
||||
# load pretrain model
|
||||
pre_best_model_dict = load_model(config, model, optimizer)
|
||||
pre_best_model_dict = load_model(config, model, optimizer,
|
||||
config['Architecture']["model_type"])
|
||||
logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
|
||||
if valid_dataloader is not None:
|
||||
logger.info('valid dataloader has {} iters'.format(
|
||||
|
|
Loading…
Reference in New Issue