mmsegmentation/tests/test_models/test_losses/test_lovasz_loss.py

63 lines
2.0 KiB
Python

import pytest
import torch
def test_lovasz_loss():
from mmseg.models import build_loss
# loss_type should be 'binary' or 'multi_class'
with pytest.raises(AssertionError):
loss_cfg = dict(
type='LovaszLoss',
loss_type='Binary',
reduction='none',
loss_weight=1.0)
build_loss(loss_cfg)
# reduction should be 'none' when per_image is False.
with pytest.raises(AssertionError):
loss_cfg = dict(type='LovaszLoss', loss_type='multi_class')
build_loss(loss_cfg)
# test lovasz loss with loss_type = 'multi_class' and per_image = False
loss_cfg = dict(type='LovaszLoss', reduction='none', loss_weight=1.0)
lovasz_loss = build_loss(loss_cfg)
logits = torch.rand(1, 3, 4, 4)
labels = (torch.rand(1, 4, 4) * 2).long()
lovasz_loss(logits, labels)
# test lovasz loss with loss_type = 'multi_class' and per_image = True
loss_cfg = dict(
type='LovaszLoss',
per_image=True,
reduction='mean',
class_weight=[1.0, 2.0, 3.0],
loss_weight=1.0)
lovasz_loss = build_loss(loss_cfg)
logits = torch.rand(1, 3, 4, 4)
labels = (torch.rand(1, 4, 4) * 2).long()
lovasz_loss(logits, labels, ignore_index=None)
# test lovasz loss with loss_type = 'binary' and per_image = False
loss_cfg = dict(
type='LovaszLoss',
loss_type='binary',
reduction='none',
loss_weight=1.0)
lovasz_loss = build_loss(loss_cfg)
logits = torch.rand(2, 4, 4)
labels = (torch.rand(2, 4, 4)).long()
lovasz_loss(logits, labels)
# test lovasz loss with loss_type = 'binary' and per_image = True
loss_cfg = dict(
type='LovaszLoss',
loss_type='binary',
per_image=True,
reduction='mean',
loss_weight=1.0)
lovasz_loss = build_loss(loss_cfg)
logits = torch.rand(2, 4, 4)
labels = (torch.rand(2, 4, 4)).long()
lovasz_loss(logits, labels, ignore_index=None)