diff --git a/timm/loss/__init__.py b/timm/loss/__init__.py index a74bcb88..ea7f15f2 100644 --- a/timm/loss/__init__.py +++ b/timm/loss/__init__.py @@ -1,4 +1,4 @@ from .asymmetric_loss import AsymmetricLossMultiLabel, AsymmetricLossSingleLabel -from .binary_cross_entropy import DenseBinaryCrossEntropy +from .binary_cross_entropy import BinaryCrossEntropy from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy from .jsd import JsdCrossEntropy diff --git a/timm/loss/binary_cross_entropy.py b/timm/loss/binary_cross_entropy.py index 6da04dba..ed76c1e8 100644 --- a/timm/loss/binary_cross_entropy.py +++ b/timm/loss/binary_cross_entropy.py @@ -1,23 +1,47 @@ +""" Binary Cross Entropy w/ a few extras + +Hacked together by / Copyright 2021 Ross Wightman +""" +from typing import Optional + import torch import torch.nn as nn import torch.nn.functional as F -class DenseBinaryCrossEntropy(nn.Module): - """ BCE using one-hot from dense targets w/ label smoothing +class BinaryCrossEntropy(nn.Module): + """ BCE with optional one-hot from dense targets, label smoothing, thresholding NOTE for experiments comparing CE to BCE /w label smoothing, may remove """ - def __init__(self, smoothing=0.1): - super(DenseBinaryCrossEntropy, self).__init__() + def __init__( + self, smoothing=0.1, target_threshold: Optional[float] = None, weight: Optional[torch.Tensor] = None, + reduction: str = 'mean', pos_weight: Optional[torch.Tensor] = None): + super(BinaryCrossEntropy, self).__init__() assert 0. <= smoothing < 1.0 self.smoothing = smoothing - self.bce = nn.BCEWithLogitsLoss() + self.target_threshold = target_threshold + self.reduction = reduction + self.register_buffer('weight', weight) + self.register_buffer('pos_weight', pos_weight) - def forward(self, x, target): - num_classes = x.shape[-1] - off_value = self.smoothing / num_classes - on_value = 1. - self.smoothing + off_value - target = target.long().view(-1, 1) - target = torch.full( - (target.size()[0], num_classes), off_value, device=x.device, dtype=x.dtype).scatter_(1, target, on_value) - return self.bce(x, target) + def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + assert x.shape[0] == target.shape[0] + if target.shape != x.shape: + # NOTE currently assume smoothing or other label softening is applied upstream if targets are already sparse + num_classes = x.shape[-1] + # FIXME should off/on be different for smoothing w/ BCE? Other impl out there differ + off_value = self.smoothing / num_classes + on_value = 1. - self.smoothing + off_value + target = target.long().view(-1, 1) + target = torch.full( + (target.size()[0], num_classes), + off_value, + device=x.device, dtype=x.dtype).scatter_(1, target, on_value) + if self.target_threshold is not None: + # Make target 0, or 1 if threshold set + target = target.gt(self.target_threshold).to(dtype=target.dtype) + return F.binary_cross_entropy_with_logits( + x, target, + self.weight, + pos_weight=self.pos_weight, + reduction=self.reduction) diff --git a/timm/loss/cross_entropy.py b/timm/loss/cross_entropy.py index 60bef646..85198107 100644 --- a/timm/loss/cross_entropy.py +++ b/timm/loss/cross_entropy.py @@ -1,23 +1,23 @@ +""" Cross Entropy w/ smoothing or soft targets + +Hacked together by / Copyright 2021 Ross Wightman +""" + import torch import torch.nn as nn import torch.nn.functional as F class LabelSmoothingCrossEntropy(nn.Module): - """ - NLL loss with label smoothing. + """ NLL loss with label smoothing. """ def __init__(self, smoothing=0.1): - """ - Constructor for the LabelSmoothing module. - :param smoothing: label smoothing factor - """ super(LabelSmoothingCrossEntropy, self).__init__() assert smoothing < 1.0 self.smoothing = smoothing self.confidence = 1. - smoothing - def forward(self, x, target): + def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: logprobs = F.log_softmax(x, dim=-1) nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) nll_loss = nll_loss.squeeze(1) @@ -31,6 +31,6 @@ class SoftTargetCrossEntropy(nn.Module): def __init__(self): super(SoftTargetCrossEntropy, self).__init__() - def forward(self, x, target): + def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) return loss.mean() diff --git a/train.py b/train.py index 3943c7d0..55aba416 100755 --- a/train.py +++ b/train.py @@ -190,6 +190,8 @@ parser.add_argument('--jsd-loss', action='store_true', default=False, help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') parser.add_argument('--bce-loss', action='store_true', default=False, help='Enable BCE loss w/ Mixup/CutMix use.') +parser.add_argument('--bce-target-thresh', type=float, default=None, + help='Threshold for binarizing softened BCE targets (default: None, disabled)') parser.add_argument('--reprob', type=float, default=0., metavar='PCT', help='Random erase prob (default: 0.)') parser.add_argument('--remode', type=str, default='pixel', @@ -459,7 +461,7 @@ def main(): else: if args.local_rank == 0: _logger.info("Using native Torch DistributedDataParallel.") - model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1 + model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.dist_bn) # NOTE: EMA model does not need to be wrapped by DDP # setup learning rate schedule and starting epoch @@ -558,12 +560,12 @@ def main(): elif mixup_active: # smoothing is handled with mixup target transform which outputs sparse, soft targets if args.bce_loss: - train_loss_fn = nn.BCEWithLogitsLoss() + train_loss_fn = BinaryCrossEntropy(target_threshold=args.bce_target_thresh) else: train_loss_fn = SoftTargetCrossEntropy() elif args.smoothing: if args.bce_loss: - train_loss_fn = DenseBinaryCrossEntropy(smoothing=args.smoothing) + train_loss_fn = BinaryCrossEntropy(smoothing=args.smoothing, target_threshold=args.bce_target_thresh) else: train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: