diff --git a/tools/export_model.py b/tools/export_model.py index f587b2bb3..bdff89f75 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -53,17 +53,19 @@ def main(): save_path = '{}/inference'.format(config['Global']['save_inference_dir']) if config['Architecture']['algorithm'] == "SRN": + max_text_length = config['Architecture']['Head']['max_text_length'] other_shape = [ paddle.static.InputSpec( shape=[None, 1, 64, 256], dtype='float32'), [ paddle.static.InputSpec( shape=[None, 256, 1], dtype="int64"), paddle.static.InputSpec( - shape=[None, 25, 1], - dtype="int64"), paddle.static.InputSpec( - shape=[None, 8, 25, 25], dtype="int64"), + shape=[None, max_text_length, 1], dtype="int64"), paddle.static.InputSpec( - shape=[None, 8, 25, 25], dtype="int64") + shape=[None, 8, max_text_length, max_text_length], + dtype="int64"), paddle.static.InputSpec( + shape=[None, 8, max_text_length, max_text_length], + dtype="int64") ] ] model = to_static(model, input_spec=other_shape)