update deep-supervision

This commit is contained in:
KaiyangZhou 2018-04-26 18:00:03 +01:00
parent ac38a1c141
commit 368b700f12
2 changed files with 14 additions and 14 deletions

View File

@ -17,7 +17,7 @@ import data_manager
from dataset_loader import ImageDataset
import transforms as T
import models
from losses import CrossEntropyLabelSmooth
from losses import CrossEntropyLabelSmooth, DeepSupervision
from utils import AverageMeter, Logger, save_checkpoint
from eval_metrics import evaluate
@ -129,7 +129,7 @@ def main():
)
print("Initializing model: {}".format(args.arch))
model = models.init_model(name=args.arch, num_classes=dataset.num_train_pids, loss={'xent'})
model = models.init_model(name=args.arch, num_classes=dataset.num_train_pids, loss={'xent'}, use_gpu=use_gpu)
print("Model size: {:.5f}M".format(sum(p.numel() for p in model.parameters())/1000000.0))
criterion = CrossEntropyLabelSmooth(num_classes=dataset.num_train_pids, use_gpu=use_gpu)
@ -190,7 +190,10 @@ def train(model, criterion, optimizer, trainloader, use_gpu):
if use_gpu:
imgs, pids = imgs.cuda(), pids.cuda()
outputs = model(imgs)
loss = criterion(outputs, pids)
if isinstance(outputs, tuple):
loss = DeepSupervision(criterion, outputs, pids)
else:
loss = criterion(outputs, pids)
optimizer.zero_grad()
loss.backward()
optimizer.step()

View File

@ -17,7 +17,7 @@ import data_manager
from dataset_loader import ImageDataset
import transforms as T
import models
from losses import CrossEntropyLabelSmooth, TripletLoss
from losses import CrossEntropyLabelSmooth, TripletLoss, DeepSupervision
from utils import AverageMeter, Logger, save_checkpoint
from eval_metrics import evaluate
from samplers import RandomIdentitySampler
@ -195,26 +195,23 @@ def train(model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu
model.train()
losses = AverageMeter()
def _deep_supervision(criterion, xs, y):
loss = 0.
for x in xs:
loss += criterion(x, y)
return loss
for batch_idx, (imgs, pids, _) in enumerate(trainloader):
if use_gpu:
imgs, pids = imgs.cuda(), pids.cuda()
outputs, features = model(imgs)
if args.htri_only:
if isinstance(features, tuple):
loss = _deep_supervision(criterion_htri, features, pids)
loss = DeepSupervision(criterion_htri, features, pids)
else:
loss = criterion_htri(features, pids)
else:
xent_loss = criterion_xent(outputs, pids)
if isinstance(outputs, tuple):
xent_loss = DeepSupervision(criterion_xent, outputs, pids)
else:
xent_loss = criterion_xent(outputs, pids)
if isinstance(features, tuple):
htri_loss = _deep_supervision(criterion_htri, features, pids)
htri_loss = DeepSupervision(criterion_htri, features, pids)
else:
htri_loss = criterion_htri(features, pids)