add --no-pretrained

pull/119/head
kaiyangzhou 2019-02-19 11:16:48 +00:00
parent 859b03cf19
commit fac7afade4
5 changed files with 6 additions and 4 deletions

View File

@ -133,6 +133,8 @@ def argument_parser():
# Architecture
# ************************************************************
parser.add_argument('-a', '--arch', type=str, default='resnet50')
parser.add_argument('--no-pretrained', action='store_true',
help='do not load pretrained weights')
# ************************************************************
# Test settings

View File

@ -54,7 +54,7 @@ def main():
trainloader, testloader_dict = dm.return_dataloaders()
print('Initializing model: {}'.format(args.arch))
model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, loss={'xent'}, use_gpu=use_gpu)
model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, loss={'xent'}, pretrained=not args.no_pretrained, use_gpu=use_gpu)
print('Model size: {:.3f} M'.format(count_num_param(model)))
if args.load_weights and check_isfile(args.load_weights):

View File

@ -55,7 +55,7 @@ def main():
trainloader, testloader_dict = dm.return_dataloaders()
print('Initializing model: {}'.format(args.arch))
model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, loss={'xent', 'htri'})
model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, loss={'xent', 'htri'}, pretrained=not args.no_pretrained, use_gpu=use_gpu)
print('Model size: {:.3f} M'.format(count_num_param(model)))
if args.load_weights and check_isfile(args.load_weights):

View File

@ -55,7 +55,7 @@ def main():
trainloader, testloader_dict = dm.return_dataloaders()
print('Initializing model: {}'.format(args.arch))
model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, loss={'xent'})
model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, loss={'xent'}, pretrained=not args.no_pretrained, use_gpu=use_gpu)
print('Model size: {:.3f} M'.format(count_num_param(model)))
if args.load_weights and check_isfile(args.load_weights):

View File

@ -56,7 +56,7 @@ def main():
trainloader, testloader_dict = dm.return_dataloaders()
print('Initializing model: {}'.format(args.arch))
model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, loss={'xent', 'htri'})
model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, loss={'xent', 'htri'}, pretrained=not args.no_pretrained, use_gpu=use_gpu)
print('Model size: {:.3f} M'.format(count_num_param(model)))
if args.load_weights and check_isfile(args.load_weights):