add export method
parent
932e0eace1
commit
d5637367a9
|
@ -57,18 +57,6 @@ def build_gallery_layer(configs, feature_extractor):
|
|||
return gallery_layer
|
||||
|
||||
|
||||
def export_fuse_model(model, config):
|
||||
model.eval()
|
||||
model.quanter.save_quantized_model(
|
||||
model.base_model,
|
||||
save_path,
|
||||
input_spec=[
|
||||
paddle.static.InputSpec(
|
||||
shape=[None] + config["Global"]["image_shape"],
|
||||
dtype='float32')
|
||||
])
|
||||
|
||||
|
||||
class FuseModel(paddle.nn.Layer):
|
||||
def __init__(self, configs):
|
||||
super().__init__()
|
||||
|
@ -85,12 +73,25 @@ class FuseModel(paddle.nn.Layer):
|
|||
return x
|
||||
|
||||
|
||||
def export_fuse_model(configs):
|
||||
fuse_model = FuseModel(configs)
|
||||
fuse_model.eval()
|
||||
save_path = configs["Global"]["save_inference_dir"]
|
||||
fuse_model.quanter.save_quantized_model(
|
||||
fuse_model,
|
||||
save_path,
|
||||
input_spec=[
|
||||
paddle.static.InputSpec(
|
||||
shape=[None] + configs["Global"]["image_shape"],
|
||||
dtype='float32')
|
||||
])
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
configs = parse_config(args.config)
|
||||
init_logger(name='gallery2fc')
|
||||
fuse_model = FuseModel(configs)
|
||||
# save_fuse_model(fuse_model)
|
||||
export_fuse_model(configs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue