Merge pull request #2678 from WenmuZhou/fix_srn_post_process
add max_text_length to export modelpull/2681/head
commit
3b19311dc1
|
@ -53,17 +53,19 @@ def main():
|
||||||
save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
|
save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
|
||||||
|
|
||||||
if config['Architecture']['algorithm'] == "SRN":
|
if config['Architecture']['algorithm'] == "SRN":
|
||||||
|
max_text_length = config['Architecture']['Head']['max_text_length']
|
||||||
other_shape = [
|
other_shape = [
|
||||||
paddle.static.InputSpec(
|
paddle.static.InputSpec(
|
||||||
shape=[None, 1, 64, 256], dtype='float32'), [
|
shape=[None, 1, 64, 256], dtype='float32'), [
|
||||||
paddle.static.InputSpec(
|
paddle.static.InputSpec(
|
||||||
shape=[None, 256, 1],
|
shape=[None, 256, 1],
|
||||||
dtype="int64"), paddle.static.InputSpec(
|
dtype="int64"), paddle.static.InputSpec(
|
||||||
shape=[None, 25, 1],
|
shape=[None, max_text_length, 1], dtype="int64"),
|
||||||
dtype="int64"), paddle.static.InputSpec(
|
|
||||||
shape=[None, 8, 25, 25], dtype="int64"),
|
|
||||||
paddle.static.InputSpec(
|
paddle.static.InputSpec(
|
||||||
shape=[None, 8, 25, 25], dtype="int64")
|
shape=[None, 8, max_text_length, max_text_length],
|
||||||
|
dtype="int64"), paddle.static.InputSpec(
|
||||||
|
shape=[None, 8, max_text_length, max_text_length],
|
||||||
|
dtype="int64")
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
model = to_static(model, input_spec=other_shape)
|
model = to_static(model, input_spec=other_shape)
|
||||||
|
|
Loading…
Reference in New Issue