fix export_model to support dygraph
parent
c5cf3c154a
commit
fa471435a7
|
@ -38,7 +38,7 @@ class Net(paddle.nn.Layer):
|
|||
self.pre_net = net(class_dim=class_dim)
|
||||
self.to_static = to_static
|
||||
|
||||
# 请根据实际需求修改shape
|
||||
# Please modify the 'shape' according to actual needs
|
||||
@to_static(input_spec=[
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 3, 224, 224], dtype='float32')
|
||||
|
@ -56,8 +56,10 @@ def main():
|
|||
net = architectures.__dict__[args.model]
|
||||
|
||||
model = Net(net, to_static, args.class_dim)
|
||||
para_state_dict = paddle.io.load_program_state(args.pretrained_model)
|
||||
load_dygraph_pretrain(model, args.pretrained_model, True)
|
||||
|
||||
# Please set 'load_static_weights' to 'True' or 'False' according to the 'pretrained_model'
|
||||
load_dygraph_pretrain(
|
||||
model, path=args.pretrained_model, load_static_weights=True)
|
||||
paddle.jit.save(model, args.output_path)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue