mmsegmentation/tests/test_models/test_losses/test_dice_loss.py

31 lines
793 B
Python

import torch
def test_dice_lose():
from mmseg.models import build_loss
# test dice loss with loss_type = 'multi_class'
loss_cfg = dict(
type='DiceLoss',
reduction='none',
class_weight=[1.0, 2.0, 3.0],
loss_weight=1.0,
ignore_index=1)
dice_loss = build_loss(loss_cfg)
logits = torch.rand(8, 3, 4, 4)
labels = (torch.rand(8, 4, 4) * 3).long()
dice_loss(logits, labels)
# test dice loss with loss_type = 'binary'
loss_cfg = dict(
type='DiceLoss',
smooth=2,
exponent=3,
reduction='sum',
loss_weight=1.0,
ignore_index=0)
dice_loss = build_loss(loss_cfg)
logits = torch.rand(8, 2, 4, 4)
labels = (torch.rand(8, 4, 4) * 2).long()
dice_loss(logits, labels)