modified head

pull/6562/head
smilelite 2022-07-10 12:31:27 +08:00
parent bc1c19c5d2
commit 6aa35c18ae
3 changed files with 4 additions and 4 deletions

View File

@ -33,8 +33,8 @@ def build_head(config):
from .rec_aster_head import AsterHead
from .rec_pren_head import PRENHead
from .rec_multi_head import MultiHead
from .rec_robustscanner_head import RobustScannerHead
from .rec_abinet_head import ABINetHead
from .rec_robustscanner_head import RobustScannerHead
# cls head
from .cls_head import ClsHead

View File

@ -1,4 +1,4 @@
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -79,7 +79,7 @@ def export_single_model(model,
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "RobustScanner":
max_seq_len = arch_config["Head"]["max_seq_len"]
max_text_length = arch_config["Head"]["max_text_length"]
other_shape = [
paddle.static.InputSpec(
shape=[None, 3, 48, 160], dtype="float32"),
@ -89,7 +89,7 @@ def export_single_model(model,
shape=[None, ],
dtype="float32"),
paddle.static.InputSpec(
shape=[None, max_seq_len],
shape=[None, max_text_length],
dtype="int64")
]
]