add --htri-only

pull/17/head
KaiyangZhou 2018-03-27 10:51:10 +01:00
parent 338e36344e
commit 952e99e4fb
2 changed files with 20 additions and 6 deletions

View File

@ -52,6 +52,8 @@ parser.add_argument('--weight-decay', default=5e-04, type=float,
parser.add_argument('--margin', type=float, default=0.3, help="margin for triplet loss") parser.add_argument('--margin', type=float, default=0.3, help="margin for triplet loss")
parser.add_argument('--num-instances', type=int, default=4, parser.add_argument('--num-instances', type=int, default=4,
help="number of instances per identity") help="number of instances per identity")
parser.add_argument('--htri-only', action='store_true', default=False,
help="if this is True, only htri loss is used in training")
# Architecture # Architecture
parser.add_argument('-a', '--arch', type=str, default='resnet50', choices=models.get_names()) parser.add_argument('-a', '--arch', type=str, default='resnet50', choices=models.get_names())
# Miscs # Miscs
@ -184,9 +186,14 @@ def train(model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu
imgs, pids = imgs.cuda(), pids.cuda() imgs, pids = imgs.cuda(), pids.cuda()
imgs, pids = Variable(imgs), Variable(pids) imgs, pids = Variable(imgs), Variable(pids)
outputs, features = model(imgs) outputs, features = model(imgs)
xent_loss = criterion_xent(outputs, pids) if args.htri_only:
htri_loss = criterion_htri(features, pids) # only use hard triplet loss to train the network
loss = xent_loss + htri_loss loss = criterion_htri(features, pids)
else:
# combine hard triplet loss with cross entropy loss
xent_loss = criterion_xent(outputs, pids)
htri_loss = criterion_htri(features, pids)
loss = xent_loss + htri_loss
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()

View File

@ -53,6 +53,8 @@ parser.add_argument('--weight-decay', default=5e-04, type=float,
parser.add_argument('--margin', type=float, default=0.3, help="margin for triplet loss") parser.add_argument('--margin', type=float, default=0.3, help="margin for triplet loss")
parser.add_argument('--num-instances', type=int, default=4, parser.add_argument('--num-instances', type=int, default=4,
help="number of instances per identity") help="number of instances per identity")
parser.add_argument('--htri-only', action='store_true', default=False,
help="if this is True, only htri loss is used in training")
# Architecture # Architecture
parser.add_argument('-a', '--arch', type=str, default='resnet50', choices=models.get_names()) parser.add_argument('-a', '--arch', type=str, default='resnet50', choices=models.get_names())
parser.add_argument('--pool', type=str, default='avg', choices=['avg', 'max']) parser.add_argument('--pool', type=str, default='avg', choices=['avg', 'max'])
@ -192,9 +194,14 @@ def train(model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu
imgs, pids = imgs.cuda(), pids.cuda() imgs, pids = imgs.cuda(), pids.cuda()
imgs, pids = Variable(imgs), Variable(pids) imgs, pids = Variable(imgs), Variable(pids)
outputs, features = model(imgs) outputs, features = model(imgs)
xent_loss = criterion_xent(outputs, pids) if args.htri_only:
htri_loss = criterion_htri(features, pids) # only use hard triplet loss to train the network
loss = xent_loss + htri_loss loss = criterion_htri(features, pids)
else:
# combine hard triplet loss with cross entropy loss
xent_loss = criterion_xent(outputs, pids)
htri_loss = criterion_htri(features, pids)
loss = xent_loss + htri_loss
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()