Add pos_weight and support for summing over classes to BCE impl in train scripts
parent
f2fdd97e9f
commit
b5a4fa9c3b
|
@ -2,7 +2,7 @@
|
|||
|
||||
Hacked together by / Copyright 2021 Ross Wightman
|
||||
"""
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -14,18 +14,30 @@ class BinaryCrossEntropy(nn.Module):
|
|||
NOTE for experiments comparing CE to BCE /w label smoothing, may remove
|
||||
"""
|
||||
def __init__(
|
||||
self, smoothing=0.1, target_threshold: Optional[float] = None, weight: Optional[torch.Tensor] = None,
|
||||
reduction: str = 'mean', pos_weight: Optional[torch.Tensor] = None):
|
||||
self,
|
||||
smoothing=0.1,
|
||||
target_threshold: Optional[float] = None,
|
||||
weight: Optional[torch.Tensor] = None,
|
||||
reduction: str = 'mean',
|
||||
sum_classes: bool = False,
|
||||
pos_weight: Optional[Union[torch.Tensor, float]] = None,
|
||||
):
|
||||
super(BinaryCrossEntropy, self).__init__()
|
||||
assert 0. <= smoothing < 1.0
|
||||
if pos_weight is not None:
|
||||
if not isinstance(pos_weight, torch.Tensor):
|
||||
pos_weight = torch.tensor(pos_weight)
|
||||
self.smoothing = smoothing
|
||||
self.target_threshold = target_threshold
|
||||
self.reduction = reduction
|
||||
self.reduction = 'none' if sum_classes else reduction
|
||||
self.sum_classes = sum_classes
|
||||
self.register_buffer('weight', weight)
|
||||
self.register_buffer('pos_weight', pos_weight)
|
||||
|
||||
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
assert x.shape[0] == target.shape[0]
|
||||
batch_size = x.shape[0]
|
||||
assert batch_size == target.shape[0]
|
||||
|
||||
if target.shape != x.shape:
|
||||
# NOTE currently assume smoothing or other label softening is applied upstream if targets are already sparse
|
||||
num_classes = x.shape[-1]
|
||||
|
@ -34,14 +46,20 @@ class BinaryCrossEntropy(nn.Module):
|
|||
on_value = 1. - self.smoothing + off_value
|
||||
target = target.long().view(-1, 1)
|
||||
target = torch.full(
|
||||
(target.size()[0], num_classes),
|
||||
(batch_size, num_classes),
|
||||
off_value,
|
||||
device=x.device, dtype=x.dtype).scatter_(1, target, on_value)
|
||||
|
||||
if self.target_threshold is not None:
|
||||
# Make target 0, or 1 if threshold set
|
||||
target = target.gt(self.target_threshold).to(dtype=target.dtype)
|
||||
return F.binary_cross_entropy_with_logits(
|
||||
|
||||
loss = F.binary_cross_entropy_with_logits(
|
||||
x, target,
|
||||
self.weight,
|
||||
pos_weight=self.pos_weight,
|
||||
reduction=self.reduction)
|
||||
reduction=self.reduction,
|
||||
)
|
||||
if self.sum_classes:
|
||||
loss = loss.sum(-1).mean()
|
||||
return loss
|
||||
|
|
19
train.py
19
train.py
|
@ -255,8 +255,12 @@ group.add_argument('--jsd-loss', action='store_true', default=False,
|
|||
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
|
||||
group.add_argument('--bce-loss', action='store_true', default=False,
|
||||
help='Enable BCE loss w/ Mixup/CutMix use.')
|
||||
group.add_argument('--bce-sum', action='store_true', default=False,
|
||||
help='Sum over classes when using BCE loss.')
|
||||
group.add_argument('--bce-target-thresh', type=float, default=None,
|
||||
help='Threshold for binarizing softened BCE targets (default: None, disabled)')
|
||||
help='Threshold for binarizing softened BCE targets (default: None, disabled).')
|
||||
group.add_argument('--bce-pos-weight', type=float, default=None,
|
||||
help='Positive weighting for BCE loss.')
|
||||
group.add_argument('--reprob', type=float, default=0., metavar='PCT',
|
||||
help='Random erase prob (default: 0.)')
|
||||
group.add_argument('--remode', type=str, default='pixel',
|
||||
|
@ -699,12 +703,21 @@ def main():
|
|||
elif mixup_active:
|
||||
# smoothing is handled with mixup target transform which outputs sparse, soft targets
|
||||
if args.bce_loss:
|
||||
train_loss_fn = BinaryCrossEntropy(target_threshold=args.bce_target_thresh)
|
||||
train_loss_fn = BinaryCrossEntropy(
|
||||
target_threshold=args.bce_target_thresh,
|
||||
sum_classes=args.bce_sum,
|
||||
pos_weight=args.bce_pos_weight,
|
||||
)
|
||||
else:
|
||||
train_loss_fn = SoftTargetCrossEntropy()
|
||||
elif args.smoothing:
|
||||
if args.bce_loss:
|
||||
train_loss_fn = BinaryCrossEntropy(smoothing=args.smoothing, target_threshold=args.bce_target_thresh)
|
||||
train_loss_fn = BinaryCrossEntropy(
|
||||
smoothing=args.smoothing,
|
||||
target_threshold=args.bce_target_thresh,
|
||||
sum_classes=args.bce_sum,
|
||||
pos_weight=args.bce_pos_weight,
|
||||
)
|
||||
else:
|
||||
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue