fix loss
parent
024ddbd5b9
commit
21fedfe5cd
|
@ -0,0 +1,43 @@
|
|||
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
from paddle.nn import L1Loss
|
||||
from paddle.nn import MSELoss as L2Loss
|
||||
from paddle.nn import SmoothL1Loss
|
||||
|
||||
|
||||
class DistanceLoss(nn.Layer):
|
||||
"""
|
||||
DistanceLoss:
|
||||
mode: loss mode
|
||||
"""
|
||||
|
||||
def __init__(self, mode="l2", **kargs):
|
||||
super().__init__()
|
||||
assert mode in ["l1", "l2", "smooth_l1"]
|
||||
if mode == "l1":
|
||||
self.loss_func = nn.L1Loss(**kargs)
|
||||
elif mode == "l2":
|
||||
self.loss_func = nn.MSELoss(**kargs)
|
||||
elif mode == "smooth_l1":
|
||||
self.loss_func = nn.SmoothL1Loss(**kargs)
|
||||
self.mode = mode
|
||||
|
||||
def forward(self, x, y):
|
||||
loss = self.loss_func(x, y)
|
||||
return {"loss_{}".format(self.mode): loss}
|
|
@ -0,0 +1,141 @@
|
|||
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
from .celoss import CELoss
|
||||
from .dmlloss import DMLLoss
|
||||
from .distanceloss import DistanceLoss
|
||||
|
||||
|
||||
class DistillationCELoss(CELoss):
|
||||
"""
|
||||
DistillationCELoss
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_name_pairs=[],
|
||||
epsilon=None,
|
||||
key=None,
|
||||
name="loss_ce"):
|
||||
super().__init__(epsilon=epsilon)
|
||||
assert isinstance(model_name_pairs, list)
|
||||
self.key = key
|
||||
self.model_name_pairs = model_name_pairs
|
||||
self.name = name
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
for idx, pair in enumerate(self.model_name_pairs):
|
||||
out1 = predicts[pair[0]]
|
||||
out2 = predicts[pair[1]]
|
||||
if self.key is not None:
|
||||
out1 = out1[self.key]
|
||||
out2 = out2[self.key]
|
||||
loss = super().forward(out1, out2)
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}".format(key, pair[0], pair[1])] = loss[key]
|
||||
return loss_dict
|
||||
|
||||
|
||||
class DistillationGTCELoss(CELoss):
|
||||
"""
|
||||
DistillationGTCELoss
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_names=[],
|
||||
epsilon=None,
|
||||
key=None,
|
||||
name="loss_gt_ce"):
|
||||
super().__init__(epsilon=epsilon)
|
||||
assert isinstance(model_names, list)
|
||||
self.key = key
|
||||
self.model_names = model_names
|
||||
self.name = name
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
for idx, name in enumerate(self.model_names):
|
||||
out = predicts[name]
|
||||
if self.key is not None:
|
||||
out = out[self.key]
|
||||
loss = super().forward(out, batch)
|
||||
for key in loss:
|
||||
loss_dict["{}_{}".format(key, name)] = loss[key]
|
||||
return loss_dict
|
||||
|
||||
|
||||
class DistillationDMLLoss(DMLLoss):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_name_pairs=[],
|
||||
act=None,
|
||||
key=None,
|
||||
name="loss_dml"):
|
||||
super().__init__(act=act)
|
||||
assert isinstance(model_name_pairs, list)
|
||||
self.key = key
|
||||
self.model_name_pairs = model_name_pairs
|
||||
self.name = name
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
for idx, pair in enumerate(self.model_name_pairs):
|
||||
out1 = predicts[pair[0]]
|
||||
out2 = predicts[pair[1]]
|
||||
if self.key is not None:
|
||||
out1 = out1[self.key]
|
||||
out2 = out2[self.key]
|
||||
loss = super().forward(out1, out2)
|
||||
if isinstance(loss, dict):
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
|
||||
idx)] = loss[key]
|
||||
else:
|
||||
loss_dict["{}_{}".format(self.name, idx)] = loss
|
||||
return loss_dict
|
||||
|
||||
|
||||
class DistillationDistanceLoss(DistanceLoss):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
mode="l2",
|
||||
model_name_pairs=[],
|
||||
key=None,
|
||||
name="loss_",
|
||||
**kargs):
|
||||
super().__init__(mode=mode, **kargs)
|
||||
assert isinstance(model_name_pairs, list)
|
||||
self.key = key
|
||||
self.model_name_pairs = model_name_pairs
|
||||
self.name = name + "_l2"
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
for idx, pair in enumerate(self.model_name_pairs):
|
||||
out1 = predicts[pair[0]]
|
||||
out2 = predicts[pair[1]]
|
||||
if self.key is not None:
|
||||
out1 = out1[self.key]
|
||||
out2 = out2[self.key]
|
||||
loss = super().forward(out1, out2)
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[key]
|
||||
return loss_dict
|
Loading…
Reference in New Issue