modified head
parent
bc1c19c5d2
commit
6aa35c18ae
ppocr/modeling/heads
tools
|
@ -33,8 +33,8 @@ def build_head(config):
|
|||
from .rec_aster_head import AsterHead
|
||||
from .rec_pren_head import PRENHead
|
||||
from .rec_multi_head import MultiHead
|
||||
from .rec_robustscanner_head import RobustScannerHead
|
||||
from .rec_abinet_head import ABINetHead
|
||||
from .rec_robustscanner_head import RobustScannerHead
|
||||
|
||||
# cls head
|
||||
from .cls_head import ClsHead
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -79,7 +79,7 @@ def export_single_model(model,
|
|||
]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
elif arch_config["algorithm"] == "RobustScanner":
|
||||
max_seq_len = arch_config["Head"]["max_seq_len"]
|
||||
max_text_length = arch_config["Head"]["max_text_length"]
|
||||
other_shape = [
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 3, 48, 160], dtype="float32"),
|
||||
|
@ -89,7 +89,7 @@ def export_single_model(model,
|
|||
shape=[None, ],
|
||||
dtype="float32"),
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, max_seq_len],
|
||||
shape=[None, max_text_length],
|
||||
dtype="int64")
|
||||
]
|
||||
]
|
||||
|
|
Loading…
Reference in New Issue