rm --htri-only and print Xent and Htri losses separately
parent
4ccdc324e4
commit
32e4d36293
2
args.py
2
args.py
|
@ -118,8 +118,6 @@ def argument_parser():
|
||||||
help='margin for triplet loss')
|
help='margin for triplet loss')
|
||||||
parser.add_argument('--num-instances', type=int, default=4,
|
parser.add_argument('--num-instances', type=int, default=4,
|
||||||
help='number of instances per identity')
|
help='number of instances per identity')
|
||||||
parser.add_argument('--htri-only', action='store_true',
|
|
||||||
help='only use hard triplet loss')
|
|
||||||
parser.add_argument('--lambda-xent', type=float, default=1,
|
parser.add_argument('--lambda-xent', type=float, default=1,
|
||||||
help='weight to balance cross entropy loss')
|
help='weight to balance cross entropy loss')
|
||||||
parser.add_argument('--lambda-htri', type=float, default=1,
|
parser.add_argument('--lambda-htri', type=float, default=1,
|
||||||
|
|
|
@ -140,7 +140,8 @@ def main():
|
||||||
|
|
||||||
|
|
||||||
def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu, fixbase=False):
|
def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu, fixbase=False):
|
||||||
losses = AverageMeter()
|
xent_losses = AverageMeter()
|
||||||
|
htri_losses = AverageMeter()
|
||||||
accs = AverageMeter()
|
accs = AverageMeter()
|
||||||
batch_time = AverageMeter()
|
batch_time = AverageMeter()
|
||||||
data_time = AverageMeter()
|
data_time = AverageMeter()
|
||||||
|
@ -160,42 +161,39 @@ def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader,
|
||||||
imgs, pids = imgs.cuda(), pids.cuda()
|
imgs, pids = imgs.cuda(), pids.cuda()
|
||||||
|
|
||||||
outputs, features = model(imgs)
|
outputs, features = model(imgs)
|
||||||
if args.htri_only:
|
if isinstance(outputs, (tuple, list)):
|
||||||
if isinstance(features, (tuple, list)):
|
xent_loss = DeepSupervision(criterion_xent, outputs, pids)
|
||||||
loss = DeepSupervision(criterion_htri, features, pids)
|
|
||||||
else:
|
|
||||||
loss = criterion_htri(features, pids)
|
|
||||||
else:
|
else:
|
||||||
if isinstance(outputs, (tuple, list)):
|
xent_loss = criterion_xent(outputs, pids)
|
||||||
xent_loss = DeepSupervision(criterion_xent, outputs, pids)
|
|
||||||
else:
|
|
||||||
xent_loss = criterion_xent(outputs, pids)
|
|
||||||
|
|
||||||
if isinstance(features, (tuple, list)):
|
if isinstance(features, (tuple, list)):
|
||||||
htri_loss = DeepSupervision(criterion_htri, features, pids)
|
htri_loss = DeepSupervision(criterion_htri, features, pids)
|
||||||
else:
|
else:
|
||||||
htri_loss = criterion_htri(features, pids)
|
htri_loss = criterion_htri(features, pids)
|
||||||
|
|
||||||
loss = args.lambda_xent * xent_loss + args.lambda_htri * htri_loss
|
loss = args.lambda_xent * xent_loss + args.lambda_htri * htri_loss
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
batch_time.update(time.time() - end)
|
batch_time.update(time.time() - end)
|
||||||
|
|
||||||
losses.update(loss.item(), pids.size(0))
|
xent_losses.update(xent_loss.item(), pids.size(0))
|
||||||
|
htri_losses.update(htri_loss.item(), pids.size(0))
|
||||||
accs.update(accuracy(outputs, pids)[0])
|
accs.update(accuracy(outputs, pids)[0])
|
||||||
|
|
||||||
if (batch_idx + 1) % args.print_freq == 0:
|
if (batch_idx + 1) % args.print_freq == 0:
|
||||||
print('Epoch: [{0}][{1}/{2}]\t'
|
print('Epoch: [{0}][{1}/{2}]\t'
|
||||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
||||||
'Data {data_time.val:.4f} ({data_time.avg:.4f})\t'
|
'Data {data_time.val:.4f} ({data_time.avg:.4f})\t'
|
||||||
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
|
'Xent {xent.val:.4f} ({xent.avg:.4f})\t'
|
||||||
|
'Htri {htri.val:.4f} ({htri.avg:.4f})\t'
|
||||||
'Acc {acc.val:.2f} ({acc.avg:.2f})\t'.format(
|
'Acc {acc.val:.2f} ({acc.avg:.2f})\t'.format(
|
||||||
epoch + 1, batch_idx + 1, len(trainloader),
|
epoch + 1, batch_idx + 1, len(trainloader),
|
||||||
batch_time=batch_time,
|
batch_time=batch_time,
|
||||||
data_time=data_time,
|
data_time=data_time,
|
||||||
loss=losses,
|
xent=xent_losses,
|
||||||
|
htri=htri_losses,
|
||||||
acc=accs
|
acc=accs
|
||||||
))
|
))
|
||||||
|
|
||||||
|
|
|
@ -142,7 +142,8 @@ def main():
|
||||||
|
|
||||||
|
|
||||||
def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu, fixbase=False):
|
def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu, fixbase=False):
|
||||||
losses = AverageMeter()
|
xent_losses = AverageMeter()
|
||||||
|
htri_losses = AverageMeter()
|
||||||
accs = AverageMeter()
|
accs = AverageMeter()
|
||||||
batch_time = AverageMeter()
|
batch_time = AverageMeter()
|
||||||
data_time = AverageMeter()
|
data_time = AverageMeter()
|
||||||
|
@ -162,42 +163,39 @@ def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader,
|
||||||
imgs, pids = imgs.cuda(), pids.cuda()
|
imgs, pids = imgs.cuda(), pids.cuda()
|
||||||
|
|
||||||
outputs, features = model(imgs)
|
outputs, features = model(imgs)
|
||||||
if args.htri_only:
|
if isinstance(outputs, (tuple, list)):
|
||||||
if isinstance(features, (tuple, list)):
|
xent_loss = DeepSupervision(criterion_xent, outputs, pids)
|
||||||
loss = DeepSupervision(criterion_htri, features, pids)
|
|
||||||
else:
|
|
||||||
loss = criterion_htri(features, pids)
|
|
||||||
else:
|
else:
|
||||||
if isinstance(outputs, (tuple, list)):
|
xent_loss = criterion_xent(outputs, pids)
|
||||||
xent_loss = DeepSupervision(criterion_xent, outputs, pids)
|
|
||||||
else:
|
|
||||||
xent_loss = criterion_xent(outputs, pids)
|
|
||||||
|
|
||||||
if isinstance(features, (tuple, list)):
|
if isinstance(features, (tuple, list)):
|
||||||
htri_loss = DeepSupervision(criterion_htri, features, pids)
|
htri_loss = DeepSupervision(criterion_htri, features, pids)
|
||||||
else:
|
else:
|
||||||
htri_loss = criterion_htri(features, pids)
|
htri_loss = criterion_htri(features, pids)
|
||||||
|
|
||||||
loss = args.lambda_xent * xent_loss + args.lambda_htri * htri_loss
|
loss = args.lambda_xent * xent_loss + args.lambda_htri * htri_loss
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
batch_time.update(time.time() - end)
|
batch_time.update(time.time() - end)
|
||||||
|
|
||||||
losses.update(loss.item(), pids.size(0))
|
xent_losses.update(xent_loss.item(), pids.size(0))
|
||||||
|
htri_losses.update(htri_loss.item(), pids.size(0))
|
||||||
accs.update(accuracy(outputs, pids)[0])
|
accs.update(accuracy(outputs, pids)[0])
|
||||||
|
|
||||||
if (batch_idx + 1) % args.print_freq == 0:
|
if (batch_idx + 1) % args.print_freq == 0:
|
||||||
print('Epoch: [{0}][{1}/{2}]\t'
|
print('Epoch: [{0}][{1}/{2}]\t'
|
||||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
||||||
'Data {data_time.val:.4f} ({data_time.avg:.4f})\t'
|
'Data {data_time.val:.4f} ({data_time.avg:.4f})\t'
|
||||||
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
|
'Xent {xent.val:.4f} ({xent.avg:.4f})\t'
|
||||||
|
'Htri {htri.val:.4f} ({htri.avg:.4f})\t'
|
||||||
'Acc {acc.val:.2f} ({acc.avg:.2f})\t'.format(
|
'Acc {acc.val:.2f} ({acc.avg:.2f})\t'.format(
|
||||||
epoch + 1, batch_idx + 1, len(trainloader),
|
epoch + 1, batch_idx + 1, len(trainloader),
|
||||||
batch_time=batch_time,
|
batch_time=batch_time,
|
||||||
data_time=data_time,
|
data_time=data_time,
|
||||||
loss=losses,
|
xent=xent_losses,
|
||||||
|
htri=htri_losses,
|
||||||
acc=accs
|
acc=accs
|
||||||
))
|
))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue