Merge pull request #1132 from cuicheng01/release/2.1

fix ml export_model
release/2.1
cuicheng01 2021-08-12 11:12:15 +08:00 committed by GitHub
commit a518e4531e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 3 deletions

View File

@ -37,15 +37,17 @@ def parse_args():
parser.add_argument("--class_dim", type=int, default=1000)
parser.add_argument("--load_static_weights", type=str2bool, default=False)
parser.add_argument("--img_size", type=int, default=224)
parser.add_argument("--multilabel", type=str2bool, default=False)
return parser.parse_args()
class Net(paddle.nn.Layer):
def __init__(self, net, class_dim, model):
def __init__(self, net, class_dim, model, multilabel):
super(Net, self).__init__()
self.pre_net = net(class_dim=class_dim)
self.model = model
self.multilabel = multilabel
def eval(self):
self.training = False
@ -57,7 +59,7 @@ class Net(paddle.nn.Layer):
x = self.pre_net(inputs)
if self.model == "GoogLeNet":
x = x[0]
x = F.softmax(x)
x = F.softmax(x) if not self.multilabel else F.sigmoid(x)
return x
@ -65,7 +67,7 @@ def main():
args = parse_args()
net = architectures.__dict__[args.model]
model = Net(net, args.class_dim, args.model)
model = Net(net, args.class_dim, args.model, args.multilabel)
load_dygraph_pretrain(
model.pre_net,
path=args.pretrained_model,