mmclassification/mmcls/models/losses/cross_entropy_loss.py

49 lines
1.4 KiB
Python
Raw Normal View History

import torch.nn as nn
import torch.nn.functional as F
from ..builder import LOSSES
from .utils import weight_reduce_loss
def cross_entropy(pred, label, weight=None, reduction='mean', avg_factor=None):
# element-wise losses
loss = F.cross_entropy(pred, label, reduction='none')
# apply weights and do the reduction
if weight is not None:
weight = weight.float()
loss = weight_reduce_loss(
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
return loss
@LOSSES.register_module()
class CrossEntropyLoss(nn.Module):
def __init__(self, reduction='mean', loss_weight=1.0):
super(CrossEntropyLoss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight
self.cls_criterion = cross_entropy
def forward(self,
cls_score,
label,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_cls = self.loss_weight * self.cls_criterion(
cls_score,
label,
weight,
reduction=reduction,
avg_factor=avg_factor,
**kwargs)
return loss_cls