update printing in triplet engine (taking into account weight_t & weight_x)

pull/405/head
KaiyangZhou 2020-08-12 11:03:24 +01:00
parent de42729f41
commit 7c2ee1e6c6
1 changed files with 18 additions and 9 deletions

View File

@ -78,6 +78,8 @@ class ImageTripletEngine(Engine):
self.scheduler = scheduler
self.register_model('model', model, optimizer, scheduler)
assert weight_t >= 0 and weight_x >= 0
assert weight_t + weight_x > 0
self.weight_t = weight_t
self.weight_x = weight_x
@ -96,18 +98,25 @@ class ImageTripletEngine(Engine):
pids = pids.cuda()
outputs, features = self.model(imgs)
loss_t = self.compute_loss(self.criterion_t, features, pids)
loss_x = self.compute_loss(self.criterion_x, outputs, pids)
loss = self.weight_t * loss_t + self.weight_x * loss_x
loss = 0
loss_summary = {}
if self.weight_t > 0:
loss_t = self.compute_loss(self.criterion_t, features, pids)
loss += self.weight_t * loss_t
loss_summary['loss_t'] = loss_t.item()
if self.weight_x > 0:
loss_x = self.compute_loss(self.criterion_x, outputs, pids)
loss += self.weight_x * loss_x
loss_summary['loss_x'] = loss_x.item()
loss_summary['acc'] = metrics.accuracy(outputs, pids)[0].item()
assert loss_summary
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
loss_summary = {
'loss_t': loss_t.item(),
'loss_x': loss_x.item(),
'acc': metrics.accuracy(outputs, pids)[0].item()
}
return loss_summary