update dataset to dm

pull/119/head
KaiyangZhou 2018-11-07 15:40:02 +00:00
parent 2ecd8e5704
commit 1b255992bb
2 changed files with 4 additions and 4 deletions

View File

@ -69,10 +69,10 @@ def main():
testloader_dict = dm.testloader_dict
print("Initializing model: {}".format(args.arch))
model = models.init_model(name=args.arch, num_classes=dataset.num_train_pids, loss={'xent'})
model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, loss={'xent'})
print("Model size: {:.3f} M".format(count_num_param(model)))
criterion = CrossEntropyLoss(num_classes=dataset.num_train_pids, use_gpu=use_gpu, label_smooth=args.label_smooth)
criterion = CrossEntropyLoss(num_classes=dm.num_train_pids, use_gpu=use_gpu, label_smooth=args.label_smooth)
optimizer = init_optim(args.optim, model.parameters(), args.lr, args.weight_decay)
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=args.stepsize, gamma=args.gamma)

View File

@ -70,10 +70,10 @@ def main():
testloader_dict = dm.testloader_dict
print("Initializing model: {}".format(args.arch))
model = models.init_model(name=args.arch, num_classes=dataset.num_train_pids, loss={'xent', 'htri'})
model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, loss={'xent', 'htri'})
print("Model size: {:.3f} M".format(count_num_param(model)))
criterion = CrossEntropyLoss(num_classes=dataset.num_train_pids, use_gpu=use_gpu, label_smooth=args.label_smooth)
criterion = CrossEntropyLoss(num_classes=dm.num_train_pids, use_gpu=use_gpu, label_smooth=args.label_smooth)
criterion_htri = TripletLoss(margin=args.margin)
optimizer = init_optim(args.optim, model.parameters(), args.lr, args.weight_decay)