From 952e99e4fb2489349fcf5a1fa4053abe7d63cb15 Mon Sep 17 00:00:00 2001 From: KaiyangZhou Date: Tue, 27 Mar 2018 10:51:10 +0100 Subject: [PATCH] add --htri-only --- train_img_model_xent_htri.py | 13 ++++++++++--- train_vid_model_xent_htri.py | 13 ++++++++++--- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/train_img_model_xent_htri.py b/train_img_model_xent_htri.py index e1599d4..e6a7411 100755 --- a/train_img_model_xent_htri.py +++ b/train_img_model_xent_htri.py @@ -52,6 +52,8 @@ parser.add_argument('--weight-decay', default=5e-04, type=float, parser.add_argument('--margin', type=float, default=0.3, help="margin for triplet loss") parser.add_argument('--num-instances', type=int, default=4, help="number of instances per identity") +parser.add_argument('--htri-only', action='store_true', default=False, + help="if this is True, only htri loss is used in training") # Architecture parser.add_argument('-a', '--arch', type=str, default='resnet50', choices=models.get_names()) # Miscs @@ -184,9 +186,14 @@ def train(model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu imgs, pids = imgs.cuda(), pids.cuda() imgs, pids = Variable(imgs), Variable(pids) outputs, features = model(imgs) - xent_loss = criterion_xent(outputs, pids) - htri_loss = criterion_htri(features, pids) - loss = xent_loss + htri_loss + if args.htri_only: + # only use hard triplet loss to train the network + loss = criterion_htri(features, pids) + else: + # combine hard triplet loss with cross entropy loss + xent_loss = criterion_xent(outputs, pids) + htri_loss = criterion_htri(features, pids) + loss = xent_loss + htri_loss optimizer.zero_grad() loss.backward() optimizer.step() diff --git a/train_vid_model_xent_htri.py b/train_vid_model_xent_htri.py index 9ce9878..47dfe6b 100755 --- a/train_vid_model_xent_htri.py +++ b/train_vid_model_xent_htri.py @@ -53,6 +53,8 @@ parser.add_argument('--weight-decay', default=5e-04, type=float, parser.add_argument('--margin', type=float, default=0.3, help="margin for triplet loss") parser.add_argument('--num-instances', type=int, default=4, help="number of instances per identity") +parser.add_argument('--htri-only', action='store_true', default=False, + help="if this is True, only htri loss is used in training") # Architecture parser.add_argument('-a', '--arch', type=str, default='resnet50', choices=models.get_names()) parser.add_argument('--pool', type=str, default='avg', choices=['avg', 'max']) @@ -192,9 +194,14 @@ def train(model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu imgs, pids = imgs.cuda(), pids.cuda() imgs, pids = Variable(imgs), Variable(pids) outputs, features = model(imgs) - xent_loss = criterion_xent(outputs, pids) - htri_loss = criterion_htri(features, pids) - loss = xent_loss + htri_loss + if args.htri_only: + # only use hard triplet loss to train the network + loss = criterion_htri(features, pids) + else: + # combine hard triplet loss with cross entropy loss + xent_loss = criterion_xent(outputs, pids) + htri_loss = criterion_htri(features, pids) + loss = xent_loss + htri_loss optimizer.zero_grad() loss.backward() optimizer.step()