From bd18a9489eea97db763007ab7b520a88f548a4c9 Mon Sep 17 00:00:00 2001 From: KaiyangZhou Date: Tue, 14 Aug 2018 17:32:06 +0100 Subject: [PATCH] add lambda-xent and lambda-htri to balance training --- train_imgreid_xent_htri.py | 6 +++++- train_vidreid_xent_htri.py | 27 ++++++++++++++++++++------- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/train_imgreid_xent_htri.py b/train_imgreid_xent_htri.py index fa06fe8..12e10e8 100755 --- a/train_imgreid_xent_htri.py +++ b/train_imgreid_xent_htri.py @@ -76,6 +76,10 @@ parser.add_argument('--num-instances', type=int, default=4, help="number of instances per identity") parser.add_argument('--htri-only', action='store_true', default=False, help="if this is True, only htri loss is used in training") +parser.add_argument('--lambda-xent', type=float, default=1, + help="weight to balance cross entropy loss") +parser.add_argument('--lambda-htri', type=float, default=1, + help="weight to balance hard triplet loss") # Architecture parser.add_argument('-a', '--arch', type=str, default='resnet50', choices=models.get_names()) # Miscs @@ -278,7 +282,7 @@ def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, else: htri_loss = criterion_htri(features, pids) - loss = xent_loss + htri_loss + loss = args.lambda_xent * xent_loss + args.lambda_htri * htri_loss optimizer.zero_grad() loss.backward() optimizer.step() diff --git a/train_vidreid_xent_htri.py b/train_vidreid_xent_htri.py index a1120a7..61e0f27 100755 --- a/train_vidreid_xent_htri.py +++ b/train_vidreid_xent_htri.py @@ -19,7 +19,7 @@ import data_manager from dataset_loader import ImageDataset, VideoDataset import transforms as T import models -from losses import CrossEntropyLabelSmooth, TripletLoss +from losses import CrossEntropyLabelSmooth, TripletLoss, DeepSupervision from utils.iotools import save_checkpoint, check_isfile from utils.avgmeter import AverageMeter from utils.logger import Logger @@ -69,6 +69,10 @@ parser.add_argument('--num-instances', type=int, default=4, help="number of instances per identity") parser.add_argument('--htri-only', action='store_true', default=False, help="if this is True, only htri loss is used in training") +parser.add_argument('--lambda-xent', type=float, default=1, + help="weight to balance cross entropy loss") +parser.add_argument('--lambda-htri', type=float, default=1, + help="weight to balance hard triplet loss") # Architecture parser.add_argument('-a', '--arch', type=str, default='resnet50', choices=models.get_names()) parser.add_argument('--pool', type=str, default='avg', choices=['avg', 'max']) @@ -260,13 +264,22 @@ def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, outputs, features = model(imgs) if args.htri_only: - # only use hard triplet loss to train the network - loss = criterion_htri(features, pids) + if isinstance(features, tuple): + loss = DeepSupervision(criterion_htri, features, pids) + else: + loss = criterion_htri(features, pids) else: - # combine hard triplet loss with cross entropy loss - xent_loss = criterion_xent(outputs, pids) - htri_loss = criterion_htri(features, pids) - loss = xent_loss + htri_loss + if isinstance(outputs, tuple): + xent_loss = DeepSupervision(criterion_xent, outputs, pids) + else: + xent_loss = criterion_xent(outputs, pids) + + if isinstance(features, tuple): + htri_loss = DeepSupervision(criterion_htri, features, pids) + else: + htri_loss = criterion_htri(features, pids) + + loss = args.lambda_xent * xent_loss + args.lambda_htri * htri_loss optimizer.zero_grad() loss.backward() optimizer.step()