404 lines
14 KiB
Python
404 lines
14 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import pytest
|
|
import torch
|
|
|
|
from mmpretrain.models import build_loss
|
|
|
|
|
|
def test_asymmetric_loss():
|
|
# test asymmetric_loss
|
|
cls_score = torch.Tensor([[5, -5, 0], [5, -5, 0]])
|
|
label = torch.Tensor([[1, 0, 1], [0, 1, 0]])
|
|
weight = torch.tensor([0.5, 0.5])
|
|
|
|
loss_cfg = dict(
|
|
type='AsymmetricLoss',
|
|
gamma_pos=1.0,
|
|
gamma_neg=4.0,
|
|
clip=0.05,
|
|
reduction='mean',
|
|
loss_weight=1.0)
|
|
loss = build_loss(loss_cfg)
|
|
assert torch.allclose(loss(cls_score, label), torch.tensor(3.80845 / 3))
|
|
|
|
# test asymmetric_loss with weight
|
|
assert torch.allclose(
|
|
loss(cls_score, label, weight=weight), torch.tensor(3.80845 / 6))
|
|
|
|
# test asymmetric_loss without clip
|
|
loss_cfg = dict(
|
|
type='AsymmetricLoss',
|
|
gamma_pos=1.0,
|
|
gamma_neg=4.0,
|
|
clip=None,
|
|
reduction='mean',
|
|
loss_weight=1.0)
|
|
loss = build_loss(loss_cfg)
|
|
assert torch.allclose(loss(cls_score, label), torch.tensor(5.1186 / 3))
|
|
|
|
# test asymmetric_loss with softmax for single label task
|
|
cls_score = torch.Tensor([[5, -5, 0], [5, -5, 0]])
|
|
label = torch.Tensor([0, 1])
|
|
weight = torch.tensor([0.5, 0.5])
|
|
loss_cfg = dict(
|
|
type='AsymmetricLoss',
|
|
gamma_pos=0.0,
|
|
gamma_neg=0.0,
|
|
clip=None,
|
|
reduction='mean',
|
|
loss_weight=1.0,
|
|
use_sigmoid=False,
|
|
eps=1e-8)
|
|
loss = build_loss(loss_cfg)
|
|
# test asymmetric_loss for single label task without weight
|
|
assert torch.allclose(loss(cls_score, label), torch.tensor(2.5045))
|
|
# test asymmetric_loss for single label task with weight
|
|
assert torch.allclose(
|
|
loss(cls_score, label, weight=weight), torch.tensor(2.5045 * 0.5))
|
|
|
|
# test soft asymmetric_loss with softmax
|
|
cls_score = torch.Tensor([[5, -5, 0], [5, -5, 0]])
|
|
label = torch.Tensor([[1, 0, 0], [0, 1, 0]])
|
|
weight = torch.tensor([0.5, 0.5])
|
|
loss_cfg = dict(
|
|
type='AsymmetricLoss',
|
|
gamma_pos=0.0,
|
|
gamma_neg=0.0,
|
|
clip=None,
|
|
reduction='mean',
|
|
loss_weight=1.0,
|
|
use_sigmoid=False,
|
|
eps=1e-8)
|
|
loss = build_loss(loss_cfg)
|
|
# test soft asymmetric_loss with softmax without weight
|
|
assert torch.allclose(loss(cls_score, label), torch.tensor(2.5045))
|
|
# test soft asymmetric_loss with softmax with weight
|
|
assert torch.allclose(
|
|
loss(cls_score, label, weight=weight), torch.tensor(2.5045 * 0.5))
|
|
|
|
|
|
def test_cross_entropy_loss():
|
|
with pytest.raises(AssertionError):
|
|
# use_sigmoid and use_soft could not be set simultaneously
|
|
loss_cfg = dict(
|
|
type='CrossEntropyLoss', use_sigmoid=True, use_soft=True)
|
|
loss = build_loss(loss_cfg)
|
|
|
|
# test ce_loss
|
|
cls_score = torch.Tensor([[-1000, 1000], [100, -100]])
|
|
label = torch.Tensor([0, 1]).long()
|
|
class_weight = [0.3, 0.7] # class 0 : 0.3, class 1 : 0.7
|
|
weight = torch.tensor([0.6, 0.4])
|
|
|
|
# test ce_loss without class weight
|
|
loss_cfg = dict(type='CrossEntropyLoss', reduction='mean', loss_weight=1.0)
|
|
loss = build_loss(loss_cfg)
|
|
assert torch.allclose(loss(cls_score, label), torch.tensor(1100.))
|
|
# test ce_loss with weight
|
|
assert torch.allclose(
|
|
loss(cls_score, label, weight=weight), torch.tensor(640.))
|
|
|
|
# test ce_loss with class weight
|
|
loss_cfg = dict(
|
|
type='CrossEntropyLoss',
|
|
reduction='mean',
|
|
loss_weight=1.0,
|
|
class_weight=class_weight)
|
|
loss = build_loss(loss_cfg)
|
|
assert torch.allclose(loss(cls_score, label), torch.tensor(370.))
|
|
# test ce_loss with weight
|
|
assert torch.allclose(
|
|
loss(cls_score, label, weight=weight), torch.tensor(208.))
|
|
|
|
# test bce_loss
|
|
cls_score = torch.Tensor([[-200, 100], [500, -1000], [300, -300]])
|
|
label = torch.Tensor([[1, 0], [0, 1], [1, 0]])
|
|
weight = torch.Tensor([0.6, 0.4, 0.5])
|
|
class_weight = [0.1, 0.9] # class 0: 0.1, class 1: 0.9
|
|
pos_weight = [0.1, 0.2]
|
|
|
|
# test bce_loss without class weight
|
|
loss_cfg = dict(
|
|
type='CrossEntropyLoss',
|
|
use_sigmoid=True,
|
|
reduction='mean',
|
|
loss_weight=1.0)
|
|
loss = build_loss(loss_cfg)
|
|
assert torch.allclose(loss(cls_score, label), torch.tensor(300.))
|
|
# test ce_loss with weight
|
|
assert torch.allclose(
|
|
loss(cls_score, label, weight=weight), torch.tensor(130.))
|
|
|
|
# test bce_loss with class weight
|
|
loss_cfg = dict(
|
|
type='CrossEntropyLoss',
|
|
use_sigmoid=True,
|
|
reduction='mean',
|
|
loss_weight=1.0,
|
|
class_weight=class_weight)
|
|
loss = build_loss(loss_cfg)
|
|
assert torch.allclose(loss(cls_score, label), torch.tensor(176.667))
|
|
# test bce_loss with weight
|
|
assert torch.allclose(
|
|
loss(cls_score, label, weight=weight), torch.tensor(74.333))
|
|
|
|
# test bce loss with pos_weight
|
|
loss_cfg = dict(
|
|
type='CrossEntropyLoss',
|
|
use_sigmoid=True,
|
|
reduction='mean',
|
|
loss_weight=1.0,
|
|
pos_weight=pos_weight)
|
|
loss = build_loss(loss_cfg)
|
|
assert torch.allclose(loss(cls_score, label), torch.tensor(136.6667))
|
|
|
|
# test soft_ce_loss
|
|
cls_score = torch.Tensor([[-1000, 1000], [100, -100]])
|
|
label = torch.Tensor([[1.0, 0.0], [0.0, 1.0]])
|
|
class_weight = [0.3, 0.7] # class 0 : 0.3, class 1 : 0.7
|
|
weight = torch.tensor([0.6, 0.4])
|
|
|
|
# test soft_ce_loss without class weight
|
|
loss_cfg = dict(
|
|
type='CrossEntropyLoss',
|
|
use_soft=True,
|
|
reduction='mean',
|
|
loss_weight=1.0)
|
|
loss = build_loss(loss_cfg)
|
|
assert torch.allclose(loss(cls_score, label), torch.tensor(1100.))
|
|
# test soft_ce_loss with weight
|
|
assert torch.allclose(
|
|
loss(cls_score, label, weight=weight), torch.tensor(640.))
|
|
|
|
# test soft_ce_loss with class weight
|
|
loss_cfg = dict(
|
|
type='CrossEntropyLoss',
|
|
use_soft=True,
|
|
reduction='mean',
|
|
loss_weight=1.0,
|
|
class_weight=class_weight)
|
|
loss = build_loss(loss_cfg)
|
|
assert torch.allclose(loss(cls_score, label), torch.tensor(370.))
|
|
# test soft_ce_loss with weight
|
|
assert torch.allclose(
|
|
loss(cls_score, label, weight=weight), torch.tensor(208.))
|
|
|
|
|
|
def test_focal_loss():
|
|
# test focal_loss
|
|
cls_score = torch.Tensor([[5, -5, 0], [5, -5, 0]])
|
|
label = torch.Tensor([[1, 0, 1], [0, 1, 0]])
|
|
weight = torch.tensor([0.5, 0.5])
|
|
|
|
loss_cfg = dict(
|
|
type='FocalLoss',
|
|
gamma=2.0,
|
|
alpha=0.25,
|
|
reduction='mean',
|
|
loss_weight=1.0)
|
|
loss = build_loss(loss_cfg)
|
|
assert torch.allclose(loss(cls_score, label), torch.tensor(0.8522))
|
|
# test focal_loss with weight
|
|
assert torch.allclose(
|
|
loss(cls_score, label, weight=weight), torch.tensor(0.8522 / 2))
|
|
# test focal loss for single label task
|
|
cls_score = torch.Tensor([[5, -5, 0], [5, -5, 0]])
|
|
label = torch.Tensor([0, 1])
|
|
weight = torch.tensor([0.5, 0.5])
|
|
assert torch.allclose(loss(cls_score, label), torch.tensor(0.86664125))
|
|
# test focal_loss single label with weight
|
|
assert torch.allclose(
|
|
loss(cls_score, label, weight=weight), torch.tensor(0.86664125 / 2))
|
|
|
|
|
|
def test_label_smooth_loss():
|
|
# test label_smooth_val assertion
|
|
with pytest.raises(AssertionError):
|
|
loss_cfg = dict(type='LabelSmoothLoss', label_smooth_val=1.0)
|
|
build_loss(loss_cfg)
|
|
|
|
with pytest.raises(AssertionError):
|
|
loss_cfg = dict(type='LabelSmoothLoss', label_smooth_val='str')
|
|
build_loss(loss_cfg)
|
|
|
|
# test reduction assertion
|
|
with pytest.raises(AssertionError):
|
|
loss_cfg = dict(
|
|
type='LabelSmoothLoss', label_smooth_val=0.1, reduction='unknown')
|
|
build_loss(loss_cfg)
|
|
|
|
# test mode assertion
|
|
with pytest.raises(AssertionError):
|
|
loss_cfg = dict(
|
|
type='LabelSmoothLoss', label_smooth_val=0.1, mode='unknown')
|
|
build_loss(loss_cfg)
|
|
|
|
# test original mode label smooth loss
|
|
cls_score = torch.tensor([[1., -1.]])
|
|
label = torch.tensor([0])
|
|
|
|
loss_cfg = dict(
|
|
type='LabelSmoothLoss',
|
|
label_smooth_val=0.1,
|
|
mode='original',
|
|
reduction='mean',
|
|
loss_weight=1.0)
|
|
loss = build_loss(loss_cfg)
|
|
correct = 0.2269 # from timm
|
|
assert loss(cls_score, label) - correct <= 0.0001
|
|
|
|
loss_cfg = dict(
|
|
type='LabelSmoothLoss',
|
|
label_smooth_val=0.1,
|
|
mode='original',
|
|
use_sigmoid=True,
|
|
reduction='mean',
|
|
loss_weight=1.0)
|
|
loss = build_loss(loss_cfg)
|
|
correct = 0.3633 # from timm
|
|
assert loss(cls_score, label) - correct <= 0.0001
|
|
|
|
# test classy_vision mode label smooth loss
|
|
loss_cfg = dict(
|
|
type='LabelSmoothLoss',
|
|
label_smooth_val=0.1,
|
|
mode='classy_vision',
|
|
reduction='mean',
|
|
loss_weight=1.0)
|
|
loss = build_loss(loss_cfg)
|
|
correct = 0.2178 # from ClassyVision
|
|
assert loss(cls_score, label) - correct <= 0.0001
|
|
|
|
# test multi_label mode label smooth loss
|
|
cls_score = torch.tensor([[1., -1., 1]])
|
|
label = torch.tensor([[1, 0, 1]])
|
|
|
|
loss_cfg = dict(
|
|
type='LabelSmoothLoss',
|
|
label_smooth_val=0.1,
|
|
mode='multi_label',
|
|
reduction='mean',
|
|
loss_weight=1.0)
|
|
loss = build_loss(loss_cfg)
|
|
smooth_label = torch.tensor([[0.9, 0.1, 0.9]])
|
|
correct = torch.binary_cross_entropy_with_logits(cls_score,
|
|
smooth_label).mean()
|
|
assert torch.allclose(loss(cls_score, label), correct)
|
|
|
|
# test label linear combination smooth loss
|
|
cls_score = torch.tensor([[1., -1., 0.]])
|
|
label1 = torch.tensor([[1., 0., 0.]])
|
|
label2 = torch.tensor([[0., 0., 1.]])
|
|
label_mix = label1 * 0.6 + label2 * 0.4
|
|
|
|
loss_cfg = dict(
|
|
type='LabelSmoothLoss',
|
|
label_smooth_val=0.1,
|
|
mode='original',
|
|
reduction='mean',
|
|
num_classes=3,
|
|
loss_weight=1.0)
|
|
loss = build_loss(loss_cfg)
|
|
smooth_label1 = loss.original_smooth_label(label1)
|
|
smooth_label2 = loss.original_smooth_label(label2)
|
|
label_smooth_mix = smooth_label1 * 0.6 + smooth_label2 * 0.4
|
|
correct = (-torch.log_softmax(cls_score, -1) * label_smooth_mix).sum()
|
|
|
|
assert loss(cls_score, label_mix) - correct <= 0.0001
|
|
|
|
# test label smooth loss with weight
|
|
cls_score = torch.tensor([[1., -1.], [1., -1.]])
|
|
label = torch.tensor([0, 1])
|
|
weight = torch.tensor([0.5, 0.5])
|
|
|
|
loss_cfg = dict(
|
|
type='LabelSmoothLoss',
|
|
reduction='mean',
|
|
label_smooth_val=0.1,
|
|
loss_weight=1.0)
|
|
loss = build_loss(loss_cfg)
|
|
assert torch.allclose(
|
|
loss(cls_score, label, weight=weight),
|
|
loss(cls_score, label) / 2)
|
|
|
|
|
|
# migrate from mmdetection with modifications
|
|
def test_seesaw_loss():
|
|
# only softmax version of Seesaw Loss is implemented
|
|
with pytest.raises(AssertionError):
|
|
loss_cfg = dict(type='SeesawLoss', use_sigmoid=True, loss_weight=1.0)
|
|
build_loss(loss_cfg)
|
|
|
|
# test that cls_score.size(-1) == num_classes
|
|
loss_cls_cfg = dict(
|
|
type='SeesawLoss', p=0.0, q=0.0, loss_weight=1.0, num_classes=2)
|
|
loss_cls = build_loss(loss_cls_cfg)
|
|
# the length of fake_pred should be num_classe = 4
|
|
with pytest.raises(AssertionError):
|
|
fake_pred = torch.Tensor([[-100, 100, -100]])
|
|
fake_label = torch.Tensor([1]).long()
|
|
loss_cls(fake_pred, fake_label)
|
|
# the length of fake_pred should be num_classes + 2 = 4
|
|
with pytest.raises(AssertionError):
|
|
fake_pred = torch.Tensor([[-100, 100, -100, 100]])
|
|
fake_label = torch.Tensor([1]).long()
|
|
loss_cls(fake_pred, fake_label)
|
|
|
|
# test the calculation without p and q
|
|
loss_cls_cfg = dict(
|
|
type='SeesawLoss', p=0.0, q=0.0, loss_weight=1.0, num_classes=2)
|
|
loss_cls = build_loss(loss_cls_cfg)
|
|
fake_pred = torch.Tensor([[-100, 100]])
|
|
fake_label = torch.Tensor([1]).long()
|
|
loss = loss_cls(fake_pred, fake_label)
|
|
assert torch.allclose(loss, torch.tensor(0.))
|
|
|
|
# test the calculation with p and without q
|
|
loss_cls_cfg = dict(
|
|
type='SeesawLoss', p=1.0, q=0.0, loss_weight=1.0, num_classes=2)
|
|
loss_cls = build_loss(loss_cls_cfg)
|
|
fake_pred = torch.Tensor([[-100, 100]])
|
|
fake_label = torch.Tensor([0]).long()
|
|
loss_cls.cum_samples[0] = torch.exp(torch.Tensor([20]))
|
|
loss = loss_cls(fake_pred, fake_label)
|
|
assert torch.allclose(loss, torch.tensor(180.))
|
|
|
|
# test the calculation with q and without p
|
|
loss_cls_cfg = dict(
|
|
type='SeesawLoss', p=0.0, q=1.0, loss_weight=1.0, num_classes=2)
|
|
loss_cls = build_loss(loss_cls_cfg)
|
|
fake_pred = torch.Tensor([[-100, 100]])
|
|
fake_label = torch.Tensor([0]).long()
|
|
loss = loss_cls(fake_pred, fake_label)
|
|
assert torch.allclose(loss, torch.tensor(200.) + torch.tensor(100.).log())
|
|
|
|
|
|
def test_reconstruction_loss():
|
|
|
|
# test L2 loss
|
|
loss_config = dict(type='PixelReconstructionLoss', criterion='L2')
|
|
loss = build_loss(loss_config)
|
|
|
|
fake_pred = torch.rand((2, 196, 768))
|
|
fake_target = torch.rand((2, 196, 768))
|
|
fake_mask = torch.ones((2, 196))
|
|
loss_value = loss(fake_pred, fake_target, fake_mask)
|
|
|
|
assert isinstance(loss_value.item(), float)
|
|
|
|
# test L1 loss
|
|
loss_config = dict(
|
|
type='PixelReconstructionLoss', criterion='L1', channel=3)
|
|
loss = build_loss(loss_config)
|
|
|
|
fake_pred = torch.rand((2, 3, 192, 192))
|
|
fake_target = torch.rand((2, 3, 192, 192))
|
|
fake_mask = torch.ones((2, 1, 192, 192))
|
|
loss_value = loss(fake_pred, fake_target, fake_mask)
|
|
|
|
assert isinstance(loss_value.item(), float)
|
|
|
|
with pytest.raises(NotImplementedError):
|
|
loss_config = dict(type='PixelReconstructionLoss', criterion='L3')
|
|
loss = build_loss(loss_config)
|