PaddleClas/ppcls/loss/__init__.py

106 lines
4.1 KiB
Python

import copy
import paddle
import paddle.nn as nn
from ppcls.utils import logger
from .celoss import CELoss, MixCELoss
from .googlenetloss import GoogLeNetLoss
from .centerloss import CenterLoss
from .contrasiveloss import ContrastiveLoss
from .contrasiveloss import ContrastiveLoss_XBM
from .emlloss import EmlLoss
from .msmloss import MSMLoss
from .npairsloss import NpairsLoss
from .trihardloss import TriHardLoss
from .triplet import TripletLoss, TripletLossV2
from .tripletangularmarginloss import TripletAngularMarginLoss, TripletAngularMarginLoss_XBM
from .supconloss import SupConLoss
from .softsuploss import SoftSupConLoss
from .ccssl_loss import CCSSLCELoss
from .pairwisecosface import PairwiseCosface
from .dmlloss import DMLLoss
from .distanceloss import DistanceLoss
from .softtargetceloss import SoftTargetCrossEntropy
from .distillationloss import DistillationCELoss
from .distillationloss import DistillationGTCELoss
from .distillationloss import DistillationDMLLoss
from .distillationloss import DistillationDistanceLoss
from .distillationloss import DistillationRKDLoss
from .distillationloss import DistillationKLDivLoss
from .distillationloss import DistillationDKDLoss
from .distillationloss import DistillationWSLLoss
from .distillationloss import DistillationSKDLoss
from .distillationloss import DistillationMultiLabelLoss
from .distillationloss import DistillationDISTLoss
from .distillationloss import DistillationPairLoss
from .multilabelloss import MultiLabelLoss
from .afdloss import AFDLoss
from .deephashloss import DSHSDLoss
from .deephashloss import LCDSHLoss
from .deephashloss import DCHLoss
from .metabinloss import CELossForMetaBIN
from .metabinloss import TripletLossForMetaBIN
from .metabinloss import InterDomainShuffleLoss
from .metabinloss import IntraDomainScatterLoss
class CombinedLoss(nn.Layer):
def __init__(self, config_list):
super().__init__()
loss_func = []
self.loss_weight = []
assert isinstance(config_list, list), (
'operator config should be a list')
for config in config_list:
assert isinstance(config,
dict) and len(config) == 1, "yaml format error"
name = list(config)[0]
param = config[name]
assert "weight" in param, "weight must be in param, but param just contains {}".format(
param.keys())
self.loss_weight.append(param.pop("weight"))
loss_func.append(eval(name)(**param))
self.loss_func = nn.LayerList(loss_func)
logger.debug("build loss {} success.".format(loss_func))
def __call__(self, input, batch):
loss_dict = {}
# just for accelerate classification traing speed
if len(self.loss_func) == 1:
loss = self.loss_func[0](input, batch)
loss_dict.update(loss)
loss_dict["loss"] = list(loss.values())[0]
else:
for idx, loss_func in enumerate(self.loss_func):
loss = loss_func(input, batch)
weight = self.loss_weight[idx]
loss = {key: loss[key] * weight for key in loss}
loss_dict.update(loss)
loss_dict["loss"] = paddle.add_n(list(loss_dict.values()))
return loss_dict
def build_loss(config, mode="train"):
train_loss_func, unlabel_train_loss_func, eval_loss_func = None, None, None
if mode == "train":
label_loss_info = config["Loss"]["Train"]
if label_loss_info:
train_loss_func = CombinedLoss(copy.deepcopy(label_loss_info))
unlabel_loss_info = config.get("UnLabelLoss", {}).get("Train", None)
if unlabel_loss_info:
unlabel_train_loss_func = CombinedLoss(
copy.deepcopy(unlabel_loss_info))
if mode == "eval" or (mode == "train" and
config["Global"]["eval_during_train"]):
loss_config = config.get("Loss", None)
if loss_config is not None:
loss_config = loss_config.get("Eval")
if loss_config is not None:
eval_loss_func = CombinedLoss(copy.deepcopy(loss_config))
return train_loss_func, unlabel_train_loss_func, eval_loss_func