Add smooth loss

This commit is contained in:
Ross Wightman 2019-04-05 20:50:26 -07:00
parent b0158a593e
commit f2029dfb65
2 changed files with 27 additions and 0 deletions

1
loss/__init__.py Normal file
View File

@ -0,0 +1 @@
from loss.cross_entropy import LabelSmoothingCrossEntropy

26
loss/cross_entropy.py Normal file
View File

@ -0,0 +1,26 @@
import torch.nn as nn
import torch.nn.functional as F
class LabelSmoothingCrossEntropy(nn.Module):
"""
NLL loss with label smoothing.
"""
def __init__(self, smoothing=0.1):
"""
Constructor for the LabelSmoothing module.
:param smoothing: label smoothing factor
"""
super(LabelSmoothingCrossEntropy, self).__init__()
assert smoothing < 1.0
self.smoothing = smoothing
self.confidence = 1. - smoothing
def forward(self, x, target):
logprobs = F.log_softmax(x, dim=-1)
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
nll_loss = nll_loss.squeeze(1)
smooth_loss = -logprobs.mean(dim=-1)
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
return loss.mean()