[Bug]Fix label smooth bug (#203)
* add convert_to_one_hot * add test_label_smooth_loss * add my label_smooth_loss * fix CELoss bug * test new label smooth loss * LabelSmoothLoss downward compatibility * add some comments * remove the old version of LabelSmoothLoss * add some comments * add some comments * add some comments * add label smooth to configpull/210/head
parent
dcf61173f6
commit
af83e981ac
|
@ -0,0 +1,18 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
num_stages=4,
|
||||
out_indices=(3, ),
|
||||
style='pytorch'),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=2048,
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, loss_weight=1.0),
|
||||
topk=(1, 5),
|
||||
))
|
|
@ -1,12 +1,5 @@
|
|||
_base_ = ['./resnet50_imagenet_bs256.py']
|
||||
model = dict(
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=2048,
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss',
|
||||
loss_weight=1.0,
|
||||
label_smooth_val=0.1,
|
||||
num_classes=1000),
|
||||
))
|
||||
_base_ = [
|
||||
'../_base_/models/resnet50_label_smooth.py',
|
||||
'../_base_/datasets/imagenet_bs32.py',
|
||||
'../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
|
||||
]
|
||||
|
|
|
@ -3,12 +3,13 @@ from .asymmetric_loss import AsymmetricLoss, asymmetric_loss
|
|||
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
|
||||
cross_entropy)
|
||||
from .focal_loss import FocalLoss, sigmoid_focal_loss
|
||||
from .label_smooth_loss import LabelSmoothLoss, label_smooth
|
||||
from .utils import reduce_loss, weight_reduce_loss, weighted_loss
|
||||
from .label_smooth_loss import LabelSmoothLoss
|
||||
from .utils import (convert_to_one_hot, reduce_loss, weight_reduce_loss,
|
||||
weighted_loss)
|
||||
|
||||
__all__ = [
|
||||
'accuracy', 'Accuracy', 'asymmetric_loss', 'AsymmetricLoss',
|
||||
'cross_entropy', 'binary_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
|
||||
'weight_reduce_loss', 'label_smooth', 'LabelSmoothLoss', 'weighted_loss',
|
||||
'FocalLoss', 'sigmoid_focal_loss'
|
||||
'weight_reduce_loss', 'LabelSmoothLoss', 'weighted_loss', 'FocalLoss',
|
||||
'sigmoid_focal_loss', 'convert_to_one_hot'
|
||||
]
|
||||
|
|
|
@ -54,6 +54,7 @@ def soft_cross_entropy(pred,
|
|||
"""
|
||||
# element-wise losses
|
||||
loss = -label * F.log_softmax(pred, dim=-1)
|
||||
loss = loss.sum(dim=-1)
|
||||
|
||||
# apply weights and do the reduction
|
||||
if weight is not None:
|
||||
|
|
|
@ -1,50 +1,69 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from ..builder import LOSSES
|
||||
from .utils import weight_reduce_loss
|
||||
|
||||
|
||||
def label_smooth(pred,
|
||||
label,
|
||||
label_smooth_val,
|
||||
avg_smooth_val,
|
||||
weight=None,
|
||||
reduction='mean',
|
||||
avg_factor=None):
|
||||
# # element-wise losses
|
||||
one_hot = torch.zeros_like(pred)
|
||||
one_hot.fill_(avg_smooth_val)
|
||||
label = label.view(-1, 1)
|
||||
one_hot.scatter_(1, label, 1 - label_smooth_val + avg_smooth_val)
|
||||
|
||||
loss = -torch.sum(F.log_softmax(pred, 1) * (one_hot.detach()))
|
||||
|
||||
# apply weights and do the reduction
|
||||
if weight is not None:
|
||||
weight = weight.float()
|
||||
loss = weight_reduce_loss(
|
||||
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
|
||||
|
||||
return loss
|
||||
from .cross_entropy_loss import CrossEntropyLoss
|
||||
from .utils import convert_to_one_hot
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class LabelSmoothLoss(nn.Module):
|
||||
class LabelSmoothLoss(CrossEntropyLoss):
|
||||
"""Intializer for the label smoothed cross entropy loss.
|
||||
|
||||
This decreases gap between output scores and encourages generalization.
|
||||
Labels provided to forward can be one-hot like vectors (NxC) or class
|
||||
indices (Nx1).
|
||||
This normalizes the labels to a sum of 1 based on the total count of
|
||||
positive targets for a given sample before applying label smoothing.
|
||||
|
||||
Args:
|
||||
label_smooth_val (float): Value to be added to each target entry
|
||||
num_classes (int, optional): Number of classes. Defaults to None.
|
||||
reduction (str): The method used to reduce the loss.
|
||||
Options are "none", "mean" and "sum". Defaults to 'mean'.
|
||||
loss_weight (float): Weight of the loss. Defaults to 1.0.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
label_smooth_val,
|
||||
num_classes,
|
||||
num_classes=None,
|
||||
reduction='mean',
|
||||
loss_weight=1.0):
|
||||
super(LabelSmoothLoss, self).__init__()
|
||||
self.label_smooth_val = label_smooth_val
|
||||
self.avg_smooth_val = self.label_smooth_val / num_classes
|
||||
self.reduction = reduction
|
||||
self.loss_weight = loss_weight
|
||||
super(LabelSmoothLoss, self).__init__(
|
||||
use_sigmoid=False,
|
||||
use_soft=True,
|
||||
reduction=reduction,
|
||||
loss_weight=loss_weight)
|
||||
self._label_smooth_val = label_smooth_val
|
||||
self.num_classes = num_classes
|
||||
self._eps = np.finfo(np.float32).eps
|
||||
|
||||
self.cls_criterion = label_smooth
|
||||
def generate_one_hot_like_label(self, label):
|
||||
"""
|
||||
This function takes one-hot or index label vectors and computes
|
||||
one-hot like label vectors (float)
|
||||
"""
|
||||
label_shape_list = list(label.size())
|
||||
# check if targets are inputted as class integers
|
||||
if len(label_shape_list) == 1 or (len(label_shape_list) == 2
|
||||
and label_shape_list[1] == 1):
|
||||
label = convert_to_one_hot(label.view(-1, 1), self.num_classes)
|
||||
return label.float()
|
||||
|
||||
def smooth_label(self, one_hot_like_label):
|
||||
"""
|
||||
This function takes one-hot like target vectors and
|
||||
computes smoothed target vectors (normalized)
|
||||
according to the loss's smoothing parameter
|
||||
"""
|
||||
assert self.num_classes > 0
|
||||
one_hot_like_label /= self._eps + one_hot_like_label.sum(
|
||||
dim=1, keepdim=True)
|
||||
smoothed_targets = one_hot_like_label + (
|
||||
self._label_smooth_val / self.num_classes)
|
||||
smoothed_targets /= self._eps + smoothed_targets.sum(
|
||||
dim=1, keepdim=True)
|
||||
|
||||
return smoothed_targets
|
||||
|
||||
def forward(self,
|
||||
cls_score,
|
||||
|
@ -53,16 +72,25 @@ class LabelSmoothLoss(nn.Module):
|
|||
avg_factor=None,
|
||||
reduction_override=None,
|
||||
**kwargs):
|
||||
assert reduction_override in (None, 'none', 'mean', 'sum')
|
||||
reduction = (
|
||||
reduction_override if reduction_override else self.reduction)
|
||||
loss_cls = self.loss_weight * self.cls_criterion(
|
||||
if self.num_classes is not None:
|
||||
assert self.num_classes == cls_score.shape[1], \
|
||||
f'num_classes should equal to cls_score.shape[1], ' \
|
||||
f'but got num_classes: {self.num_classes} and ' \
|
||||
f'cls_score.shape[1]: {cls_score.shape[1]}'
|
||||
else:
|
||||
self.num_classes = cls_score.shape[1]
|
||||
one_hot_like_label = self.generate_one_hot_like_label(label=label)
|
||||
assert (
|
||||
one_hot_like_label.shape == cls_score.shape
|
||||
), f'LabelSmoothingCrossEntropyLoss requires output and target ' \
|
||||
f'to be same shape, but got output.shape: {cls_score.shape}' \
|
||||
f'and target.shape: {one_hot_like_label.shape}'
|
||||
smoothed_label = self.smooth_label(
|
||||
one_hot_like_label=one_hot_like_label)
|
||||
return super(LabelSmoothLoss, self).forward(
|
||||
cls_score,
|
||||
label,
|
||||
self.label_smooth_val,
|
||||
self.avg_smooth_val,
|
||||
weight,
|
||||
reduction=reduction,
|
||||
smoothed_label,
|
||||
weight=weight,
|
||||
avg_factor=avg_factor,
|
||||
reduction_override=reduction_override,
|
||||
**kwargs)
|
||||
return loss_cls
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import functools
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
|
@ -96,3 +97,24 @@ def weighted_loss(loss_func):
|
|||
return loss
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def convert_to_one_hot(targets: torch.Tensor, classes) -> torch.Tensor:
|
||||
"""This function converts target class indices to one-hot vectors,
|
||||
given the number of classes.
|
||||
|
||||
Args:
|
||||
targets (Tensor): The ground truth label of the prediction
|
||||
with shape (N, 1)
|
||||
classes (int): the number of classes.
|
||||
|
||||
Returns:
|
||||
Tensor: Processed loss values.
|
||||
"""
|
||||
assert (torch.max(targets).item() <
|
||||
classes), 'Class Index must be less than number of classes'
|
||||
one_hot_targets = torch.zeros((targets.shape[0], classes),
|
||||
dtype=torch.long,
|
||||
device=targets.device)
|
||||
one_hot_targets.scatter_(1, targets.long(), 1)
|
||||
return one_hot_targets
|
||||
|
|
|
@ -28,3 +28,29 @@ def test_image_classifier():
|
|||
|
||||
losses = img_classifier.forward_train(imgs, label)
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
|
||||
def test_image_classifier_with_label_smooth_loss():
|
||||
|
||||
# Test mixup in ImageClassifier
|
||||
model_cfg = dict(
|
||||
backbone=dict(
|
||||
type='ResNet_CIFAR',
|
||||
depth=50,
|
||||
num_stages=4,
|
||||
out_indices=(3, ),
|
||||
style='pytorch'),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='MultiLabelLinearClsHead',
|
||||
num_classes=10,
|
||||
in_channels=2048,
|
||||
loss=dict(type='LabelSmoothLoss', label_smooth_val=0.1)),
|
||||
train_cfg=dict(mixup=dict(alpha=1.0, num_classes=10)))
|
||||
img_classifier = ImageClassifier(**model_cfg)
|
||||
img_classifier.init_weights()
|
||||
imgs = torch.randn(16, 3, 32, 32)
|
||||
label = torch.randint(0, 10, (16, ))
|
||||
|
||||
losses = img_classifier.forward_train(imgs, label)
|
||||
assert losses['loss'].item() > 0
|
||||
|
|
|
@ -82,10 +82,10 @@ def test_cross_entropy_loss():
|
|||
reduction='mean',
|
||||
loss_weight=1.0)
|
||||
loss = build_loss(loss_cfg)
|
||||
assert torch.allclose(loss(cls_score, label), torch.tensor(50.))
|
||||
assert torch.allclose(loss(cls_score, label), torch.tensor(100.))
|
||||
# test soft_ce_loss with weight
|
||||
assert torch.allclose(
|
||||
loss(cls_score, label, weight=weight), torch.tensor(25.))
|
||||
loss(cls_score, label, weight=weight), torch.tensor(50.))
|
||||
|
||||
|
||||
def test_focal_loss():
|
||||
|
@ -105,3 +105,32 @@ def test_focal_loss():
|
|||
# test focal_loss with weight
|
||||
assert torch.allclose(
|
||||
loss(cls_score, label, weight=weight), torch.tensor(0.8522 / 2))
|
||||
|
||||
|
||||
def test_label_smooth_loss():
|
||||
# test label smooth loss
|
||||
cls_score = torch.tensor([[1., -1.]])
|
||||
label = torch.tensor([0])
|
||||
|
||||
loss_cfg = dict(
|
||||
type='LabelSmoothLoss',
|
||||
reduction='mean',
|
||||
label_smooth_val=0.1,
|
||||
loss_weight=1.0)
|
||||
loss = build_loss(loss_cfg)
|
||||
assert loss(cls_score, label) - 0.2179 <= 0.0001
|
||||
|
||||
# test label smooth loss with weight
|
||||
cls_score = torch.tensor([[1., -1.], [1., -1.]])
|
||||
label = torch.tensor([0, 1])
|
||||
weight = torch.tensor([0.5, 0.5])
|
||||
|
||||
loss_cfg = dict(
|
||||
type='LabelSmoothLoss',
|
||||
reduction='mean',
|
||||
label_smooth_val=0.1,
|
||||
loss_weight=1.0)
|
||||
loss = build_loss(loss_cfg)
|
||||
assert torch.allclose(
|
||||
loss(cls_score, label, weight=weight),
|
||||
loss(cls_score, label) / 2)
|
||||
|
|
Loading…
Reference in New Issue