update combined loss for accelerate classification training speed

pull/1460/head
dongshuilong 2021-11-19 08:35:47 +00:00
parent 8174645941
commit ee3c643cc1
1 changed files with 11 additions and 5 deletions

View File

@ -44,12 +44,18 @@ class CombinedLoss(nn.Layer):
def __call__(self, input, batch):
loss_dict = {}
for idx, loss_func in enumerate(self.loss_func):
loss = loss_func(input, batch)
weight = self.loss_weight[idx]
loss = {key: loss[key] * weight for key in loss}
# just for accelerate classification traing speed
if len(self.loss_func) == 1:
loss = self.loss_func[0](input, batch)
loss_dict.update(loss)
loss_dict["loss"] = paddle.add_n(list(loss_dict.values()))
loss_dict["loss"] = list(loss.values())[0]
else:
for idx, loss_func in enumerate(self.loss_func):
loss = loss_func(input, batch)
weight = self.loss_weight[idx]
loss = {key: loss[key] * weight for key in loss}
loss_dict.update(loss)
loss_dict["loss"] = paddle.add_n(list(loss_dict.values()))
return loss_dict