add layoutlmv2
parent
6fe387ce03
commit
1bcbd31815
|
@ -0,0 +1,125 @@
|
|||
Global:
|
||||
use_gpu: True
|
||||
epoch_num: &epoch_num 200
|
||||
log_smooth_window: 10
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/re_layoutlmv2/
|
||||
save_epoch_step: 2000
|
||||
# evaluation is run every 10 iterations after the 0th iteration
|
||||
eval_batch_step: [ 0, 19 ]
|
||||
cal_metric_during_train: False
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
seed: 2048
|
||||
infer_img: doc/vqa/input/zh_val_21.jpg
|
||||
save_res_path: ./output/re/
|
||||
|
||||
Architecture:
|
||||
model_type: vqa
|
||||
algorithm: &algorithm "LayoutLMv2"
|
||||
Transform:
|
||||
Backbone:
|
||||
name: LayoutLMv2ForRe
|
||||
pretrained: True
|
||||
checkpoints:
|
||||
|
||||
Loss:
|
||||
name: LossFromOutput
|
||||
key: loss
|
||||
reduction: mean
|
||||
|
||||
Optimizer:
|
||||
name: AdamW
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
clip_norm: 10
|
||||
lr:
|
||||
name: Piecewise
|
||||
values: [0.000005, 0.00005]
|
||||
decay_epochs: [10]
|
||||
warmup_epoch: 0
|
||||
regularizer:
|
||||
name: L2
|
||||
factor: 0.00000
|
||||
|
||||
PostProcess:
|
||||
name: VQAReTokenLayoutLMPostProcess
|
||||
|
||||
Metric:
|
||||
name: VQAReTokenMetric
|
||||
main_indicator: hmean
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: train_data/XFUND/zh_train/image
|
||||
label_file_list:
|
||||
- train_data/XFUND/zh_train/xfun_normalize_train.json
|
||||
ratio_list: [ 1.0 ]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: RGB
|
||||
channel_first: False
|
||||
- VQATokenLabelEncode: # Class handling label
|
||||
contains_re: True
|
||||
algorithm: *algorithm
|
||||
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
|
||||
- VQATokenPad:
|
||||
max_seq_len: &max_seq_len 512
|
||||
return_attention_mask: True
|
||||
- VQAReTokenRelation:
|
||||
- VQAReTokenChunk:
|
||||
max_seq_len: *max_seq_len
|
||||
- Resize:
|
||||
size: [224,224]
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 8
|
||||
collate_fn: ListCollator
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: train_data/XFUND/zh_val/image
|
||||
label_file_list:
|
||||
- train_data/XFUND/zh_val/xfun_normalize_val.json
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: RGB
|
||||
channel_first: False
|
||||
- VQATokenLabelEncode: # Class handling label
|
||||
contains_re: True
|
||||
algorithm: *algorithm
|
||||
class_path: *class_path
|
||||
- VQATokenPad:
|
||||
max_seq_len: *max_seq_len
|
||||
return_attention_mask: True
|
||||
- VQAReTokenRelation:
|
||||
- VQAReTokenChunk:
|
||||
max_seq_len: *max_seq_len
|
||||
- Resize:
|
||||
size: [224,224]
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 8
|
||||
collate_fn: ListCollator
|
|
@ -21,7 +21,7 @@ Architecture:
|
|||
Backbone:
|
||||
name: LayoutXLMForRe
|
||||
pretrained: True
|
||||
checkpoints:
|
||||
checkpoints:
|
||||
|
||||
Loss:
|
||||
name: LossFromOutput
|
||||
|
@ -34,7 +34,10 @@ Optimizer:
|
|||
beta2: 0.999
|
||||
clip_norm: 10
|
||||
lr:
|
||||
learning_rate: 0.00005
|
||||
name: Piecewise
|
||||
values: [0.000005, 0.00005]
|
||||
decay_epochs: [10]
|
||||
warmup_epoch: 0
|
||||
regularizer:
|
||||
name: L2
|
||||
factor: 0.00000
|
||||
|
@ -81,7 +84,7 @@ Train:
|
|||
shuffle: True
|
||||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 4
|
||||
num_workers: 8
|
||||
collate_fn: ListCollator
|
||||
|
||||
Eval:
|
||||
|
@ -118,5 +121,5 @@ Eval:
|
|||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 4
|
||||
num_workers: 8
|
||||
collate_fn: ListCollator
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
Global:
|
||||
use_gpu: True
|
||||
epoch_num: &epoch_num 200
|
||||
log_smooth_window: 10
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/ser_layoutlmv2/
|
||||
save_epoch_step: 2000
|
||||
# evaluation is run every 10 iterations after the 0th iteration
|
||||
eval_batch_step: [ 0, 19 ]
|
||||
cal_metric_during_train: False
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
seed: 2022
|
||||
infer_img: doc/vqa/input/zh_val_0.jpg
|
||||
save_res_path: ./output/ser/
|
||||
|
||||
Architecture:
|
||||
model_type: vqa
|
||||
algorithm: &algorithm "LayoutLMv2"
|
||||
Transform:
|
||||
Backbone:
|
||||
name: LayoutLMv2ForSer
|
||||
pretrained: True
|
||||
checkpoints:
|
||||
num_classes: &num_classes 7
|
||||
|
||||
Loss:
|
||||
name: VQASerTokenLayoutLMLoss
|
||||
num_classes: *num_classes
|
||||
|
||||
Optimizer:
|
||||
name: AdamW
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Linear
|
||||
learning_rate: 0.00005
|
||||
epochs: *epoch_num
|
||||
warmup_epoch: 2
|
||||
regularizer:
|
||||
|
||||
name: L2
|
||||
factor: 0.00000
|
||||
|
||||
PostProcess:
|
||||
name: VQASerTokenLayoutLMPostProcess
|
||||
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
|
||||
|
||||
Metric:
|
||||
name: VQASerTokenMetric
|
||||
main_indicator: hmean
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: train_data/XFUND/zh_train/image
|
||||
label_file_list:
|
||||
- train_data/XFUND/zh_train/xfun_normalize_train.json
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: RGB
|
||||
channel_first: False
|
||||
- VQATokenLabelEncode: # Class handling label
|
||||
contains_re: False
|
||||
algorithm: *algorithm
|
||||
class_path: *class_path
|
||||
- VQATokenPad:
|
||||
max_seq_len: &max_seq_len 512
|
||||
return_attention_mask: True
|
||||
- VQASerTokenChunk:
|
||||
max_seq_len: *max_seq_len
|
||||
- Resize:
|
||||
size: [224,224]
|
||||
- NormalizeImage:
|
||||
scale: 1
|
||||
mean: [ 123.675, 116.28, 103.53 ]
|
||||
std: [ 58.395, 57.12, 57.375 ]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: [ 'input_ids','labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: train_data/XFUND/zh_val/image
|
||||
label_file_list:
|
||||
- train_data/XFUND/zh_val/xfun_normalize_val.json
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: RGB
|
||||
channel_first: False
|
||||
- VQATokenLabelEncode: # Class handling label
|
||||
contains_re: False
|
||||
algorithm: *algorithm
|
||||
class_path: *class_path
|
||||
- VQATokenPad:
|
||||
max_seq_len: *max_seq_len
|
||||
return_attention_mask: True
|
||||
- VQASerTokenChunk:
|
||||
max_seq_len: *max_seq_len
|
||||
- Resize:
|
||||
size: [224,224]
|
||||
- NormalizeImage:
|
||||
scale: 1
|
||||
mean: [ 123.675, 116.28, 103.53 ]
|
||||
std: [ 58.395, 57.12, 57.375 ]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: [ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 4
|
|
@ -799,7 +799,7 @@ class VQATokenLabelEncode(object):
|
|||
ocr_engine=None,
|
||||
**kwargs):
|
||||
super(VQATokenLabelEncode, self).__init__()
|
||||
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutLMTokenizer
|
||||
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutLMTokenizer, LayoutLMv2Tokenizer
|
||||
from ppocr.utils.utility import load_vqa_bio_label_maps
|
||||
tokenizer_dict = {
|
||||
'LayoutXLM': {
|
||||
|
@ -809,6 +809,10 @@ class VQATokenLabelEncode(object):
|
|||
'LayoutLM': {
|
||||
'class': LayoutLMTokenizer,
|
||||
'pretrained_model': 'layoutlm-base-uncased'
|
||||
},
|
||||
'LayoutLMv2': {
|
||||
'class': LayoutLMv2Tokenizer,
|
||||
'pretrained_model': 'layoutlmv2-base-uncased'
|
||||
}
|
||||
}
|
||||
self.contains_re = contains_re
|
||||
|
|
|
@ -12,6 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
class VQASerTokenChunk(object):
|
||||
def __init__(self, max_seq_len=512, infer_mode=False, **kwargs):
|
||||
|
@ -39,6 +41,8 @@ class VQASerTokenChunk(object):
|
|||
encoded_inputs_example[key] = data[key]
|
||||
|
||||
encoded_inputs_all.append(encoded_inputs_example)
|
||||
if len(encoded_inputs_all) == 0:
|
||||
return None
|
||||
return encoded_inputs_all[0]
|
||||
|
||||
|
||||
|
@ -101,17 +105,18 @@ class VQAReTokenChunk(object):
|
|||
"entities": self.reformat(entities_in_this_span),
|
||||
"relations": self.reformat(relations_in_this_span),
|
||||
})
|
||||
item['entities']['label'] = [
|
||||
self.entities_labels[x] for x in item['entities']['label']
|
||||
]
|
||||
encoded_inputs_all.append(item)
|
||||
if len(item['entities']) > 0:
|
||||
item['entities']['label'] = [
|
||||
self.entities_labels[x] for x in item['entities']['label']
|
||||
]
|
||||
encoded_inputs_all.append(item)
|
||||
if len(encoded_inputs_all) == 0:
|
||||
return None
|
||||
return encoded_inputs_all[0]
|
||||
|
||||
def reformat(self, data):
|
||||
new_data = {}
|
||||
new_data = defaultdict(list)
|
||||
for item in data:
|
||||
for k, v in item.items():
|
||||
if k not in new_data:
|
||||
new_data[k] = []
|
||||
new_data[k].append(v)
|
||||
return new_data
|
||||
|
|
|
@ -45,8 +45,11 @@ def build_backbone(config, model_type):
|
|||
from .table_mobilenet_v3 import MobileNetV3
|
||||
support_dict = ["ResNet", "MobileNetV3"]
|
||||
elif model_type == 'vqa':
|
||||
from .vqa_layoutlm import LayoutLMForSer, LayoutXLMForSer, LayoutXLMForRe
|
||||
support_dict = ["LayoutLMForSer", "LayoutXLMForSer", 'LayoutXLMForRe']
|
||||
from .vqa_layoutlm import LayoutLMForSer, LayoutLMv2ForSer, LayoutLMv2ForRe, LayoutXLMForSer, LayoutXLMForRe
|
||||
support_dict = [
|
||||
"LayoutLMForSer", "LayoutLMv2ForSer", 'LayoutLMv2ForRe',
|
||||
"LayoutXLMForSer", 'LayoutXLMForRe'
|
||||
]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -21,12 +21,14 @@ from paddle import nn
|
|||
|
||||
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMForTokenClassification, LayoutXLMForRelationExtraction
|
||||
from paddlenlp.transformers import LayoutLMModel, LayoutLMForTokenClassification
|
||||
from paddlenlp.transformers import LayoutLMv2Model, LayoutLMv2ForTokenClassification, LayoutLMv2ForRelationExtraction
|
||||
|
||||
__all__ = ["LayoutXLMForSer", 'LayoutLMForSer']
|
||||
|
||||
pretrained_model_dict = {
|
||||
LayoutXLMModel: 'layoutxlm-base-uncased',
|
||||
LayoutLMModel: 'layoutlm-base-uncased'
|
||||
LayoutLMModel: 'layoutlm-base-uncased',
|
||||
LayoutLMv2Model: 'layoutlmv2-base-uncased'
|
||||
}
|
||||
|
||||
|
||||
|
@ -58,6 +60,52 @@ class NLPBaseModel(nn.Layer):
|
|||
self.out_channels = 1
|
||||
|
||||
|
||||
class LayoutLMForSer(NLPBaseModel):
|
||||
def __init__(self, num_classes, pretrained=True, checkpoints=None,
|
||||
**kwargs):
|
||||
super(LayoutLMForSer, self).__init__(
|
||||
LayoutLMModel,
|
||||
LayoutLMForTokenClassification,
|
||||
'ser',
|
||||
pretrained,
|
||||
checkpoints,
|
||||
num_classes=num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.model(
|
||||
input_ids=x[0],
|
||||
bbox=x[2],
|
||||
attention_mask=x[4],
|
||||
token_type_ids=x[5],
|
||||
position_ids=None,
|
||||
output_hidden_states=False)
|
||||
return x
|
||||
|
||||
|
||||
class LayoutLMv2ForSer(NLPBaseModel):
|
||||
def __init__(self, num_classes, pretrained=True, checkpoints=None,
|
||||
**kwargs):
|
||||
super(LayoutLMv2ForSer, self).__init__(
|
||||
LayoutLMv2Model,
|
||||
LayoutLMv2ForTokenClassification,
|
||||
'ser',
|
||||
pretrained,
|
||||
checkpoints,
|
||||
num_classes=num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.model(
|
||||
input_ids=x[0],
|
||||
bbox=x[2],
|
||||
image=x[3],
|
||||
attention_mask=x[4],
|
||||
token_type_ids=x[5],
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
labels=None)
|
||||
return x[0]
|
||||
|
||||
|
||||
class LayoutXLMForSer(NLPBaseModel):
|
||||
def __init__(self, num_classes, pretrained=True, checkpoints=None,
|
||||
**kwargs):
|
||||
|
@ -82,25 +130,24 @@ class LayoutXLMForSer(NLPBaseModel):
|
|||
return x[0]
|
||||
|
||||
|
||||
class LayoutLMForSer(NLPBaseModel):
|
||||
def __init__(self, num_classes, pretrained=True, checkpoints=None,
|
||||
**kwargs):
|
||||
super(LayoutLMForSer, self).__init__(
|
||||
LayoutLMModel,
|
||||
LayoutLMForTokenClassification,
|
||||
'ser',
|
||||
pretrained,
|
||||
checkpoints,
|
||||
num_classes=num_classes)
|
||||
class LayoutLMv2ForRe(NLPBaseModel):
|
||||
def __init__(self, pretrained=True, checkpoints=None, **kwargs):
|
||||
super(LayoutLMv2ForRe, self).__init__(LayoutLMv2Model,
|
||||
LayoutLMv2ForRelationExtraction,
|
||||
're', pretrained, checkpoints)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.model(
|
||||
input_ids=x[0],
|
||||
bbox=x[2],
|
||||
attention_mask=x[4],
|
||||
token_type_ids=x[5],
|
||||
bbox=x[1],
|
||||
labels=None,
|
||||
image=x[2],
|
||||
attention_mask=x[3],
|
||||
token_type_ids=x[4],
|
||||
position_ids=None,
|
||||
output_hidden_states=False)
|
||||
head_mask=None,
|
||||
entities=x[5],
|
||||
relations=x[6])
|
||||
return x
|
||||
|
||||
|
||||
|
|
|
@ -24,6 +24,8 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进
|
|||
|:---:|:---:|:---:| :---:|
|
||||
| LayoutXLM | RE | 0.7483 | [链接](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
|
||||
| LayoutXLM | SER | 0.9038 | [链接](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
|
||||
| LayoutLMv2 | RE | 0.6777 | [链接](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) |
|
||||
| LayoutLMv2 | SER | 0.8544 | [链接](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) |
|
||||
| LayoutLM | SER | 0.7731 | [链接](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue