rm --htri-only and print Xent and Htri losses separately
parent
4ccdc324e4
commit
32e4d36293
2
args.py
2
args.py
|
@ -118,8 +118,6 @@ def argument_parser():
|
|||
help='margin for triplet loss')
|
||||
parser.add_argument('--num-instances', type=int, default=4,
|
||||
help='number of instances per identity')
|
||||
parser.add_argument('--htri-only', action='store_true',
|
||||
help='only use hard triplet loss')
|
||||
parser.add_argument('--lambda-xent', type=float, default=1,
|
||||
help='weight to balance cross entropy loss')
|
||||
parser.add_argument('--lambda-htri', type=float, default=1,
|
||||
|
|
|
@ -140,7 +140,8 @@ def main():
|
|||
|
||||
|
||||
def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu, fixbase=False):
|
||||
losses = AverageMeter()
|
||||
xent_losses = AverageMeter()
|
||||
htri_losses = AverageMeter()
|
||||
accs = AverageMeter()
|
||||
batch_time = AverageMeter()
|
||||
data_time = AverageMeter()
|
||||
|
@ -160,42 +161,39 @@ def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader,
|
|||
imgs, pids = imgs.cuda(), pids.cuda()
|
||||
|
||||
outputs, features = model(imgs)
|
||||
if args.htri_only:
|
||||
if isinstance(features, (tuple, list)):
|
||||
loss = DeepSupervision(criterion_htri, features, pids)
|
||||
else:
|
||||
loss = criterion_htri(features, pids)
|
||||
if isinstance(outputs, (tuple, list)):
|
||||
xent_loss = DeepSupervision(criterion_xent, outputs, pids)
|
||||
else:
|
||||
if isinstance(outputs, (tuple, list)):
|
||||
xent_loss = DeepSupervision(criterion_xent, outputs, pids)
|
||||
else:
|
||||
xent_loss = criterion_xent(outputs, pids)
|
||||
|
||||
if isinstance(features, (tuple, list)):
|
||||
htri_loss = DeepSupervision(criterion_htri, features, pids)
|
||||
else:
|
||||
htri_loss = criterion_htri(features, pids)
|
||||
|
||||
loss = args.lambda_xent * xent_loss + args.lambda_htri * htri_loss
|
||||
xent_loss = criterion_xent(outputs, pids)
|
||||
|
||||
if isinstance(features, (tuple, list)):
|
||||
htri_loss = DeepSupervision(criterion_htri, features, pids)
|
||||
else:
|
||||
htri_loss = criterion_htri(features, pids)
|
||||
|
||||
loss = args.lambda_xent * xent_loss + args.lambda_htri * htri_loss
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
batch_time.update(time.time() - end)
|
||||
|
||||
losses.update(loss.item(), pids.size(0))
|
||||
xent_losses.update(xent_loss.item(), pids.size(0))
|
||||
htri_losses.update(htri_loss.item(), pids.size(0))
|
||||
accs.update(accuracy(outputs, pids)[0])
|
||||
|
||||
if (batch_idx + 1) % args.print_freq == 0:
|
||||
print('Epoch: [{0}][{1}/{2}]\t'
|
||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
||||
'Data {data_time.val:.4f} ({data_time.avg:.4f})\t'
|
||||
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
|
||||
'Xent {xent.val:.4f} ({xent.avg:.4f})\t'
|
||||
'Htri {htri.val:.4f} ({htri.avg:.4f})\t'
|
||||
'Acc {acc.val:.2f} ({acc.avg:.2f})\t'.format(
|
||||
epoch + 1, batch_idx + 1, len(trainloader),
|
||||
batch_time=batch_time,
|
||||
data_time=data_time,
|
||||
loss=losses,
|
||||
xent=xent_losses,
|
||||
htri=htri_losses,
|
||||
acc=accs
|
||||
))
|
||||
|
||||
|
|
|
@ -142,7 +142,8 @@ def main():
|
|||
|
||||
|
||||
def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu, fixbase=False):
|
||||
losses = AverageMeter()
|
||||
xent_losses = AverageMeter()
|
||||
htri_losses = AverageMeter()
|
||||
accs = AverageMeter()
|
||||
batch_time = AverageMeter()
|
||||
data_time = AverageMeter()
|
||||
|
@ -162,42 +163,39 @@ def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader,
|
|||
imgs, pids = imgs.cuda(), pids.cuda()
|
||||
|
||||
outputs, features = model(imgs)
|
||||
if args.htri_only:
|
||||
if isinstance(features, (tuple, list)):
|
||||
loss = DeepSupervision(criterion_htri, features, pids)
|
||||
else:
|
||||
loss = criterion_htri(features, pids)
|
||||
if isinstance(outputs, (tuple, list)):
|
||||
xent_loss = DeepSupervision(criterion_xent, outputs, pids)
|
||||
else:
|
||||
if isinstance(outputs, (tuple, list)):
|
||||
xent_loss = DeepSupervision(criterion_xent, outputs, pids)
|
||||
else:
|
||||
xent_loss = criterion_xent(outputs, pids)
|
||||
xent_loss = criterion_xent(outputs, pids)
|
||||
|
||||
if isinstance(features, (tuple, list)):
|
||||
htri_loss = DeepSupervision(criterion_htri, features, pids)
|
||||
else:
|
||||
htri_loss = criterion_htri(features, pids)
|
||||
|
||||
if isinstance(features, (tuple, list)):
|
||||
htri_loss = DeepSupervision(criterion_htri, features, pids)
|
||||
else:
|
||||
htri_loss = criterion_htri(features, pids)
|
||||
|
||||
loss = args.lambda_xent * xent_loss + args.lambda_htri * htri_loss
|
||||
loss = args.lambda_xent * xent_loss + args.lambda_htri * htri_loss
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
batch_time.update(time.time() - end)
|
||||
|
||||
losses.update(loss.item(), pids.size(0))
|
||||
xent_losses.update(xent_loss.item(), pids.size(0))
|
||||
htri_losses.update(htri_loss.item(), pids.size(0))
|
||||
accs.update(accuracy(outputs, pids)[0])
|
||||
|
||||
if (batch_idx + 1) % args.print_freq == 0:
|
||||
print('Epoch: [{0}][{1}/{2}]\t'
|
||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
||||
'Data {data_time.val:.4f} ({data_time.avg:.4f})\t'
|
||||
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
|
||||
'Xent {xent.val:.4f} ({xent.avg:.4f})\t'
|
||||
'Htri {htri.val:.4f} ({htri.avg:.4f})\t'
|
||||
'Acc {acc.val:.2f} ({acc.avg:.2f})\t'.format(
|
||||
epoch + 1, batch_idx + 1, len(trainloader),
|
||||
batch_time=batch_time,
|
||||
data_time=data_time,
|
||||
loss=losses,
|
||||
xent=xent_losses,
|
||||
htri=htri_losses,
|
||||
acc=accs
|
||||
))
|
||||
|
||||
|
|
Loading…
Reference in New Issue