31 lines
793 B
Python
31 lines
793 B
Python
import torch
|
|
|
|
|
|
def test_dice_lose():
|
|
from mmseg.models import build_loss
|
|
|
|
# test dice loss with loss_type = 'multi_class'
|
|
loss_cfg = dict(
|
|
type='DiceLoss',
|
|
reduction='none',
|
|
class_weight=[1.0, 2.0, 3.0],
|
|
loss_weight=1.0,
|
|
ignore_index=1)
|
|
dice_loss = build_loss(loss_cfg)
|
|
logits = torch.rand(8, 3, 4, 4)
|
|
labels = (torch.rand(8, 4, 4) * 3).long()
|
|
dice_loss(logits, labels)
|
|
|
|
# test dice loss with loss_type = 'binary'
|
|
loss_cfg = dict(
|
|
type='DiceLoss',
|
|
smooth=2,
|
|
exponent=3,
|
|
reduction='sum',
|
|
loss_weight=1.0,
|
|
ignore_index=0)
|
|
dice_loss = build_loss(loss_cfg)
|
|
logits = torch.rand(8, 2, 4, 4)
|
|
labels = (torch.rand(8, 4, 4) * 2).long()
|
|
dice_loss(logits, labels)
|