PaddleClas/ppcls/loss/celoss.py

126 lines
3.8 KiB
Python
Raw Normal View History

2021-05-31 14:20:48 +08:00
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn.functional as F
__all__ = ['CELoss', 'JSDivLoss', 'KLDivLoss']
class Loss(object):
"""
Loss
"""
2021-06-03 12:24:48 +08:00
2021-05-31 14:20:48 +08:00
def __init__(self, class_dim=1000, epsilon=None):
assert class_dim > 1, "class_dim=%d is not larger than 1" % (class_dim)
self._class_dim = class_dim
if epsilon is not None and epsilon >= 0.0 and epsilon <= 1.0:
self._epsilon = epsilon
self._label_smoothing = True #use label smoothing.(Actually, it is softmax label)
else:
self._epsilon = None
self._label_smoothing = False
#do label_smoothing
def _labelsmoothing(self, target):
if target.shape[-1] != self._class_dim:
2021-06-03 12:24:48 +08:00
one_hot_target = F.one_hot(
target,
self._class_dim) #do ont hot(23,34,46)-> 3 * _class_dim
2021-05-31 14:20:48 +08:00
else:
one_hot_target = target
#do label_smooth
2021-06-03 12:24:48 +08:00
soft_target = F.label_smooth(
one_hot_target,
epsilon=self._epsilon) #(1 - epsilon) * input + eposilon / K.
2021-05-31 14:20:48 +08:00
soft_target = paddle.reshape(soft_target, shape=[-1, self._class_dim])
return soft_target
def _crossentropy(self, input, target, use_pure_fp16=False):
if self._label_smoothing:
target = self._labelsmoothing(target)
2021-06-03 12:24:48 +08:00
input = -F.log_softmax(input, axis=-1) #softmax and do log
2021-05-31 14:20:48 +08:00
cost = paddle.sum(target * input, axis=-1) #sum
else:
2021-06-03 12:24:48 +08:00
cost = F.cross_entropy(input=input, label=target)
2021-05-31 14:20:48 +08:00
if use_pure_fp16:
avg_cost = paddle.sum(cost)
else:
avg_cost = paddle.mean(cost)
return avg_cost
def _kldiv(self, input, target, name=None):
eps = 1.0e-10
cost = target * paddle.log(
(target + eps) / (input + eps)) * self._class_dim
return cost
2021-06-03 12:24:48 +08:00
def _jsdiv(self, input,
target): #so the input and target is the fc output; no softmax
2021-05-31 14:20:48 +08:00
input = F.softmax(input)
2021-06-03 12:24:48 +08:00
target = F.softmax(target)
2021-05-31 14:20:48 +08:00
#two distribution
cost = self._kldiv(input, target) + self._kldiv(target, input)
cost = cost / 2
avg_cost = paddle.mean(cost)
return avg_cost
def __call__(self, input, target):
pass
class CELoss(Loss):
"""
Cross entropy loss
"""
def __init__(self, class_dim=1000, epsilon=None):
super(CELoss, self).__init__(class_dim, epsilon)
def __call__(self, input, target, use_pure_fp16=False):
2021-06-03 12:24:48 +08:00
if type(input) is dict:
logits = input["logits"]
else:
logits = input
2021-05-31 14:20:48 +08:00
cost = self._crossentropy(logits, target, use_pure_fp16)
return {"CELoss": cost}
2021-06-03 12:24:48 +08:00
2021-05-31 14:20:48 +08:00
class JSDivLoss(Loss):
"""
JSDiv loss
"""
2021-06-03 12:24:48 +08:00
2021-05-31 14:20:48 +08:00
def __init__(self, class_dim=1000, epsilon=None):
super(JSDivLoss, self).__init__(class_dim, epsilon)
def __call__(self, input, target):
cost = self._jsdiv(input, target)
return cost
class KLDivLoss(paddle.nn.Layer):
def __init__(self):
super(KLDivLoss, self).__init__()
def __call__(self, p, q, is_logit=True):
if is_logit:
p = paddle.nn.functional.softmax(p)
q = paddle.nn.functional.softmax(q)
2021-05-31 14:28:52 +08:00
return -(p * paddle.log(q + 1e-8)).sum(1).mean()