modify export with pir
parent
cd18221ed3
commit
1e4058c04b
|
@ -606,37 +606,25 @@ class Engine(object):
|
|||
model.base_model.quanter.save_quantized_model(model,
|
||||
save_path + "_int8")
|
||||
else:
|
||||
paddle_version = version.parse(paddle.__version__)
|
||||
if (paddle_version >= version.parse('3.0.0b2') or paddle_version ==
|
||||
version.parse('0.0.0')) and os.environ.get(
|
||||
"FLAGS_enable_pir_api", None) not in ["0", "False"]:
|
||||
save_path = os.path.dirname(save_path)
|
||||
for enable_pir in [True, False]:
|
||||
if not enable_pir:
|
||||
save_path_no_pir = os.path.join(save_path, "inference")
|
||||
model.forward.rollback()
|
||||
with paddle.pir_utils.OldIrGuard():
|
||||
model = paddle.jit.to_static(
|
||||
model,
|
||||
input_spec=[
|
||||
paddle.static.InputSpec(
|
||||
shape=[None] +
|
||||
self.config["Global"]["image_shape"],
|
||||
dtype='float32')
|
||||
])
|
||||
paddle.jit.save(model, save_path_no_pir)
|
||||
else:
|
||||
save_path_pir = os.path.join(
|
||||
os.path.dirname(save_path),
|
||||
f"{os.path.basename(save_path)}_pir", "inference")
|
||||
paddle.jit.save(model, save_path_pir)
|
||||
shutil.copy(
|
||||
dst_path,
|
||||
os.path.join(
|
||||
os.path.dirname(save_path_pir),
|
||||
os.path.basename(dst_path)), )
|
||||
else:
|
||||
if self.config["Global"].get("export_with_pir", False):
|
||||
paddle_version = version.parse(paddle.__version__)
|
||||
assert (paddle_version >= version.parse('3.0.0b2') or
|
||||
paddle_version == version.parse('0.0.0')
|
||||
) and os.environ.get("FLAGS_enable_pir_api",
|
||||
None) not in ["0", "False"]
|
||||
paddle.jit.save(model, save_path)
|
||||
else:
|
||||
model.forward.rollback()
|
||||
with paddle.pir_utils.OldIrGuard():
|
||||
model = paddle.jit.to_static(
|
||||
model,
|
||||
input_spec=[
|
||||
paddle.static.InputSpec(
|
||||
shape=[None] +
|
||||
self.config["Global"]["image_shape"],
|
||||
dtype='float32')
|
||||
])
|
||||
paddle.jit.save(model, save_path)
|
||||
logger.info(
|
||||
f"Export succeeded! The inference model exported has been saved in \"{save_path}\"."
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue