update printing in triplet engine (taking into account weight_t & weight_x)
parent
de42729f41
commit
7c2ee1e6c6
|
@ -78,6 +78,8 @@ class ImageTripletEngine(Engine):
|
|||
self.scheduler = scheduler
|
||||
self.register_model('model', model, optimizer, scheduler)
|
||||
|
||||
assert weight_t >= 0 and weight_x >= 0
|
||||
assert weight_t + weight_x > 0
|
||||
self.weight_t = weight_t
|
||||
self.weight_x = weight_x
|
||||
|
||||
|
@ -96,18 +98,25 @@ class ImageTripletEngine(Engine):
|
|||
pids = pids.cuda()
|
||||
|
||||
outputs, features = self.model(imgs)
|
||||
loss_t = self.compute_loss(self.criterion_t, features, pids)
|
||||
loss_x = self.compute_loss(self.criterion_x, outputs, pids)
|
||||
loss = self.weight_t * loss_t + self.weight_x * loss_x
|
||||
|
||||
loss = 0
|
||||
loss_summary = {}
|
||||
|
||||
if self.weight_t > 0:
|
||||
loss_t = self.compute_loss(self.criterion_t, features, pids)
|
||||
loss += self.weight_t * loss_t
|
||||
loss_summary['loss_t'] = loss_t.item()
|
||||
|
||||
if self.weight_x > 0:
|
||||
loss_x = self.compute_loss(self.criterion_x, outputs, pids)
|
||||
loss += self.weight_x * loss_x
|
||||
loss_summary['loss_x'] = loss_x.item()
|
||||
loss_summary['acc'] = metrics.accuracy(outputs, pids)[0].item()
|
||||
|
||||
assert loss_summary
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
loss_summary = {
|
||||
'loss_t': loss_t.item(),
|
||||
'loss_x': loss_x.item(),
|
||||
'acc': metrics.accuracy(outputs, pids)[0].item()
|
||||
}
|
||||
|
||||
return loss_summary
|
||||
|
|
Loading…
Reference in New Issue