add learning rate scheduler

pull/17/head
KaiyangZhou 2018-03-11 22:52:33 +00:00
parent 2f2180df47
commit 037bfe3fac
1 changed files with 12 additions and 1 deletions

View File

@ -12,6 +12,7 @@ import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torch.optim import lr_scheduler
import data_manager
from dataset_loader import ImageDataset
@ -32,7 +33,7 @@ parser.add_argument('--height', type=int, default=256,
parser.add_argument('--width', type=int, default=128,
help="width of an image (default: 128)")
# Optimization options
parser.add_argument('--max-epoch', default=10, type=int,
parser.add_argument('--max-epoch', default=60, type=int,
help="maximum epochs to run")
parser.add_argument('--start-epoch', default=0, type=int,
help="manual epoch number (useful on restarts)")
@ -41,6 +42,10 @@ parser.add_argument('--train-batch', default=32, type=int,
parser.add_argument('--test-batch', default=100, type=int, help="test batch size")
parser.add_argument('--lr', '--learning-rate', default=3e-04, type=float,
help="initial learning rate")
parser.add_argument('--stepsize', default=0, type=int,
help="stepsize to decay learning rate (>0 means this is enabled)")
parser.add_argument('--gamma', default=0.1, type=float,
help="learning rate decay")
parser.add_argument('--weight-decay', '--wd', default=5e-04, type=float,
help="weight decay (default: 5e-04)")
# Architecture
@ -116,6 +121,8 @@ def main():
criterion = CrossEntropyLabelSmooth(num_classes=dataset.num_train_pids, use_gpu=use_gpu)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
if args.stepsize > 0:
scheduler = lr_scheduler.StepLR(optimizer, step_size=args.stepsize, gamma=args.lr_decay)
start_epoch = args.start_epoch
if args.resume:
@ -137,7 +144,11 @@ def main():
for epoch in range(start_epoch, args.max_epoch):
print("==> Epoch {}/{}".format(epoch+1, args.max_epoch))
train(model, criterion, optimizer, trainloader, use_gpu)
if args.stepsize > 0: scheduler.step()
if (epoch+1) % args.eval_step == 0 or (epoch+1) == args.max_epoch:
print("==> Test")
rank1 = test(model, queryloader, galleryloader, use_gpu)