fix (#7587)
parent
0c8ac23851
commit
4589f51b5c
|
@ -151,17 +151,24 @@ def main():
|
||||||
|
|
||||||
arch_config = config["Architecture"]
|
arch_config = config["Architecture"]
|
||||||
|
|
||||||
arch_config = config["Architecture"]
|
if arch_config["algorithm"] == "SVTR" and arch_config["Head"][
|
||||||
|
"name"] != 'MultiHead':
|
||||||
|
input_shape = config["Eval"]["dataset"]["transforms"][-2][
|
||||||
|
'SVTRRecResizeImg']['image_shape']
|
||||||
|
else:
|
||||||
|
input_shape = None
|
||||||
|
|
||||||
if arch_config["algorithm"] in ["Distillation", ]: # distillation model
|
if arch_config["algorithm"] in ["Distillation", ]: # distillation model
|
||||||
archs = list(arch_config["Models"].values())
|
archs = list(arch_config["Models"].values())
|
||||||
for idx, name in enumerate(model.model_name_list):
|
for idx, name in enumerate(model.model_name_list):
|
||||||
sub_model_save_path = os.path.join(save_path, name, "inference")
|
sub_model_save_path = os.path.join(save_path, name, "inference")
|
||||||
export_single_model(model.model_list[idx], archs[idx],
|
export_single_model(model.model_list[idx], archs[idx],
|
||||||
sub_model_save_path, logger, quanter)
|
sub_model_save_path, logger, input_shape,
|
||||||
|
quanter)
|
||||||
else:
|
else:
|
||||||
save_path = os.path.join(save_path, "inference")
|
save_path = os.path.join(save_path, "inference")
|
||||||
export_single_model(model, arch_config, save_path, logger, quanter)
|
export_single_model(model, arch_config, save_path, logger, input_shape,
|
||||||
|
quanter)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue