mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-06-03 21:53:39 +08:00
add pretrained params to backbone
This commit is contained in:
parent
9ecfc34809
commit
cd7b2ea923
@ -8,7 +8,6 @@ Global:
|
||||
# evaluation is run every 10 iterations after the 0th iteration
|
||||
eval_batch_step: [ 0, 19 ]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: &pretrained_model layoutxlm-base-uncased # This field can only be changed by modifying the configuration file
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/vqa/input/zh_val_21.jpg
|
||||
@ -20,7 +19,7 @@ Architecture:
|
||||
Transform:
|
||||
Backbone:
|
||||
name: LayoutXLMForRe
|
||||
pretrained_model: *pretrained_model
|
||||
pretrained: True
|
||||
checkpoints:
|
||||
|
||||
Loss:
|
||||
|
@ -8,7 +8,6 @@ Global:
|
||||
# evaluation is run every 10 iterations after the 0th iteration
|
||||
eval_batch_step: [ 0, 19 ]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: &pretrained_model layoutlm-base-uncased # This field can only be changed by modifying the configuration file
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/vqa/input/zh_val_0.jpg
|
||||
@ -20,7 +19,7 @@ Architecture:
|
||||
Transform:
|
||||
Backbone:
|
||||
name: LayoutLMForSer
|
||||
pretrained_model: *pretrained_model
|
||||
pretrained: True
|
||||
checkpoints:
|
||||
num_classes: &num_classes 7
|
||||
|
||||
|
@ -8,7 +8,6 @@ Global:
|
||||
# evaluation is run every 10 iterations after the 0th iteration
|
||||
eval_batch_step: [ 0, 19 ]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: &pretrained_model layoutxlm-base-uncased # This field can only be changed by modifying the configuration file
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/vqa/input/zh_val_42.jpg
|
||||
@ -20,7 +19,7 @@ Architecture:
|
||||
Transform:
|
||||
Backbone:
|
||||
name: LayoutXLMForSer
|
||||
pretrained_model: *pretrained_model
|
||||
pretrained: True
|
||||
checkpoints:
|
||||
num_classes: &num_classes 7
|
||||
|
||||
|
@ -24,21 +24,32 @@ from paddlenlp.transformers import LayoutLMModel, LayoutLMForTokenClassification
|
||||
|
||||
__all__ = ["LayoutXLMForSer", 'LayoutLMForSer']
|
||||
|
||||
pretrained_model_dict = {
|
||||
LayoutXLMModel: 'layoutxlm-base-uncased',
|
||||
LayoutLMModel: 'layoutlm-base-uncased'
|
||||
}
|
||||
|
||||
|
||||
class NLPBaseModel(nn.Layer):
|
||||
def __init__(self,
|
||||
base_model_class,
|
||||
model_class,
|
||||
type='ser',
|
||||
pretrained_model=None,
|
||||
pretrained=True,
|
||||
checkpoints=None,
|
||||
**kwargs):
|
||||
super(NLPBaseModel, self).__init__()
|
||||
assert pretrained_model is not None or checkpoints is not None, "one of pretrained_model and checkpoints must be not None"
|
||||
if checkpoints is not None:
|
||||
self.model = model_class.from_pretrained(checkpoints)
|
||||
else:
|
||||
base_model = base_model_class.from_pretrained(pretrained_model)
|
||||
pretrained_model_name = pretrained_model_dict[base_model_class]
|
||||
if pretrained:
|
||||
base_model = base_model_class.from_pretrained(
|
||||
pretrained_model_name)
|
||||
else:
|
||||
base_model = base_model_class(
|
||||
**base_model_class.pretrained_init_configuration[
|
||||
pretrained_model_name])
|
||||
if type == 'ser':
|
||||
self.model = model_class(
|
||||
base_model, num_classes=kwargs['num_classes'], dropout=None)
|
||||
@ -48,16 +59,13 @@ class NLPBaseModel(nn.Layer):
|
||||
|
||||
|
||||
class LayoutXLMForSer(NLPBaseModel):
|
||||
def __init__(self,
|
||||
num_classes,
|
||||
pretrained_model='layoutxlm-base-uncased',
|
||||
checkpoints=None,
|
||||
def __init__(self, num_classes, pretrained=True, checkpoints=None,
|
||||
**kwargs):
|
||||
super(LayoutXLMForSer, self).__init__(
|
||||
LayoutXLMModel,
|
||||
LayoutXLMForTokenClassification,
|
||||
'ser',
|
||||
pretrained_model,
|
||||
pretrained,
|
||||
checkpoints,
|
||||
num_classes=num_classes)
|
||||
|
||||
@ -75,16 +83,13 @@ class LayoutXLMForSer(NLPBaseModel):
|
||||
|
||||
|
||||
class LayoutLMForSer(NLPBaseModel):
|
||||
def __init__(self,
|
||||
num_classes,
|
||||
pretrained_model='layoutxlm-base-uncased',
|
||||
checkpoints=None,
|
||||
def __init__(self, num_classes, pretrained=True, checkpoints=None,
|
||||
**kwargs):
|
||||
super(LayoutLMForSer, self).__init__(
|
||||
LayoutLMModel,
|
||||
LayoutLMForTokenClassification,
|
||||
'ser',
|
||||
pretrained_model,
|
||||
pretrained,
|
||||
checkpoints,
|
||||
num_classes=num_classes)
|
||||
|
||||
@ -100,13 +105,10 @@ class LayoutLMForSer(NLPBaseModel):
|
||||
|
||||
|
||||
class LayoutXLMForRe(NLPBaseModel):
|
||||
def __init__(self,
|
||||
pretrained_model='layoutxlm-base-uncased',
|
||||
checkpoints=None,
|
||||
**kwargs):
|
||||
super(LayoutXLMForRe, self).__init__(
|
||||
LayoutXLMModel, LayoutXLMForRelationExtraction, 're',
|
||||
pretrained_model, checkpoints)
|
||||
def __init__(self, pretrained=True, checkpoints=None, **kwargs):
|
||||
super(LayoutXLMForRe, self).__init__(LayoutXLMModel,
|
||||
LayoutXLMForRelationExtraction,
|
||||
're', pretrained, checkpoints)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.model(
|
||||
|
Loading…
x
Reference in New Issue
Block a user