support latexocr static train (#14297)

pull/14306/head
liuhongen1234567 2024-11-29 17:44:53 +08:00 committed by GitHub
parent 8fdc409edf
commit 0018cbd2b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 2 deletions

View File

@ -119,6 +119,14 @@ def apply_to_static(model, config, logger):
InputSpec([None], dtype="int64"),
]
)
elif algo == "LaTeXOCR":
specs = [
[
InputSpec(shape=[None, 1, None, None], dtype="float32"),
InputSpec(shape=[None, None], dtype="float32"),
InputSpec(shape=[None, None], dtype="float32"),
]
]
model = to_static(model, input_spec=specs)
logger.info("Successfully to apply @to_static with specs: {}".format(specs))
return model

View File

@ -101,7 +101,9 @@ class StdConv2dSame(nn.Conv2D):
if self.export:
weight = paddle.reshape(
F.batch_norm(
self.weight.reshape([1, self._out_channels, -1]),
self.weight.reshape([1, self._out_channels, -1]).cast(
paddle.float32
),
running_mean,
running_variance,
momentum=0.0,
@ -113,7 +115,9 @@ class StdConv2dSame(nn.Conv2D):
else:
weight = paddle.reshape(
F.batch_norm(
self.weight.reshape([1, self._out_channels, -1]),
self.weight.reshape([1, self._out_channels, -1]).cast(
paddle.float32
),
running_mean,
running_variance,
training=True,