[Bug]Fix label smooth bug (#203)

* add convert_to_one_hot

* add test_label_smooth_loss

* add my label_smooth_loss

* fix CELoss bug

* test new label smooth loss

* LabelSmoothLoss downward compatibility

* add some comments

* remove the old version of LabelSmoothLoss

* add some comments

* add some comments

* add some comments

* add label smooth to config
pull/210/head
whcao 2021-04-13 13:53:56 +08:00 committed by GitHub
parent dcf61173f6
commit af83e981ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 182 additions and 64 deletions

View File

@ -0,0 +1,18 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=2048,
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, loss_weight=1.0),
topk=(1, 5),
))

View File

@ -1,12 +1,5 @@
_base_ = ['./resnet50_imagenet_bs256.py']
model = dict(
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=2048,
loss=dict(
type='LabelSmoothLoss',
loss_weight=1.0,
label_smooth_val=0.1,
num_classes=1000),
))
_base_ = [
'../_base_/models/resnet50_label_smooth.py',
'../_base_/datasets/imagenet_bs32.py',
'../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
]

View File

@ -3,12 +3,13 @@ from .asymmetric_loss import AsymmetricLoss, asymmetric_loss
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
cross_entropy)
from .focal_loss import FocalLoss, sigmoid_focal_loss
from .label_smooth_loss import LabelSmoothLoss, label_smooth
from .utils import reduce_loss, weight_reduce_loss, weighted_loss
from .label_smooth_loss import LabelSmoothLoss
from .utils import (convert_to_one_hot, reduce_loss, weight_reduce_loss,
weighted_loss)
__all__ = [
'accuracy', 'Accuracy', 'asymmetric_loss', 'AsymmetricLoss',
'cross_entropy', 'binary_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
'weight_reduce_loss', 'label_smooth', 'LabelSmoothLoss', 'weighted_loss',
'FocalLoss', 'sigmoid_focal_loss'
'weight_reduce_loss', 'LabelSmoothLoss', 'weighted_loss', 'FocalLoss',
'sigmoid_focal_loss', 'convert_to_one_hot'
]

View File

