[Feature] Add focal loss (#1024)
* [Feature] add focal loss * fix the bug of 'non' reduction type * refine the implementation * add class_weight and ignore_index; support different alpha values for different classes * fixed some bugs * fix bugs * add comments * modify test * Update mmseg/models/losses/focal_loss.py Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn> * update test_focal_loss.py * modified the implementation * Update mmseg/models/losses/focal_loss.py Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com> * update focal_loss.py Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn> Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com>pull/1801/head
parent
504a5c6bd2
commit
07cd6c98e0
|
@ -3,11 +3,13 @@ from .accuracy import Accuracy, accuracy
|
|||
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
|
||||
cross_entropy, mask_cross_entropy)
|
||||
from .dice_loss import DiceLoss
|
||||
from .focal_loss import FocalLoss
|
||||
from .lovasz_loss import LovaszLoss
|
||||
from .utils import reduce_loss, weight_reduce_loss, weighted_loss
|
||||
|
||||
__all__ = [
|
||||
'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
|
||||
'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
|
||||
'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss'
|
||||
'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss',
|
||||
'FocalLoss'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,327 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Modified from https://github.com/open-mmlab/mmdetection
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss
|
||||
|
||||
from ..builder import LOSSES
|
||||
from .utils import weight_reduce_loss
|
||||
|
||||
|
||||
# This method is used when cuda is not available
|
||||
def py_sigmoid_focal_loss(pred,
|
||||
target,
|
||||
one_hot_target=None,
|
||||
weight=None,
|
||||
gamma=2.0,
|
||||
alpha=0.5,
|
||||
class_weight=None,
|
||||
valid_mask=None,
|
||||
reduction='mean',
|
||||
avg_factor=None):
|
||||
"""PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction with shape (N, C), C is the
|
||||
number of classes
|
||||
target (torch.Tensor): The learning label of the prediction with
|
||||
shape (N, C)
|
||||
one_hot_target (None): Placeholder. It should be None.
|
||||
weight (torch.Tensor, optional): Sample-wise loss weight.
|
||||
gamma (float, optional): The gamma for calculating the modulating
|
||||
factor. Defaults to 2.0.
|
||||
alpha (float | list[float], optional): A balanced form for Focal Loss.
|
||||
Defaults to 0.5.
|
||||
class_weight (list[float], optional): Weight of each class.
|
||||
Defaults to None.
|
||||
valid_mask (torch.Tensor, optional): A mask uses 1 to mark the valid
|
||||
samples and uses 0 to mark the ignored samples. Default: None.
|
||||
reduction (str, optional): The method used to reduce the loss into
|
||||
a scalar. Defaults to 'mean'.
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
"""
|
||||
if isinstance(alpha, list):
|
||||
alpha = pred.new_tensor(alpha)
|
||||
pred_sigmoid = pred.sigmoid()
|
||||
target = target.type_as(pred)
|
||||
one_minus_pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
|
||||
focal_weight = (alpha * target + (1 - alpha) *
|
||||
(1 - target)) * one_minus_pt.pow(gamma)
|
||||
|
||||
loss = F.binary_cross_entropy_with_logits(
|
||||
pred, target, reduction='none') * focal_weight
|
||||
final_weight = torch.ones(1, pred.size(1)).type_as(loss)
|
||||
if weight is not None:
|
||||
if weight.shape != loss.shape and weight.size(0) == loss.size(0):
|
||||
# For most cases, weight is of shape (N, ),
|
||||
# which means it does not have the second axis num_class
|
||||
weight = weight.view(-1, 1)
|
||||
assert weight.dim() == loss.dim()
|
||||
final_weight = final_weight * weight
|
||||
if class_weight is not None:
|
||||
final_weight = final_weight * pred.new_tensor(class_weight)
|
||||
if valid_mask is not None:
|
||||
final_weight = final_weight * valid_mask
|
||||
loss = weight_reduce_loss(loss, final_weight, reduction, avg_factor)
|
||||
return loss
|
||||
|
||||
|
||||
def sigmoid_focal_loss(pred,
|
||||
target,
|
||||
one_hot_target,
|
||||
weight=None,
|
||||
gamma=2.0,
|
||||
alpha=0.5,
|
||||
class_weight=None,
|
||||
valid_mask=None,
|
||||
reduction='mean',
|
||||
avg_factor=None):
|
||||
r"""A warpper of cuda version `Focal Loss
|
||||
<https://arxiv.org/abs/1708.02002>`_.
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction with shape (N, C), C is the number
|
||||
of classes.
|
||||
target (torch.Tensor): The learning label of the prediction. It's shape
|
||||
should be (N, )
|
||||
one_hot_target (torch.Tensor): The learning label with shape (N, C)
|
||||
weight (torch.Tensor, optional): Sample-wise loss weight.
|
||||
gamma (float, optional): The gamma for calculating the modulating
|
||||
factor. Defaults to 2.0.
|
||||
alpha (float | list[float], optional): A balanced form for Focal Loss.
|
||||
Defaults to 0.5.
|
||||
class_weight (list[float], optional): Weight of each class.
|
||||
Defaults to None.
|
||||
valid_mask (torch.Tensor, optional): A mask uses 1 to mark the valid
|
||||
samples and uses 0 to mark the ignored samples. Default: None.
|
||||
reduction (str, optional): The method used to reduce the loss into
|
||||
a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum".
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
"""
|
||||
# Function.apply does not accept keyword arguments, so the decorator
|
||||
# "weighted_loss" is not applicable
|
||||
final_weight = torch.ones(1, pred.size(1)).type_as(pred)
|
||||
if isinstance(alpha, list):
|
||||
# _sigmoid_focal_loss doesn't accept alpha of list type. Therefore, if
|
||||
# a list is given, we set the input alpha as 0.5. This means setting
|
||||
# equal weight for foreground class and background class. By
|
||||
# multiplying the loss by 2, the effect of setting alpha as 0.5 is
|
||||
# undone. The alpha of type list is used to regulate the loss in the
|
||||
# post-processing process.
|
||||
loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(),
|
||||
gamma, 0.5, None, 'none') * 2
|
||||
alpha = pred.new_tensor(alpha)
|
||||
final_weight = final_weight * (
|
||||
alpha * one_hot_target + (1 - alpha) * (1 - one_hot_target))
|
||||
else:
|
||||
loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(),
|
||||
gamma, alpha, None, 'none')
|
||||
if weight is not None:
|
||||
if weight.shape != loss.shape and weight.size(0) == loss.size(0):
|
||||
# For most cases, weight is of shape (N, ),
|
||||
# which means it does not have the second axis num_class
|
||||
weight = weight.view(-1, 1)
|
||||
assert weight.dim() == loss.dim()
|
||||
final_weight = final_weight * weight
|
||||
if class_weight is not None:
|
||||
final_weight = final_weight * pred.new_tensor(class_weight)
|
||||
if valid_mask is not None:
|
||||
final_weight = final_weight * valid_mask
|
||||
loss = weight_reduce_loss(loss, final_weight, reduction, avg_factor)
|
||||
return loss
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class FocalLoss(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
use_sigmoid=True,
|
||||
gamma=2.0,
|
||||
alpha=0.5,
|
||||
reduction='mean',
|
||||
class_weight=None,
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_focal'):
|
||||
"""`Focal Loss <https://arxiv.org/abs/1708.02002>`_
|
||||
Args:
|
||||
use_sigmoid (bool, optional): Whether to the prediction is
|
||||
used for sigmoid or softmax. Defaults to True.
|
||||
gamma (float, optional): The gamma for calculating the modulating
|
||||
factor. Defaults to 2.0.
|
||||
alpha (float | list[float], optional): A balanced form for Focal
|
||||
Loss. Defaults to 0.5. When a list is provided, the length
|
||||
of the list should be equal to the number of classes.
|
||||
Please be careful that this parameter is not the
|
||||
class-wise weight but the weight of a binary classification
|
||||
problem. This binary classification problem regards the
|
||||
pixels which belong to one class as the foreground
|
||||
and the other pixels as the background, each element in
|
||||
the list is the weight of the corresponding foreground class.
|
||||
The value of alpha or each element of alpha should be a float
|
||||
in the interval [0, 1]. If you want to specify the class-wise
|
||||
weight, please use `class_weight` parameter.
|
||||
reduction (str, optional): The method used to reduce the loss into
|
||||
a scalar. Defaults to 'mean'. Options are "none", "mean" and
|
||||
"sum".
|
||||
class_weight (list[float], optional): Weight of each class.
|
||||
Defaults to None.
|
||||
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
|
||||
loss_name (str, optional): Name of the loss item. If you want this
|
||||
loss item to be included into the backward graph, `loss_` must
|
||||
be the prefix of the name. Defaults to 'loss_focal'.
|
||||
"""
|
||||
super(FocalLoss, self).__init__()
|
||||
assert use_sigmoid is True, \
|
||||
'AssertionError: Only sigmoid focal loss supported now.'
|
||||
assert reduction in ('none', 'mean', 'sum'), \
|
||||
"AssertionError: reduction should be 'none', 'mean' or " \
|
||||
"'sum'"
|
||||
assert isinstance(alpha, (float, list)), \
|
||||
'AssertionError: alpha should be of type float'
|
||||
assert isinstance(gamma, float), \
|
||||
'AssertionError: gamma should be of type float'
|
||||
assert isinstance(loss_weight, float), \
|
||||
'AssertionError: loss_weight should be of type float'
|
||||
assert isinstance(loss_name, str), \
|
||||
'AssertionError: loss_name should be of type str'
|
||||
assert isinstance(class_weight, list) or class_weight is None, \
|
||||
'AssertionError: class_weight must be None or of type list'
|
||||
self.use_sigmoid = use_sigmoid
|
||||
self.gamma = gamma
|
||||
self.alpha = alpha
|
||||
self.reduction = reduction
|
||||
self.class_weight = class_weight
|
||||
self.loss_weight = loss_weight
|
||||
self._loss_name = loss_name
|
||||
|
||||
def forward(self,
|
||||
pred,
|
||||
target,
|
||||
weight=None,
|
||||
avg_factor=None,
|
||||
reduction_override=None,
|
||||
ignore_index=255,
|
||||
**kwargs):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction with shape
|
||||
(N, C) where C = number of classes, or
|
||||
(N, C, d_1, d_2, ..., d_K) with K≥1 in the
|
||||
case of K-dimensional loss.
|
||||
target (torch.Tensor): The ground truth. If containing class
|
||||
indices, shape (N) where each value is 0≤targets[i]≤C−1,
|
||||
or (N, d_1, d_2, ..., d_K) with K≥1 in the case of
|
||||
K-dimensional loss. If containing class probabilities,
|
||||
same shape as the input.
|
||||
weight (torch.Tensor, optional): The weight of loss for each
|
||||
prediction. Defaults to None.
|
||||
avg_factor (int, optional): Average factor that is used to
|
||||
average the loss. Defaults to None.
|
||||
reduction_override (str, optional): The reduction method used
|
||||
to override the original reduction method of the loss.
|
||||
Options are "none", "mean" and "sum".
|
||||
ignore_index (int, optional): The label index to be ignored.
|
||||
Default: 255
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss
|
||||
"""
|
||||
assert isinstance(ignore_index, int), \
|
||||
'ignore_index must be of type int'
|
||||
assert reduction_override in (None, 'none', 'mean', 'sum'), \
|
||||
"AssertionError: reduction should be 'none', 'mean' or " \
|
||||
"'sum'"
|
||||
assert pred.shape == target.shape or \
|
||||
(pred.size(0) == target.size(0) and
|
||||
pred.shape[2:] == target.shape[1:]), \
|
||||
"The shape of pred doesn't match the shape of target"
|
||||
|
||||
original_shape = pred.shape
|
||||
|
||||
# [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k]
|
||||
pred = pred.transpose(0, 1)
|
||||
# [C, B, d_1, d_2, ..., d_k] -> [C, N]
|
||||
pred = pred.reshape(pred.size(0), -1)
|
||||
# [C, N] -> [N, C]
|
||||
pred = pred.transpose(0, 1).contiguous()
|
||||
|
||||
if original_shape == target.shape:
|
||||
# target with shape [B, C, d_1, d_2, ...]
|
||||
# transform it's shape into [N, C]
|
||||
# [B, C, d_1, d_2, ...] -> [C, B, d_1, d_2, ..., d_k]
|
||||
target = target.transpose(0, 1)
|
||||
# [C, B, d_1, d_2, ..., d_k] -> [C, N]
|
||||
target = target.reshape(target.size(0), -1)
|
||||
# [C, N] -> [N, C]
|
||||
target = target.transpose(0, 1).contiguous()
|
||||
else:
|
||||
# target with shape [B, d_1, d_2, ...]
|
||||
# transform it's shape into [N, ]
|
||||
target = target.view(-1).contiguous()
|
||||
valid_mask = (target != ignore_index).view(-1, 1)
|
||||
# avoid raising error when using F.one_hot()
|
||||
target = torch.where(target == ignore_index, target.new_tensor(0),
|
||||
target)
|
||||
|
||||
reduction = (
|
||||
reduction_override if reduction_override else self.reduction)
|
||||
if self.use_sigmoid:
|
||||
num_classes = pred.size(1)
|
||||
if torch.cuda.is_available() and pred.is_cuda:
|
||||
if target.dim() == 1:
|
||||
one_hot_target = F.one_hot(target, num_classes=num_classes)
|
||||
else:
|
||||
one_hot_target = target
|
||||
target = target.argmax(dim=1)
|
||||
valid_mask = (target != ignore_index).view(-1, 1)
|
||||
calculate_loss_func = sigmoid_focal_loss
|
||||
else:
|
||||
one_hot_target = None
|
||||
if target.dim() == 1:
|
||||
target = F.one_hot(target, num_classes=num_classes)
|
||||
else:
|
||||
valid_mask = (target.argmax(dim=1) != ignore_index).view(
|
||||
-1, 1)
|
||||
calculate_loss_func = py_sigmoid_focal_loss
|
||||
|
||||
loss_cls = self.loss_weight * calculate_loss_func(
|
||||
pred,
|
||||
target,
|
||||
one_hot_target,
|
||||
weight,
|
||||
gamma=self.gamma,
|
||||
alpha=self.alpha,
|
||||
class_weight=self.class_weight,
|
||||
valid_mask=valid_mask,
|
||||
reduction=reduction,
|
||||
avg_factor=avg_factor)
|
||||
|
||||
if reduction == 'none':
|
||||
# [N, C] -> [C, N]
|
||||
loss_cls = loss_cls.transpose(0, 1)
|
||||
# [C, N] -> [C, B, d1, d2, ...]
|
||||
# original_shape: [B, C, d1, d2, ...]
|
||||
loss_cls = loss_cls.reshape(original_shape[1],
|
||||
original_shape[0],
|
||||
*original_shape[2:])
|
||||
# [C, B, d1, d2, ...] -> [B, C, d1, d2, ...]
|
||||
loss_cls = loss_cls.transpose(0, 1).contiguous()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return loss_cls
|
||||
|
||||
@property
|
||||
def loss_name(self):
|
||||
"""Loss Name.
|
||||
|
||||
This function must be implemented and will return the name of this
|
||||
loss function. This name will be used to combine different loss items
|
||||
by simple sum operation. In addition, if you want this loss item to be
|
||||
included into the backward graph, `loss_` must be the prefix of the
|
||||
name.
|
||||
Returns:
|
||||
str: The name of this loss item.
|
||||
"""
|
||||
return self._loss_name
|
|
@ -0,0 +1,216 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmseg.models import build_loss
|
||||
|
||||
|
||||
# test focal loss with use_sigmoid=False
|
||||
def test_use_sigmoid():
|
||||
# can't init with use_sigmoid=True
|
||||
with pytest.raises(AssertionError):
|
||||
loss_cfg = dict(type='FocalLoss', use_sigmoid=False)
|
||||
build_loss(loss_cfg)
|
||||
|
||||
# can't forward with use_sigmoid=True
|
||||
with pytest.raises(NotImplementedError):
|
||||
loss_cfg = dict(type='FocalLoss', use_sigmoid=True)
|
||||
focal_loss = build_loss(loss_cfg)
|
||||
focal_loss.use_sigmoid = False
|
||||
fake_pred = torch.rand(3, 4, 5, 6)
|
||||
fake_target = torch.randint(0, 4, (3, 5, 6))
|
||||
focal_loss(fake_pred, fake_target)
|
||||
|
||||
|
||||
# reduction type must be 'none', 'mean' or 'sum'
|
||||
def test_wrong_reduction_type():
|
||||
# can't init with wrong reduction
|
||||
with pytest.raises(AssertionError):
|
||||
loss_cfg = dict(type='FocalLoss', reduction='test')
|
||||
build_loss(loss_cfg)
|
||||
|
||||
# can't forward with wrong reduction override
|
||||
with pytest.raises(AssertionError):
|
||||
loss_cfg = dict(type='FocalLoss')
|
||||
focal_loss = build_loss(loss_cfg)
|
||||
fake_pred = torch.rand(3, 4, 5, 6)
|
||||
fake_target = torch.randint(0, 4, (3, 5, 6))
|
||||
focal_loss(fake_pred, fake_target, reduction_override='test')
|
||||
|
||||
|
||||
# test focal loss can handle input parameters with
|
||||
# unacceptable types
|
||||
def test_unacceptable_parameters():
|
||||
with pytest.raises(AssertionError):
|
||||
loss_cfg = dict(type='FocalLoss', gamma='test')
|
||||
build_loss(loss_cfg)
|
||||
with pytest.raises(AssertionError):
|
||||
loss_cfg = dict(type='FocalLoss', alpha='test')
|
||||
build_loss(loss_cfg)
|
||||
with pytest.raises(AssertionError):
|
||||
loss_cfg = dict(type='FocalLoss', class_weight='test')
|
||||
build_loss(loss_cfg)
|
||||
with pytest.raises(AssertionError):
|
||||
loss_cfg = dict(type='FocalLoss', loss_weight='test')
|
||||
build_loss(loss_cfg)
|
||||
with pytest.raises(AssertionError):
|
||||
loss_cfg = dict(type='FocalLoss', loss_name=123)
|
||||
build_loss(loss_cfg)
|
||||
|
||||
|
||||
# test if focal loss can be correctly initialize
|
||||
def test_init_focal_loss():
|
||||
loss_cfg = dict(
|
||||
type='FocalLoss',
|
||||
use_sigmoid=True,
|
||||
gamma=3.0,
|
||||
alpha=3.0,
|
||||
class_weight=[1, 2, 3, 4],
|
||||
reduction='sum')
|
||||
focal_loss = build_loss(loss_cfg)
|
||||
assert focal_loss.use_sigmoid is True
|
||||
assert focal_loss.gamma == 3.0
|
||||
assert focal_loss.alpha == 3.0
|
||||
assert focal_loss.reduction == 'sum'
|
||||
assert focal_loss.class_weight == [1, 2, 3, 4]
|
||||
assert focal_loss.loss_weight == 1.0
|
||||
assert focal_loss.loss_name == 'loss_focal'
|
||||
|
||||
|
||||
# test reduction override
|
||||
def test_reduction_override():
|
||||
loss_cfg = dict(type='FocalLoss', reduction='mean')
|
||||
focal_loss = build_loss(loss_cfg)
|
||||
fake_pred = torch.rand(3, 4, 5, 6)
|
||||
fake_target = torch.randint(0, 4, (3, 5, 6))
|
||||
loss = focal_loss(fake_pred, fake_target, reduction_override='none')
|
||||
assert loss.shape == fake_pred.shape
|
||||
|
||||
|
||||
# test wrong pred and target shape
|
||||
def test_wrong_pred_and_target_shape():
|
||||
loss_cfg = dict(type='FocalLoss')
|
||||
focal_loss = build_loss(loss_cfg)
|
||||
fake_pred = torch.rand(3, 4, 5, 6)
|
||||
fake_target = torch.randint(0, 4, (3, 2, 2))
|
||||
fake_target = F.one_hot(fake_target, num_classes=4)
|
||||
fake_target = fake_target.permute(0, 3, 1, 2)
|
||||
with pytest.raises(AssertionError):
|
||||
focal_loss(fake_pred, fake_target)
|
||||
|
||||
|
||||
# test forward with different shape of target
|
||||
def test_forward_with_different_shape_of_target():
|
||||
loss_cfg = dict(type='FocalLoss')
|
||||
focal_loss = build_loss(loss_cfg)
|
||||
|
||||
fake_pred = torch.rand(3, 4, 5, 6)
|
||||
fake_target = torch.randint(0, 4, (3, 5, 6))
|
||||
loss1 = focal_loss(fake_pred, fake_target)
|
||||
|
||||
fake_target = F.one_hot(fake_target, num_classes=4)
|
||||
fake_target = fake_target.permute(0, 3, 1, 2)
|
||||
loss2 = focal_loss(fake_pred, fake_target)
|
||||
assert loss1 == loss2
|
||||
|
||||
|
||||
# test forward with weight
|
||||
def test_forward_with_weight():
|
||||
loss_cfg = dict(type='FocalLoss')
|
||||
focal_loss = build_loss(loss_cfg)
|
||||
fake_pred = torch.rand(3, 4, 5, 6)
|
||||
fake_target = torch.randint(0, 4, (3, 5, 6))
|
||||
weight = torch.rand(3 * 5 * 6, 1)
|
||||
loss1 = focal_loss(fake_pred, fake_target, weight=weight)
|
||||
|
||||
weight2 = weight.view(-1)
|
||||
loss2 = focal_loss(fake_pred, fake_target, weight=weight2)
|
||||
|
||||
weight3 = weight.expand(3 * 5 * 6, 4)
|
||||
loss3 = focal_loss(fake_pred, fake_target, weight=weight3)
|
||||
assert loss1 == loss2 == loss3
|
||||
|
||||
|
||||
# test none reduction type
|
||||
def test_none_reduction_type():
|
||||
loss_cfg = dict(type='FocalLoss', reduction='none')
|
||||
focal_loss = build_loss(loss_cfg)
|
||||
fake_pred = torch.rand(3, 4, 5, 6)
|
||||
fake_target = torch.randint(0, 4, (3, 5, 6))
|
||||
loss = focal_loss(fake_pred, fake_target)
|
||||
assert loss.shape == fake_pred.shape
|
||||
|
||||
|
||||
# test the usage of class weight
|
||||
def test_class_weight():
|
||||
loss_cfg_cw = dict(
|
||||
type='FocalLoss', reduction='none', class_weight=[1.0, 2.0, 3.0, 4.0])
|
||||
loss_cfg = dict(type='FocalLoss', reduction='none')
|
||||
focal_loss_cw = build_loss(loss_cfg_cw)
|
||||
focal_loss = build_loss(loss_cfg)
|
||||
fake_pred = torch.rand(3, 4, 5, 6)
|
||||
fake_target = torch.randint(0, 4, (3, 5, 6))
|
||||
loss_cw = focal_loss_cw(fake_pred, fake_target)
|
||||
loss = focal_loss(fake_pred, fake_target)
|
||||
weight = torch.tensor([1, 2, 3, 4]).view(1, 4, 1, 1)
|
||||
assert (loss * weight == loss_cw).all()
|
||||
|
||||
|
||||
# test ignore index
|
||||
def test_ignore_index():
|
||||
loss_cfg = dict(type='FocalLoss', reduction='none')
|
||||
# ignore_index within C classes
|
||||
focal_loss = build_loss(loss_cfg)
|
||||
fake_pred = torch.rand(3, 5, 5, 6)
|
||||
fake_target = torch.randint(0, 4, (3, 5, 6))
|
||||
dim1 = torch.randint(0, 3, (4, ))
|
||||
dim2 = torch.randint(0, 5, (4, ))
|
||||
dim3 = torch.randint(0, 6, (4, ))
|
||||
fake_target[dim1, dim2, dim3] = 4
|
||||
loss1 = focal_loss(fake_pred, fake_target, ignore_index=4)
|
||||
one_hot_target = F.one_hot(fake_target, num_classes=5)
|
||||
one_hot_target = one_hot_target.permute(0, 3, 1, 2)
|
||||
loss2 = focal_loss(fake_pred, one_hot_target, ignore_index=4)
|
||||
assert (loss1 == loss2).all()
|
||||
assert (loss1[dim1, :, dim2, dim3] == 0).all()
|
||||
assert (loss2[dim1, :, dim2, dim3] == 0).all()
|
||||
|
||||
fake_pred = torch.rand(3, 4, 5, 6)
|
||||
fake_target = torch.randint(0, 4, (3, 5, 6))
|
||||
loss1 = focal_loss(fake_pred, fake_target, ignore_index=2)
|
||||
one_hot_target = F.one_hot(fake_target, num_classes=4)
|
||||
one_hot_target = one_hot_target.permute(0, 3, 1, 2)
|
||||
loss2 = focal_loss(fake_pred, one_hot_target, ignore_index=2)
|
||||
ignore_mask = one_hot_target == 2
|
||||
assert (loss1 == loss2).all()
|
||||
assert torch.sum(loss1 * ignore_mask) == 0
|
||||
assert torch.sum(loss2 * ignore_mask) == 0
|
||||
|
||||
# ignore index is not in prediction's classes
|
||||
fake_pred = torch.rand(3, 4, 5, 6)
|
||||
fake_target = torch.randint(0, 4, (3, 5, 6))
|
||||
dim1 = torch.randint(0, 3, (4, ))
|
||||
dim2 = torch.randint(0, 5, (4, ))
|
||||
dim3 = torch.randint(0, 6, (4, ))
|
||||
fake_target[dim1, dim2, dim3] = 255
|
||||
loss1 = focal_loss(fake_pred, fake_target, ignore_index=255)
|
||||
assert (loss1[dim1, :, dim2, dim3] == 0).all()
|
||||
|
||||
|
||||
# test list alpha
|
||||
def test_alpha():
|
||||
loss_cfg = dict(type='FocalLoss')
|
||||
focal_loss = build_loss(loss_cfg)
|
||||
alpha_float = 0.4
|
||||
alpha = [0.4, 0.4, 0.4, 0.4]
|
||||
alpha2 = [0.1, 0.3, 0.2, 0.1]
|
||||
fake_pred = torch.rand(3, 4, 5, 6)
|
||||
fake_target = torch.randint(0, 4, (3, 5, 6))
|
||||
focal_loss.alpha = alpha_float
|
||||
loss1 = focal_loss(fake_pred, fake_target)
|
||||
focal_loss.alpha = alpha
|
||||
loss2 = focal_loss(fake_pred, fake_target)
|
||||
assert loss1 == loss2
|
||||
focal_loss.alpha = alpha2
|
||||
focal_loss(fake_pred, fake_target)
|
Loading…
Reference in New Issue