modify export with pir (#14441)
parent
0d41ffc91d
commit
2f0a29ed3a
|
@ -318,42 +318,31 @@ def dynamic_to_static(model, arch_config, logger, input_shape=None):
|
|||
|
||||
|
||||
def export_single_model(
|
||||
model, arch_config, save_path, logger, yaml_path, input_shape=None, quanter=None
|
||||
model,
|
||||
arch_config,
|
||||
save_path,
|
||||
logger,
|
||||
yaml_path,
|
||||
config,
|
||||
input_shape=None,
|
||||
quanter=None,
|
||||
):
|
||||
|
||||
model = dynamic_to_static(model, arch_config, logger, input_shape)
|
||||
|
||||
if quanter is None:
|
||||
paddle_version = version.parse(paddle.__version__)
|
||||
if (
|
||||
paddle_version >= version.parse("3.0.0b2")
|
||||
or paddle_version == version.parse("0.0.0")
|
||||
) and os.environ.get("FLAGS_enable_pir_api", None) not in ["0", "False"]:
|
||||
save_path = os.path.dirname(save_path)
|
||||
for enable_pir in [True, False]:
|
||||
if not enable_pir:
|
||||
save_path_no_pir = os.path.join(save_path, "inference")
|
||||
model.forward.rollback()
|
||||
with paddle.pir_utils.OldIrGuard():
|
||||
model = dynamic_to_static(
|
||||
model, arch_config, logger, input_shape
|
||||
)
|
||||
paddle.jit.save(model, save_path_no_pir)
|
||||
else:
|
||||
save_path_pir = os.path.join(
|
||||
os.path.dirname(save_path),
|
||||
f"{os.path.basename(save_path)}_pir",
|
||||
"inference",
|
||||
)
|
||||
paddle.jit.save(model, save_path_pir)
|
||||
shutil.copy(
|
||||
yaml_path,
|
||||
os.path.join(
|
||||
os.path.dirname(save_path_pir), os.path.basename(yaml_path)
|
||||
),
|
||||
)
|
||||
else:
|
||||
if config["Global"].get("export_with_pir", False):
|
||||
paddle_version = version.parse(paddle.__version__)
|
||||
assert (
|
||||
paddle_version >= version.parse("3.0.0b2")
|
||||
or paddle_version == version.parse("0.0.0")
|
||||
) and os.environ.get("FLAGS_enable_pir_api", None) not in ["0", "False"]
|
||||
paddle.jit.save(model, save_path)
|
||||
else:
|
||||
model.forward.rollback()
|
||||
with paddle.pir_utils.OldIrGuard():
|
||||
model = dynamic_to_static(model, arch_config, logger, input_shape)
|
||||
paddle.jit.save(model, save_path)
|
||||
else:
|
||||
quanter.save_quantized_model(model, save_path)
|
||||
logger.info("inference model is saved to {}".format(save_path))
|
||||
|
@ -472,9 +461,16 @@ def export(config, base_model=None, save_path=None):
|
|||
sub_model_save_path,
|
||||
logger,
|
||||
yaml_path,
|
||||
config,
|
||||
)
|
||||
else:
|
||||
save_path = os.path.join(save_path, "inference")
|
||||
export_single_model(
|
||||
model, arch_config, save_path, logger, yaml_path, input_shape=input_shape
|
||||
model,
|
||||
arch_config,
|
||||
save_path,
|
||||
logger,
|
||||
yaml_path,
|
||||
config,
|
||||
input_shape=input_shape,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue