add lambda-xent and lambda-htri to balance training
parent
28ecca9397
commit
bd18a9489e
|
@ -76,6 +76,10 @@ parser.add_argument('--num-instances', type=int, default=4,
|
||||||
help="number of instances per identity")
|
help="number of instances per identity")
|
||||||
parser.add_argument('--htri-only', action='store_true', default=False,
|
parser.add_argument('--htri-only', action='store_true', default=False,
|
||||||
help="if this is True, only htri loss is used in training")
|
help="if this is True, only htri loss is used in training")
|
||||||
|
parser.add_argument('--lambda-xent', type=float, default=1,
|
||||||
|
help="weight to balance cross entropy loss")
|
||||||
|
parser.add_argument('--lambda-htri', type=float, default=1,
|
||||||
|
help="weight to balance hard triplet loss")
|
||||||
# Architecture
|
# Architecture
|
||||||
parser.add_argument('-a', '--arch', type=str, default='resnet50', choices=models.get_names())
|
parser.add_argument('-a', '--arch', type=str, default='resnet50', choices=models.get_names())
|
||||||
# Miscs
|
# Miscs
|
||||||
|
@ -278,7 +282,7 @@ def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader,
|
||||||
else:
|
else:
|
||||||
htri_loss = criterion_htri(features, pids)
|
htri_loss = criterion_htri(features, pids)
|
||||||
|
|
||||||
loss = xent_loss + htri_loss
|
loss = args.lambda_xent * xent_loss + args.lambda_htri * htri_loss
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
|
@ -19,7 +19,7 @@ import data_manager
|
||||||
from dataset_loader import ImageDataset, VideoDataset
|
from dataset_loader import ImageDataset, VideoDataset
|
||||||
import transforms as T
|
import transforms as T
|
||||||
import models
|
import models
|
||||||
from losses import CrossEntropyLabelSmooth, TripletLoss
|
from losses import CrossEntropyLabelSmooth, TripletLoss, DeepSupervision
|
||||||
from utils.iotools import save_checkpoint, check_isfile
|
from utils.iotools import save_checkpoint, check_isfile
|
||||||
from utils.avgmeter import AverageMeter
|
from utils.avgmeter import AverageMeter
|
||||||
from utils.logger import Logger
|
from utils.logger import Logger
|
||||||
|
@ -69,6 +69,10 @@ parser.add_argument('--num-instances', type=int, default=4,
|
||||||
help="number of instances per identity")
|
help="number of instances per identity")
|
||||||
parser.add_argument('--htri-only', action='store_true', default=False,
|
parser.add_argument('--htri-only', action='store_true', default=False,
|
||||||
help="if this is True, only htri loss is used in training")
|
help="if this is True, only htri loss is used in training")
|
||||||
|
parser.add_argument('--lambda-xent', type=float, default=1,
|
||||||
|
help="weight to balance cross entropy loss")
|
||||||
|
parser.add_argument('--lambda-htri', type=float, default=1,
|
||||||
|
help="weight to balance hard triplet loss")
|
||||||
# Architecture
|
# Architecture
|
||||||
parser.add_argument('-a', '--arch', type=str, default='resnet50', choices=models.get_names())
|
parser.add_argument('-a', '--arch', type=str, default='resnet50', choices=models.get_names())
|
||||||
parser.add_argument('--pool', type=str, default='avg', choices=['avg', 'max'])
|
parser.add_argument('--pool', type=str, default='avg', choices=['avg', 'max'])
|
||||||
|
@ -260,13 +264,22 @@ def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader,
|
||||||
|
|
||||||
outputs, features = model(imgs)
|
outputs, features = model(imgs)
|
||||||
if args.htri_only:
|
if args.htri_only:
|
||||||
# only use hard triplet loss to train the network
|
if isinstance(features, tuple):
|
||||||
|
loss = DeepSupervision(criterion_htri, features, pids)
|
||||||
|
else:
|
||||||
loss = criterion_htri(features, pids)
|
loss = criterion_htri(features, pids)
|
||||||
else:
|
else:
|
||||||
# combine hard triplet loss with cross entropy loss
|
if isinstance(outputs, tuple):
|
||||||
|
xent_loss = DeepSupervision(criterion_xent, outputs, pids)
|
||||||
|
else:
|
||||||
xent_loss = criterion_xent(outputs, pids)
|
xent_loss = criterion_xent(outputs, pids)
|
||||||
|
|
||||||
|
if isinstance(features, tuple):
|
||||||
|
htri_loss = DeepSupervision(criterion_htri, features, pids)
|
||||||
|
else:
|
||||||
htri_loss = criterion_htri(features, pids)
|
htri_loss = criterion_htri(features, pids)
|
||||||
loss = xent_loss + htri_loss
|
|
||||||
|
loss = args.lambda_xent * xent_loss + args.lambda_htri * htri_loss
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
Loading…
Reference in New Issue