mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update binary cross ent impl to use thresholding as an option (convert soft targets from mixup/cutmix to 0, 1)
This commit is contained in:
parent
5d6983c462
commit
0387e6057e
@ -1,4 +1,4 @@
|
|||||||
from .asymmetric_loss import AsymmetricLossMultiLabel, AsymmetricLossSingleLabel
|
from .asymmetric_loss import AsymmetricLossMultiLabel, AsymmetricLossSingleLabel
|
||||||
from .binary_cross_entropy import DenseBinaryCrossEntropy
|
from .binary_cross_entropy import BinaryCrossEntropy
|
||||||
from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
|
from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
|
||||||
from .jsd import JsdCrossEntropy
|
from .jsd import JsdCrossEntropy
|
||||||
|
@ -1,23 +1,47 @@
|
|||||||
|
""" Binary Cross Entropy w/ a few extras
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2021 Ross Wightman
|
||||||
|
"""
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
class DenseBinaryCrossEntropy(nn.Module):
|
class BinaryCrossEntropy(nn.Module):
|
||||||
""" BCE using one-hot from dense targets w/ label smoothing
|
""" BCE with optional one-hot from dense targets, label smoothing, thresholding
|
||||||
NOTE for experiments comparing CE to BCE /w label smoothing, may remove
|
NOTE for experiments comparing CE to BCE /w label smoothing, may remove
|
||||||
"""
|
"""
|
||||||
def __init__(self, smoothing=0.1):
|
def __init__(
|
||||||
super(DenseBinaryCrossEntropy, self).__init__()
|
self, smoothing=0.1, target_threshold: Optional[float] = None, weight: Optional[torch.Tensor] = None,
|
||||||
|
reduction: str = 'mean', pos_weight: Optional[torch.Tensor] = None):
|
||||||
|
super(BinaryCrossEntropy, self).__init__()
|
||||||
assert 0. <= smoothing < 1.0
|
assert 0. <= smoothing < 1.0
|
||||||
self.smoothing = smoothing
|
self.smoothing = smoothing
|
||||||
self.bce = nn.BCEWithLogitsLoss()
|
self.target_threshold = target_threshold
|
||||||
|
self.reduction = reduction
|
||||||
|
self.register_buffer('weight', weight)
|
||||||
|
self.register_buffer('pos_weight', pos_weight)
|
||||||
|
|
||||||
def forward(self, x, target):
|
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||||
|
assert x.shape[0] == target.shape[0]
|
||||||
|
if target.shape != x.shape:
|
||||||
|
# NOTE currently assume smoothing or other label softening is applied upstream if targets are already sparse
|
||||||
num_classes = x.shape[-1]
|
num_classes = x.shape[-1]
|
||||||
|
# FIXME should off/on be different for smoothing w/ BCE? Other impl out there differ
|
||||||
off_value = self.smoothing / num_classes
|
off_value = self.smoothing / num_classes
|
||||||
on_value = 1. - self.smoothing + off_value
|
on_value = 1. - self.smoothing + off_value
|
||||||
target = target.long().view(-1, 1)
|
target = target.long().view(-1, 1)
|
||||||
target = torch.full(
|
target = torch.full(
|
||||||
(target.size()[0], num_classes), off_value, device=x.device, dtype=x.dtype).scatter_(1, target, on_value)
|
(target.size()[0], num_classes),
|
||||||
return self.bce(x, target)
|
off_value,
|
||||||
|
device=x.device, dtype=x.dtype).scatter_(1, target, on_value)
|
||||||
|
if self.target_threshold is not None:
|
||||||
|
# Make target 0, or 1 if threshold set
|
||||||
|
target = target.gt(self.target_threshold).to(dtype=target.dtype)
|
||||||
|
return F.binary_cross_entropy_with_logits(
|
||||||
|
x, target,
|
||||||
|
self.weight,
|
||||||
|
pos_weight=self.pos_weight,
|
||||||
|
reduction=self.reduction)
|
||||||
|
@ -1,23 +1,23 @@
|
|||||||
|
""" Cross Entropy w/ smoothing or soft targets
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2021 Ross Wightman
|
||||||
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
class LabelSmoothingCrossEntropy(nn.Module):
|
class LabelSmoothingCrossEntropy(nn.Module):
|
||||||
"""
|
""" NLL loss with label smoothing.
|
||||||
NLL loss with label smoothing.
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, smoothing=0.1):
|
def __init__(self, smoothing=0.1):
|
||||||
"""
|
|
||||||
Constructor for the LabelSmoothing module.
|
|
||||||
:param smoothing: label smoothing factor
|
|
||||||
"""
|
|
||||||
super(LabelSmoothingCrossEntropy, self).__init__()
|
super(LabelSmoothingCrossEntropy, self).__init__()
|
||||||
assert smoothing < 1.0
|
assert smoothing < 1.0
|
||||||
self.smoothing = smoothing
|
self.smoothing = smoothing
|
||||||
self.confidence = 1. - smoothing
|
self.confidence = 1. - smoothing
|
||||||
|
|
||||||
def forward(self, x, target):
|
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||||
logprobs = F.log_softmax(x, dim=-1)
|
logprobs = F.log_softmax(x, dim=-1)
|
||||||
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
|
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
|
||||||
nll_loss = nll_loss.squeeze(1)
|
nll_loss = nll_loss.squeeze(1)
|
||||||
@ -31,6 +31,6 @@ class SoftTargetCrossEntropy(nn.Module):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(SoftTargetCrossEntropy, self).__init__()
|
super(SoftTargetCrossEntropy, self).__init__()
|
||||||
|
|
||||||
def forward(self, x, target):
|
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||||
loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
|
loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
|
||||||
return loss.mean()
|
return loss.mean()
|
||||||
|
8
train.py
8
train.py
@ -190,6 +190,8 @@ parser.add_argument('--jsd-loss', action='store_true', default=False,
|
|||||||
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
|
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
|
||||||
parser.add_argument('--bce-loss', action='store_true', default=False,
|
parser.add_argument('--bce-loss', action='store_true', default=False,
|
||||||
help='Enable BCE loss w/ Mixup/CutMix use.')
|
help='Enable BCE loss w/ Mixup/CutMix use.')
|
||||||
|
parser.add_argument('--bce-target-thresh', type=float, default=None,
|
||||||
|
help='Threshold for binarizing softened BCE targets (default: None, disabled)')
|
||||||
parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
|
parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
|
||||||
help='Random erase prob (default: 0.)')
|
help='Random erase prob (default: 0.)')
|
||||||
parser.add_argument('--remode', type=str, default='pixel',
|
parser.add_argument('--remode', type=str, default='pixel',
|
||||||
@ -459,7 +461,7 @@ def main():
|
|||||||
else:
|
else:
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
_logger.info("Using native Torch DistributedDataParallel.")
|
_logger.info("Using native Torch DistributedDataParallel.")
|
||||||
model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1
|
model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.dist_bn)
|
||||||
# NOTE: EMA model does not need to be wrapped by DDP
|
# NOTE: EMA model does not need to be wrapped by DDP
|
||||||
|
|
||||||
# setup learning rate schedule and starting epoch
|
# setup learning rate schedule and starting epoch
|
||||||
@ -558,12 +560,12 @@ def main():
|
|||||||
elif mixup_active:
|
elif mixup_active:
|
||||||
# smoothing is handled with mixup target transform which outputs sparse, soft targets
|
# smoothing is handled with mixup target transform which outputs sparse, soft targets
|
||||||
if args.bce_loss:
|
if args.bce_loss:
|
||||||
train_loss_fn = nn.BCEWithLogitsLoss()
|
train_loss_fn = BinaryCrossEntropy(target_threshold=args.bce_target_thresh)
|
||||||
else:
|
else:
|
||||||
train_loss_fn = SoftTargetCrossEntropy()
|
train_loss_fn = SoftTargetCrossEntropy()
|
||||||
elif args.smoothing:
|
elif args.smoothing:
|
||||||
if args.bce_loss:
|
if args.bce_loss:
|
||||||
train_loss_fn = DenseBinaryCrossEntropy(smoothing=args.smoothing)
|
train_loss_fn = BinaryCrossEntropy(smoothing=args.smoothing, target_threshold=args.bce_target_thresh)
|
||||||
else:
|
else:
|
||||||
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
|
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user