More appropriate/correct loss name
parent
99ab1b1276
commit
e6c14427c0
|
@ -1 +1 @@
|
|||
from loss.cross_entropy import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy
|
||||
from loss.cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
|
|
@ -26,10 +26,10 @@ class LabelSmoothingCrossEntropy(nn.Module):
|
|||
return loss.mean()
|
||||
|
||||
|
||||
class SparseLabelCrossEntropy(nn.Module):
|
||||
class SoftTargetCrossEntropy(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(SparseLabelCrossEntropy, self).__init__()
|
||||
super(SoftTargetCrossEntropy, self).__init__()
|
||||
|
||||
def forward(self, x, target):
|
||||
loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
|
||||
|
|
4
train.py
4
train.py
|
@ -13,7 +13,7 @@ except ImportError:
|
|||
from data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target
|
||||
from models import create_model, resume_checkpoint
|
||||
from utils import *
|
||||
from loss import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy
|
||||
from loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
|
||||
from optim import create_optimizer
|
||||
from scheduler import create_scheduler
|
||||
|
||||
|
@ -261,7 +261,7 @@ def main():
|
|||
|
||||
if args.mixup > 0.:
|
||||
# smoothing is handled with mixup label transform
|
||||
train_loss_fn = SparseLabelCrossEntropy().cuda()
|
||||
train_loss_fn = SoftTargetCrossEntropy().cuda()
|
||||
validate_loss_fn = nn.CrossEntropyLoss().cuda()
|
||||
elif args.smoothing:
|
||||
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
|
||||
|
|
Loading…
Reference in New Issue