support latexocr static train (#14297)
parent
8fdc409edf
commit
0018cbd2b6
|
@ -119,6 +119,14 @@ def apply_to_static(model, config, logger):
|
|||
InputSpec([None], dtype="int64"),
|
||||
]
|
||||
)
|
||||
elif algo == "LaTeXOCR":
|
||||
specs = [
|
||||
[
|
||||
InputSpec(shape=[None, 1, None, None], dtype="float32"),
|
||||
InputSpec(shape=[None, None], dtype="float32"),
|
||||
InputSpec(shape=[None, None], dtype="float32"),
|
||||
]
|
||||
]
|
||||
model = to_static(model, input_spec=specs)
|
||||
logger.info("Successfully to apply @to_static with specs: {}".format(specs))
|
||||
return model
|
||||
|
|
|
@ -101,7 +101,9 @@ class StdConv2dSame(nn.Conv2D):
|
|||
if self.export:
|
||||
weight = paddle.reshape(
|
||||
F.batch_norm(
|
||||
self.weight.reshape([1, self._out_channels, -1]),
|
||||
self.weight.reshape([1, self._out_channels, -1]).cast(
|
||||
paddle.float32
|
||||
),
|
||||
running_mean,
|
||||
running_variance,
|
||||
momentum=0.0,
|
||||
|
@ -113,7 +115,9 @@ class StdConv2dSame(nn.Conv2D):
|
|||
else:
|
||||
weight = paddle.reshape(
|
||||
F.batch_norm(
|
||||
self.weight.reshape([1, self._out_channels, -1]),
|
||||
self.weight.reshape([1, self._out_channels, -1]).cast(
|
||||
paddle.float32
|
||||
),
|
||||
running_mean,
|
||||
running_variance,
|
||||
training=True,
|
||||
|
|
Loading…
Reference in New Issue