Merge pull request #1021 from cuicheng01/release/2.0-beta
fix infer class_num bugs in release/2.0-beta branchrelease/2.0-beta
commit
772cccd63f
|
@ -39,6 +39,7 @@ def parse_args():
|
|||
parser.add_argument("-m", "--model", type=str)
|
||||
parser.add_argument("-p", "--pretrained_model", type=str)
|
||||
parser.add_argument("--use_gpu", type=str2bool, default=True)
|
||||
parser.add_argument("--class_num", type=int, default=1000)
|
||||
parser.add_argument(
|
||||
"--load_static_weights",
|
||||
type=str2bool,
|
||||
|
@ -122,7 +123,7 @@ def main():
|
|||
|
||||
paddle.disable_static(place)
|
||||
|
||||
net = architectures.__dict__[args.model]()
|
||||
net = architectures.__dict__[args.model](class_dim=args.class_num)
|
||||
load_dygraph_pretrain(net, args.pretrained_model, args.load_static_weights)
|
||||
image_list = get_image_list(args.image_file)
|
||||
for idx, filename in enumerate(image_list):
|
||||
|
|
Loading…
Reference in New Issue