add export method

pull/1599/head
weishengyu 2021-12-23 20:35:10 +08:00
parent 932e0eace1
commit d5637367a9
1 changed files with 15 additions and 14 deletions

View File

@ -57,18 +57,6 @@ def build_gallery_layer(configs, feature_extractor):
return gallery_layer
def export_fuse_model(model, config):
model.eval()
model.quanter.save_quantized_model(
model.base_model,
save_path,
input_spec=[
paddle.static.InputSpec(
shape=[None] + config["Global"]["image_shape"],
dtype='float32')
])
class FuseModel(paddle.nn.Layer):
def __init__(self, configs):
super().__init__()
@ -85,12 +73,25 @@ class FuseModel(paddle.nn.Layer):
return x
def export_fuse_model(configs):
fuse_model = FuseModel(configs)
fuse_model.eval()
save_path = configs["Global"]["save_inference_dir"]
fuse_model.quanter.save_quantized_model(
fuse_model,
save_path,
input_spec=[
paddle.static.InputSpec(
shape=[None] + configs["Global"]["image_shape"],
dtype='float32')
])
def main():
args = parse_args()
configs = parse_config(args.config)
init_logger(name='gallery2fc')
fuse_model = FuseModel(configs)
# save_fuse_model(fuse_model)
export_fuse_model(configs)
if __name__ == '__main__':