diff --git a/train_vidreid_xent.py b/train_vidreid_xent.py index 73b6632..8bdd853 100755 --- a/train_vidreid_xent.py +++ b/train_vidreid_xent.py @@ -69,10 +69,10 @@ def main(): testloader_dict = dm.testloader_dict print("Initializing model: {}".format(args.arch)) - model = models.init_model(name=args.arch, num_classes=dataset.num_train_pids, loss={'xent'}) + model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, loss={'xent'}) print("Model size: {:.3f} M".format(count_num_param(model))) - criterion = CrossEntropyLoss(num_classes=dataset.num_train_pids, use_gpu=use_gpu, label_smooth=args.label_smooth) + criterion = CrossEntropyLoss(num_classes=dm.num_train_pids, use_gpu=use_gpu, label_smooth=args.label_smooth) optimizer = init_optim(args.optim, model.parameters(), args.lr, args.weight_decay) scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=args.stepsize, gamma=args.gamma) diff --git a/train_vidreid_xent_htri.py b/train_vidreid_xent_htri.py index 43c8178..2fcbcbd 100755 --- a/train_vidreid_xent_htri.py +++ b/train_vidreid_xent_htri.py @@ -70,10 +70,10 @@ def main(): testloader_dict = dm.testloader_dict print("Initializing model: {}".format(args.arch)) - model = models.init_model(name=args.arch, num_classes=dataset.num_train_pids, loss={'xent', 'htri'}) + model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, loss={'xent', 'htri'}) print("Model size: {:.3f} M".format(count_num_param(model))) - criterion = CrossEntropyLoss(num_classes=dataset.num_train_pids, use_gpu=use_gpu, label_smooth=args.label_smooth) + criterion = CrossEntropyLoss(num_classes=dm.num_train_pids, use_gpu=use_gpu, label_smooth=args.label_smooth) criterion_htri = TripletLoss(margin=args.margin) optimizer = init_optim(args.optim, model.parameters(), args.lr, args.weight_decay)