update combined loss for accelerate classification training speed
parent
8174645941
commit
ee3c643cc1
|
@ -44,12 +44,18 @@ class CombinedLoss(nn.Layer):
|
|||
|
||||
def __call__(self, input, batch):
|
||||
loss_dict = {}
|
||||
for idx, loss_func in enumerate(self.loss_func):
|
||||
loss = loss_func(input, batch)
|
||||
weight = self.loss_weight[idx]
|
||||
loss = {key: loss[key] * weight for key in loss}
|
||||
# just for accelerate classification traing speed
|
||||
if len(self.loss_func) == 1:
|
||||
loss = self.loss_func[0](input, batch)
|
||||
loss_dict.update(loss)
|
||||
loss_dict["loss"] = paddle.add_n(list(loss_dict.values()))
|
||||
loss_dict["loss"] = list(loss.values())[0]
|
||||
else:
|
||||
for idx, loss_func in enumerate(self.loss_func):
|
||||
loss = loss_func(input, batch)
|
||||
weight = self.loss_weight[idx]
|
||||
loss = {key: loss[key] * weight for key in loss}
|
||||
loss_dict.update(loss)
|
||||
loss_dict["loss"] = paddle.add_n(list(loss_dict.values()))
|
||||
return loss_dict
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue