fix export_model to support dygraph

pull/291/head
gaotingquan 2020-09-23 22:51:52 +08:00
parent c5cf3c154a
commit fa471435a7
1 changed files with 5 additions and 3 deletions

View File

@ -38,7 +38,7 @@ class Net(paddle.nn.Layer):
self.pre_net = net(class_dim=class_dim)
self.to_static = to_static
# 请根据实际需求修改shape
# Please modify the 'shape' according to actual needs
@to_static(input_spec=[
paddle.static.InputSpec(
shape=[None, 3, 224, 224], dtype='float32')
@ -56,8 +56,10 @@ def main():
net = architectures.__dict__[args.model]
model = Net(net, to_static, args.class_dim)
para_state_dict = paddle.io.load_program_state(args.pretrained_model)
load_dygraph_pretrain(model, args.pretrained_model, True)
# Please set 'load_static_weights' to 'True' or 'False' according to the 'pretrained_model'
load_dygraph_pretrain(
model, path=args.pretrained_model, load_static_weights=True)
paddle.jit.save(model, args.output_path)