rm --htri-only and print Xent and Htri losses separately

pull/119/head
kaiyangzhou 2019-02-05 16:52:29 +00:00
parent 4ccdc324e4
commit 32e4d36293
3 changed files with 35 additions and 41 deletions

View File

@ -118,8 +118,6 @@ def argument_parser():
help='margin for triplet loss')
parser.add_argument('--num-instances', type=int, default=4,
help='number of instances per identity')
parser.add_argument('--htri-only', action='store_true',
help='only use hard triplet loss')
parser.add_argument('--lambda-xent', type=float, default=1,
help='weight to balance cross entropy loss')
parser.add_argument('--lambda-htri', type=float, default=1,

View File

@ -140,7 +140,8 @@ def main():
def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu, fixbase=False):
losses = AverageMeter()
xent_losses = AverageMeter()
htri_losses = AverageMeter()
accs = AverageMeter()
batch_time = AverageMeter()
data_time = AverageMeter()
@ -160,42 +161,39 @@ def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader,
imgs, pids = imgs.cuda(), pids.cuda()
outputs, features = model(imgs)
if args.htri_only:
if isinstance(features, (tuple, list)):
loss = DeepSupervision(criterion_htri, features, pids)
else:
loss = criterion_htri(features, pids)
if isinstance(outputs, (tuple, list)):
xent_loss = DeepSupervision(criterion_xent, outputs, pids)
else:
if isinstance(outputs, (tuple, list)):
xent_loss = DeepSupervision(criterion_xent, outputs, pids)
else:
xent_loss = criterion_xent(outputs, pids)
if isinstance(features, (tuple, list)):
htri_loss = DeepSupervision(criterion_htri, features, pids)
else:
htri_loss = criterion_htri(features, pids)
loss = args.lambda_xent * xent_loss + args.lambda_htri * htri_loss
xent_loss = criterion_xent(outputs, pids)
if isinstance(features, (tuple, list)):
htri_loss = DeepSupervision(criterion_htri, features, pids)
else:
htri_loss = criterion_htri(features, pids)
loss = args.lambda_xent * xent_loss + args.lambda_htri * htri_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
batch_time.update(time.time() - end)
losses.update(loss.item(), pids.size(0))
xent_losses.update(xent_loss.item(), pids.size(0))
htri_losses.update(htri_loss.item(), pids.size(0))
accs.update(accuracy(outputs, pids)[0])
if (batch_idx + 1) % args.print_freq == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.4f} ({data_time.avg:.4f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Xent {xent.val:.4f} ({xent.avg:.4f})\t'
'Htri {htri.val:.4f} ({htri.avg:.4f})\t'
'Acc {acc.val:.2f} ({acc.avg:.2f})\t'.format(
epoch + 1, batch_idx + 1, len(trainloader),
batch_time=batch_time,
data_time=data_time,
loss=losses,
xent=xent_losses,
htri=htri_losses,
acc=accs
))

View File

@ -142,7 +142,8 @@ def main():
def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu, fixbase=False):
losses = AverageMeter()
xent_losses = AverageMeter()
htri_losses = AverageMeter()
accs = AverageMeter()
batch_time = AverageMeter()
data_time = AverageMeter()
@ -162,42 +163,39 @@ def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader,
imgs, pids = imgs.cuda(), pids.cuda()
outputs, features = model(imgs)
if args.htri_only:
if isinstance(features, (tuple, list)):
loss = DeepSupervision(criterion_htri, features, pids)
else:
loss = criterion_htri(features, pids)
if isinstance(outputs, (tuple, list)):
xent_loss = DeepSupervision(criterion_xent, outputs, pids)
else:
if isinstance(outputs, (tuple, list)):
xent_loss = DeepSupervision(criterion_xent, outputs, pids)
else:
xent_loss = criterion_xent(outputs, pids)
xent_loss = criterion_xent(outputs, pids)
if isinstance(features, (tuple, list)):
htri_loss = DeepSupervision(criterion_htri, features, pids)
else:
htri_loss = criterion_htri(features, pids)
if isinstance(features, (tuple, list)):
htri_loss = DeepSupervision(criterion_htri, features, pids)
else:
htri_loss = criterion_htri(features, pids)
loss = args.lambda_xent * xent_loss + args.lambda_htri * htri_loss
loss = args.lambda_xent * xent_loss + args.lambda_htri * htri_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
batch_time.update(time.time() - end)
losses.update(loss.item(), pids.size(0))
xent_losses.update(xent_loss.item(), pids.size(0))
htri_losses.update(htri_loss.item(), pids.size(0))
accs.update(accuracy(outputs, pids)[0])
if (batch_idx + 1) % args.print_freq == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.4f} ({data_time.avg:.4f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Xent {xent.val:.4f} ({xent.avg:.4f})\t'
'Htri {htri.val:.4f} ({htri.avg:.4f})\t'
'Acc {acc.val:.2f} ({acc.avg:.2f})\t'.format(
epoch + 1, batch_idx + 1, len(trainloader),
batch_time=batch_time,
data_time=data_time,
loss=losses,
xent=xent_losses,
htri=htri_losses,
acc=accs
))