@ -54,6 +54,7 @@ def soft_cross_entropy(pred,
"""
# element-wise losses
loss = -label * F.log_softmax(pred, dim=-1)
loss = loss.sum(dim=-1)
# apply weights and do the reduction
if weight is not None:

View File

@ -1,50 +1,69 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from ..builder import LOSSES
from .utils import weight_reduce_loss
def label_smooth(pred,
label,
label_smooth_val,
avg_smooth_val,
weight=None,
reduction='mean',
avg_factor=None):
# # element-wise losses
one_hot = torch.zeros_like(pred)
one_hot.fill_(avg_smooth_val)
label = label.view(-1, 1)
one_hot.scatter_(1, label, 1 - label_smooth_val + avg_smooth_val)
loss = -torch.sum(F.log_softmax(pred, 1) * (one_hot.detach()))
# apply weights and do the reduction
if weight is not None:
weight = weight.float()
loss = weight_reduce_loss(
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
return loss
from .cross_entropy_loss import CrossEntropyLoss
from .utils import convert_to_one_hot
@LOSSES.register_module()
class LabelSmoothLoss(nn.Module):
class LabelSmoothLoss(CrossEntropyLoss):
"""Intializer for the label smoothed cross entropy loss.
This decreases gap between output scores and encourages generalization.
Labels provided to forward can be one-hot like vectors (NxC) or class
indices (Nx1).
This normalizes the labels to a sum of 1 based on the total count of
positive targets for a given sample before applying label smoothing.
Args:
label_smooth_val (float): Value to be added to each target entry
num_classes (int, optional): Number of classes. Defaults to None.
reduction (str): The method used to reduce the loss.
Options are "none", "mean" and "sum". Defaults to 'mean'.
loss_weight (float): Weight of the loss. Defaults to 1.0.
"""
def __init__(self,
label_smooth_val,
num_classes,
num_classes=None,
reduction='mean',
loss_weight=1.0):
super(LabelSmoothLoss, self).__init__()
self.label_smooth_val = label_smooth_val
self.avg_smooth_val = self.label_smooth_val / num_classes
self.reduction = reduction
self.loss_weight = loss_weight
super(LabelSmoothLoss, self).__init__(
use_sigmoid=False,
use_soft=True,
reduction=reduction,
loss_weight=loss_weight)
self._label_smooth_val = label_smooth_val
self.num_classes = num_classes
self._eps = np.finfo(np.float32).eps
self.cls_criterion = label_smooth
def generate_one_hot_like_label(self, label):
"""
This function takes one-hot or index label vectors and computes
one-hot like label vectors (float)
"""
label_shape_list = list(label.size())
# check if targets are inputted as class integers
if len(label_shape_list) == 1 or (len(label_shape_list) == 2
and label_shape_list[1] == 1):
label = convert_to_one_hot(label.view(-1, 1), self.num_classes)
return label.float()
def smooth_label(self, one_hot_like_label):
"""
This function takes one-hot like target vectors and
computes smoothed target vectors (normalized)
according to the loss's smoothing parameter
"""
assert self.num_classes > 0
one_hot_like_label /= self._eps + one_hot_like_label.sum(
dim=1, keepdim=True)
smoothed_targets = one_hot_like_label + (
self._label_smooth_val / self.num_classes)
smoothed_targets /= self._eps + smoothed_targets.sum(
dim=1, keepdim=True)
return smoothed_targets
def forward(self,
cls_score,
@ -53,16 +72,25 @@ class LabelSmoothLoss(nn.Module):
avg_factor=None,
reduction_override=None,
**kwargs):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_cls = self.loss_weight * self.cls_criterion(
if self.num_classes is not None:
assert self.num_classes == cls_score.shape[1], \
f'num_classes should equal to cls_score.shape[1], ' \
f'but got num_classes: {self.num_classes} and ' \
f'cls_score.shape[1]: {cls_score.shape[1]}'
else:
self.num_classes = cls_score.shape[1]
one_hot_like_label = self.generate_one_hot_like_label(label=label)
assert (
one_hot_like_label.shape == cls_score.shape
), f'LabelSmoothingCrossEntropyLoss requires output and target ' \
f'to be same shape, but got output.shape: {cls_score.shape}' \
f'and target.shape: {one_hot_like_label.shape}'
smoothed_label = self.smooth_label(
one_hot_like_label=one_hot_like_label)
return super(LabelSmoothLoss, self).forward(
cls_score,
label,
self.label_smooth_val,
self.avg_smooth_val,
weight,
reduction=reduction,
smoothed_label,
weight=weight,
avg_factor=avg_factor,
reduction_override=reduction_override,
**kwargs)
return loss_cls

View File

@ -1,5 +1,6 @@
import functools
import torch
import torch.nn.functional as F
@ -96,3 +97,24 @@ def weighted_loss(loss_func):
return loss
return wrapper
def convert_to_one_hot(targets: torch.Tensor, classes) -> torch.Tensor:
"""This function converts target class indices to one-hot vectors,
given the number of classes.
Args:
targets (Tensor): The ground truth label of the prediction
with shape (N, 1)
classes (int): the number of classes.
Returns:
Tensor: Processed loss values.
"""
assert (torch.max(targets).item() <
classes), 'Class Index must be less than number of classes'
one_hot_targets = torch.zeros((targets.shape[0], classes),
dtype=torch.long,
device=targets.device)
one_hot_targets.scatter_(1, targets.long(), 1)
return one_hot_targets

View File

@ -28,3 +28,29 @@ def test_image_classifier():
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0
def test_image_classifier_with_label_smooth_loss():
# Test mixup in ImageClassifier
model_cfg = dict(
backbone=dict(
type='ResNet_CIFAR',
depth=50,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='MultiLabelLinearClsHead',
num_classes=10,
in_channels=2048,
loss=dict(type='LabelSmoothLoss', label_smooth_val=0.1)),
train_cfg=dict(mixup=dict(alpha=1.0, num_classes=10)))
img_classifier = ImageClassifier(**model_cfg)
img_classifier.init_weights()
imgs = torch.randn(16, 3, 32, 32)
label = torch.randint(0, 10, (16, ))
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0

View File

@ -82,10 +82,10 @@ def test_cross_entropy_loss():
reduction='mean',
loss_weight=1.0)
loss = build_loss(loss_cfg)
assert torch.allclose(loss(cls_score, label), torch.tensor(50.))
assert torch.allclose(loss(cls_score, label), torch.tensor(100.))
# test soft_ce_loss with weight
assert torch.allclose(
loss(cls_score, label, weight=weight), torch.tensor(25.))
loss(cls_score, label, weight=weight), torch.tensor(50.))
def test_focal_loss():
@ -105,3 +105,32 @@ def test_focal_loss():
# test focal_loss with weight
assert torch.allclose(
loss(cls_score, label, weight=weight), torch.tensor(0.8522 / 2))
def test_label_smooth_loss():
# test label smooth loss
cls_score = torch.tensor([[1., -1.]])
label = torch.tensor([0])
loss_cfg = dict(
type='LabelSmoothLoss',
reduction='mean',
label_smooth_val=0.1,
loss_weight=1.0)
loss = build_loss(loss_cfg)
assert loss(cls_score, label) - 0.2179 <= 0.0001
# test label smooth loss with weight
cls_score = torch.tensor([[1., -1.], [1., -1.]])
label = torch.tensor([0, 1])
weight = torch.tensor([0.5, 0.5])
loss_cfg = dict(
type='LabelSmoothLoss',
reduction='mean',
label_smooth_val=0.1,
loss_weight=1.0)
loss = build_loss(loss_cfg)
assert torch.allclose(
loss(cls_score, label, weight=weight),
loss(cls_score, label) / 2)