commit
a518e4531e
|
@ -37,15 +37,17 @@ def parse_args():
|
|||
parser.add_argument("--class_dim", type=int, default=1000)
|
||||
parser.add_argument("--load_static_weights", type=str2bool, default=False)
|
||||
parser.add_argument("--img_size", type=int, default=224)
|
||||
parser.add_argument("--multilabel", type=str2bool, default=False)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
class Net(paddle.nn.Layer):
|
||||
def __init__(self, net, class_dim, model):
|
||||
def __init__(self, net, class_dim, model, multilabel):
|
||||
super(Net, self).__init__()
|
||||
self.pre_net = net(class_dim=class_dim)
|
||||
self.model = model
|
||||
self.multilabel = multilabel
|
||||
|
||||
def eval(self):
|
||||
self.training = False
|
||||
|
@ -57,7 +59,7 @@ class Net(paddle.nn.Layer):
|
|||
x = self.pre_net(inputs)
|
||||
if self.model == "GoogLeNet":
|
||||
x = x[0]
|
||||
x = F.softmax(x)
|
||||
x = F.softmax(x) if not self.multilabel else F.sigmoid(x)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -65,7 +67,7 @@ def main():
|
|||
args = parse_args()
|
||||
|
||||
net = architectures.__dict__[args.model]
|
||||
model = Net(net, args.class_dim, args.model)
|
||||
model = Net(net, args.class_dim, args.model, args.multilabel)
|
||||
load_dygraph_pretrain(
|
||||
model.pre_net,
|
||||
path=args.pretrained_model,
|
||||
|
|
Loading…
Reference in New Issue