[Feature] Support multiple losses during training (#818)

* multiple losses

* fix lint error

* fix typos

* fix typos

* Adding Attribute

* Fixing loss_ prefix

* Fixing loss_ prefix

* Fixing loss_ prefix

* Add Same

* loss_name must has 'loss_' prefix

* Fix unittest

* Fix unittest

* Fix unittest

* Update mmseg/models/decode_heads/decode_head.py

Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>
This commit is contained in:
MengzhangLI 2021-09-24 15:08:28 +08:00 committed by GitHub
parent e13076adef
commit 0fd3972c41
12 changed files with 297 additions and 33 deletions

View File

@ -50,3 +50,21 @@ model=dict(
``` ```
`class_weight` will be passed into `CrossEntropyLoss` as `weight` argument. Please refer to [PyTorch Doc](https://pytorch.org/docs/stable/nn.html?highlight=crossentropy#torch.nn.CrossEntropyLoss) for details. `class_weight` will be passed into `CrossEntropyLoss` as `weight` argument. Please refer to [PyTorch Doc](https://pytorch.org/docs/stable/nn.html?highlight=crossentropy#torch.nn.CrossEntropyLoss) for details.
## Multiple Losses
For loss calculation, we support multiple losses training concurrently. Here is an example config of training `unet` on `DRIVE` dataset, whose loss function is `1:3` weighted sum of `CrossEntropyLoss` and `DiceLoss`:
```python
_base_ = './fcn_unet_s5-d16_64x64_40k_drive.py'
model = dict(
decode_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0),
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),
auxiliary_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce',loss_weight=1.0),
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),
)
```
In this way, `loss_weight` and `loss_name` will be weight and name in training log of corresponding loss, respectively.
Note: If you want this loss item to be included into the backward graph, `loss_` must be the prefix of the name.

View File

@ -49,3 +49,22 @@ model=dict(
``` ```
`class_weight` 将被作为 `weight` 参数,传递给 `CrossEntropyLoss`。详细信息请参照 [PyTorch 文档](https://pytorch.org/docs/stable/nn.html?highlight=crossentropy#torch.nn.CrossEntropyLoss) 。 `class_weight` 将被作为 `weight` 参数,传递给 `CrossEntropyLoss`。详细信息请参照 [PyTorch 文档](https://pytorch.org/docs/stable/nn.html?highlight=crossentropy#torch.nn.CrossEntropyLoss) 。
## 同时使用多种损失函数 (Multiple Losses)
对于训练时损失函数的计算,我们目前支持多个损失函数同时使用。 以 `unet` 使用 `DRIVE` 数据集训练为例,
使用 `CrossEntropyLoss``DiceLoss``1:3` 的加权和作为损失函数。配置文件写为:
```python
_base_ = './fcn_unet_s5-d16_64x64_40k_drive.py'
model = dict(
decode_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0),
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),
auxiliary_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce',loss_weight=1.0),
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),
)
```
通过这种方式,确定训练过程中损失函数的权重 `loss_weight` 和在训练日志里的名字 `loss_name`
注意: `loss_name` 的名字必须带有 `loss_` 前缀,这样它才能被包括在反传的图里。

View File

@ -62,12 +62,14 @@ class OHEMPixelSampler(BasePixelSampler):
threshold = max(min_threshold, self.thresh) threshold = max(min_threshold, self.thresh)
valid_seg_weight[seg_prob[valid_mask] < threshold] = 1. valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
else: else:
losses = self.context.loss_decode( losses = 0.0
seg_logit, for loss_module in self.context.loss_decode:
seg_label, losses += loss_module(
weight=None, seg_logit,
ignore_index=self.context.ignore_index, seg_label,
reduction_override='none') weight=None,
ignore_index=self.context.ignore_index,
reduction_override='none')
# faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa # faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa
_, sort_indices = losses[valid_mask].sort(descending=True) _, sort_indices = losses[valid_mask].sort(descending=True)
valid_seg_weight[sort_indices[:batch_kept]] = 1. valid_seg_weight[sort_indices[:batch_kept]] = 1.

View File

@ -33,10 +33,17 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
a list and passed into decode head. a list and passed into decode head.
None: Only one select feature map is allowed. None: Only one select feature map is allowed.
Default: None. Default: None.
loss_decode (dict): Config of decode loss. loss_decode (dict | Sequence[dict]): Config of decode loss.
The `loss_name` is property of corresponding loss function which
could be shown in training log. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_ce'.
e.g. dict(type='CrossEntropyLoss'),
[dict(type='CrossEntropyLoss', loss_name='loss_ce'),
dict(type='DiceLoss', loss_name='loss_dice')]
Default: dict(type='CrossEntropyLoss'). Default: dict(type='CrossEntropyLoss').
ignore_index (int | None): The label index to be ignored. When using ignore_index (int | None): The label index to be ignored. When using
masked BCE loss, ignore_index should be set to None. Default: 255 masked BCE loss, ignore_index should be set to None. Default: 255.
sampler (dict|None): The config of segmentation map sampler. sampler (dict|None): The config of segmentation map sampler.
Default: None. Default: None.
align_corners (bool): align_corners argument of F.interpolate. align_corners (bool): align_corners argument of F.interpolate.
@ -73,9 +80,20 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
self.norm_cfg = norm_cfg self.norm_cfg = norm_cfg
self.act_cfg = act_cfg self.act_cfg = act_cfg
self.in_index = in_index self.in_index = in_index
self.loss_decode = build_loss(loss_decode)
self.ignore_index = ignore_index self.ignore_index = ignore_index
self.align_corners = align_corners self.align_corners = align_corners
self.loss_decode = nn.ModuleList()
if isinstance(loss_decode, dict):
self.loss_decode.append(build_loss(loss_decode))
elif isinstance(loss_decode, (list, tuple)):
for loss in loss_decode:
self.loss_decode.append(build_loss(loss))
else:
raise TypeError(f'loss_decode must be a dict or sequence of dict,\
but got {type(loss_decode)}')
if sampler is not None: if sampler is not None:
self.sampler = build_pixel_sampler(sampler, context=self) self.sampler = build_pixel_sampler(sampler, context=self)
else: else:
@ -224,10 +242,19 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
else: else:
seg_weight = None seg_weight = None
seg_label = seg_label.squeeze(1) seg_label = seg_label.squeeze(1)
loss['loss_seg'] = self.loss_decode( for loss_decode in self.loss_decode:
seg_logit, if loss_decode.loss_name not in loss:
seg_label, loss[loss_decode.loss_name] = loss_decode(
weight=seg_weight, seg_logit,
ignore_index=self.ignore_index) seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)
else:
loss[loss_decode.loss_name] += loss_decode(
seg_logit,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)
loss['acc_seg'] = accuracy(seg_logit, seg_label) loss['acc_seg'] = accuracy(seg_logit, seg_label)
return loss return loss

View File

@ -249,8 +249,9 @@ class PointHead(BaseCascadeDecodeHead):
def losses(self, point_logits, point_label): def losses(self, point_logits, point_label):
"""Compute segmentation loss.""" """Compute segmentation loss."""
loss = dict() loss = dict()
loss['loss_point'] = self.loss_decode( for loss_module in self.loss_decode:
point_logits, point_label, ignore_index=self.ignore_index) loss['point' + loss_module.loss_name] = loss_module(
point_logits, point_label, ignore_index=self.ignore_index)
loss['acc_point'] = accuracy(point_logits, point_label) loss['acc_point'] = accuracy(point_logits, point_label)
return loss return loss

View File

@ -150,6 +150,9 @@ class CrossEntropyLoss(nn.Module):
class_weight (list[float] | str, optional): Weight of each class. If in class_weight (list[float] | str, optional): Weight of each class. If in
str format, read them from a file. Defaults to None. str format, read them from a file. Defaults to None.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0. loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_ce'.
""" """
def __init__(self, def __init__(self,
@ -157,7 +160,8 @@ class CrossEntropyLoss(nn.Module):
use_mask=False, use_mask=False,
reduction='mean', reduction='mean',
class_weight=None, class_weight=None,
loss_weight=1.0): loss_weight=1.0,
loss_name='loss_ce'):
super(CrossEntropyLoss, self).__init__() super(CrossEntropyLoss, self).__init__()
assert (use_sigmoid is False) or (use_mask is False) assert (use_sigmoid is False) or (use_mask is False)
self.use_sigmoid = use_sigmoid self.use_sigmoid = use_sigmoid
@ -172,6 +176,7 @@ class CrossEntropyLoss(nn.Module):
self.cls_criterion = mask_cross_entropy self.cls_criterion = mask_cross_entropy
else: else:
self.cls_criterion = cross_entropy self.cls_criterion = cross_entropy
self._loss_name = loss_name
def forward(self, def forward(self,
cls_score, cls_score,
@ -197,3 +202,17 @@ class CrossEntropyLoss(nn.Module):
avg_factor=avg_factor, avg_factor=avg_factor,
**kwargs) **kwargs)
return loss_cls return loss_cls
@property
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name

View File

@ -68,6 +68,9 @@ class DiceLoss(nn.Module):
str format, read them from a file. Defaults to None. str format, read them from a file. Defaults to None.
loss_weight (float, optional): Weight of the loss. Default to 1.0. loss_weight (float, optional): Weight of the loss. Default to 1.0.
ignore_index (int | None): The label index to be ignored. Default: 255. ignore_index (int | None): The label index to be ignored. Default: 255.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_dice'.
""" """
def __init__(self, def __init__(self,
@ -77,6 +80,7 @@ class DiceLoss(nn.Module):
class_weight=None, class_weight=None,
loss_weight=1.0, loss_weight=1.0,
ignore_index=255, ignore_index=255,
loss_name='loss_dice',
**kwards): **kwards):
super(DiceLoss, self).__init__() super(DiceLoss, self).__init__()
self.smooth = smooth self.smooth = smooth
@ -85,6 +89,7 @@ class DiceLoss(nn.Module):
self.class_weight = get_class_weight(class_weight) self.class_weight = get_class_weight(class_weight)
self.loss_weight = loss_weight self.loss_weight = loss_weight
self.ignore_index = ignore_index self.ignore_index = ignore_index
self._loss_name = loss_name
def forward(self, def forward(self,
pred, pred,
@ -118,3 +123,17 @@ class DiceLoss(nn.Module):
class_weight=class_weight, class_weight=class_weight,
ignore_index=self.ignore_index) ignore_index=self.ignore_index)
return loss return loss
@property
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name

View File

@ -244,6 +244,9 @@ class LovaszLoss(nn.Module):
class_weight (list[float] | str, optional): Weight of each class. If in class_weight (list[float] | str, optional): Weight of each class. If in
str format, read them from a file. Defaults to None. str format, read them from a file. Defaults to None.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0. loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_lovasz'.
""" """
def __init__(self, def __init__(self,
@ -252,7 +255,8 @@ class LovaszLoss(nn.Module):
per_image=False, per_image=False,
reduction='mean', reduction='mean',
class_weight=None, class_weight=None,
loss_weight=1.0): loss_weight=1.0,
loss_name='loss_lovasz'):
super(LovaszLoss, self).__init__() super(LovaszLoss, self).__init__()
assert loss_type in ('binary', 'multi_class'), "loss_type should be \ assert loss_type in ('binary', 'multi_class'), "loss_type should be \
'binary' or 'multi_class'." 'binary' or 'multi_class'."
@ -271,6 +275,7 @@ class LovaszLoss(nn.Module):
self.reduction = reduction self.reduction = reduction
self.loss_weight = loss_weight self.loss_weight = loss_weight
self.class_weight = get_class_weight(class_weight) self.class_weight = get_class_weight(class_weight)
self._loss_name = loss_name
def forward(self, def forward(self,
cls_score, cls_score,
@ -302,3 +307,17 @@ class LovaszLoss(nn.Module):
avg_factor=avg_factor, avg_factor=avg_factor,
**kwargs) **kwargs)
return loss_cls return loss_cls
@property
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name

View File

@ -74,3 +74,92 @@ def test_decode_head():
assert head.input_transform == 'resize_concat' assert head.input_transform == 'resize_concat'
transformed_inputs = head._transform_inputs(inputs) transformed_inputs = head._transform_inputs(inputs)
assert transformed_inputs.shape == (1, 48, 45, 45) assert transformed_inputs.shape == (1, 48, 45, 45)
# test multi-loss, loss_decode is dict
with pytest.raises(TypeError):
# loss_decode must be a dict or sequence of dict.
BaseDecodeHead(3, 16, num_classes=19, loss_decode=['CrossEntropyLoss'])
inputs = torch.randn(2, 19, 8, 8).float()
target = torch.ones(2, 1, 64, 64).long()
head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
head, target = to_cuda(head, target)
loss = head.losses(seg_logit=inputs, seg_label=target)
assert 'loss_ce' in loss
# test multi-loss, loss_decode is list of dict
inputs = torch.randn(2, 19, 8, 8).float()
target = torch.ones(2, 1, 64, 64).long()
head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=[
dict(type='CrossEntropyLoss', loss_name='loss_1'),
dict(type='CrossEntropyLoss', loss_name='loss_2')
])
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
head, target = to_cuda(head, target)
loss = head.losses(seg_logit=inputs, seg_label=target)
assert 'loss_1' in loss
assert 'loss_2' in loss
# 'loss_decode' must be a dict or sequence of dict
with pytest.raises(TypeError):
BaseDecodeHead(3, 16, num_classes=19, loss_decode=['CrossEntropyLoss'])
with pytest.raises(TypeError):
BaseDecodeHead(3, 16, num_classes=19, loss_decode=0)
# test multi-loss, loss_decode is list of dict
inputs = torch.randn(2, 19, 8, 8).float()
target = torch.ones(2, 1, 64, 64).long()
head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_1'),
dict(type='CrossEntropyLoss', loss_name='loss_2'),
dict(type='CrossEntropyLoss', loss_name='loss_3')))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
head, target = to_cuda(head, target)
loss = head.losses(seg_logit=inputs, seg_label=target)
assert 'loss_1' in loss
assert 'loss_2' in loss
assert 'loss_3' in loss
# test multi-loss, loss_decode is list of dict, names of them are identical
inputs = torch.randn(2, 19, 8, 8).float()
target = torch.ones(2, 1, 64, 64).long()
head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_ce'),
dict(type='CrossEntropyLoss', loss_name='loss_ce'),
dict(type='CrossEntropyLoss', loss_name='loss_ce')))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
head, target = to_cuda(head, target)
loss_3 = head.losses(seg_logit=inputs, seg_label=target)
head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_ce')))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
head, target = to_cuda(head, target)
loss = head.losses(seg_logit=inputs, seg_label=target)
assert 'loss_ce' in loss
assert 'loss_ce' in loss_3
assert loss_3['loss_ce'] == 3 * loss['loss_ce']

View File

@ -20,7 +20,8 @@ def test_ce_loss():
type='CrossEntropyLoss', type='CrossEntropyLoss',
use_sigmoid=False, use_sigmoid=False,
class_weight=[0.8, 0.2], class_weight=[0.8, 0.2],
loss_weight=1.0) loss_weight=1.0,
loss_name='loss_ce')
loss_cls = build_loss(loss_cls_cfg) loss_cls = build_loss(loss_cls_cfg)
fake_pred = torch.Tensor([[100, -100]]) fake_pred = torch.Tensor([[100, -100]])
fake_label = torch.Tensor([1]).long() fake_label = torch.Tensor([1]).long()
@ -38,7 +39,8 @@ def test_ce_loss():
type='CrossEntropyLoss', type='CrossEntropyLoss',
use_sigmoid=False, use_sigmoid=False,
class_weight=f'{tmp_file.name}.pkl', class_weight=f'{tmp_file.name}.pkl',
loss_weight=1.0) loss_weight=1.0,
loss_name='loss_ce')
loss_cls = build_loss(loss_cls_cfg) loss_cls = build_loss(loss_cls_cfg)
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.)) assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))
@ -47,7 +49,8 @@ def test_ce_loss():
type='CrossEntropyLoss', type='CrossEntropyLoss',
use_sigmoid=False, use_sigmoid=False,
class_weight=f'{tmp_file.name}.npy', class_weight=f'{tmp_file.name}.npy',
loss_weight=1.0) loss_weight=1.0,
loss_name='loss_ce')
loss_cls = build_loss(loss_cls_cfg) loss_cls = build_loss(loss_cls_cfg)
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.)) assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))
tmp_file.close() tmp_file.close()
@ -74,4 +77,12 @@ def test_ce_loss():
torch.tensor(0.9354), torch.tensor(0.9354),
atol=1e-4) atol=1e-4)
# test cross entropy loss has name `loss_ce`
loss_cls_cfg = dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0,
loss_name='loss_ce')
loss_cls = build_loss(loss_cls_cfg)
assert loss_cls.loss_name == 'loss_ce'
# TODO test use_mask # TODO test use_mask

View File

@ -11,7 +11,8 @@ def test_dice_lose():
reduction='none', reduction='none',
class_weight=[1.0, 2.0, 3.0], class_weight=[1.0, 2.0, 3.0],
loss_weight=1.0, loss_weight=1.0,
ignore_index=1) ignore_index=1,
loss_name='loss_dice')
dice_loss = build_loss(loss_cfg) dice_loss = build_loss(loss_cfg)
logits = torch.rand(8, 3, 4, 4) logits = torch.rand(8, 3, 4, 4)
labels = (torch.rand(8, 4, 4) * 3).long() labels = (torch.rand(8, 4, 4) * 3).long()
@ -30,7 +31,8 @@ def test_dice_lose():
reduction='none', reduction='none',
class_weight=f'{tmp_file.name}.pkl', class_weight=f'{tmp_file.name}.pkl',
loss_weight=1.0, loss_weight=1.0,
ignore_index=1) ignore_index=1,
loss_name='loss_dice')
dice_loss = build_loss(loss_cfg) dice_loss = build_loss(loss_cfg)
dice_loss(logits, labels, ignore_index=None) dice_loss(logits, labels, ignore_index=None)
@ -40,7 +42,8 @@ def test_dice_lose():
reduction='none', reduction='none',
class_weight=f'{tmp_file.name}.pkl', class_weight=f'{tmp_file.name}.pkl',
loss_weight=1.0, loss_weight=1.0,
ignore_index=1) ignore_index=1,
loss_name='loss_dice')
dice_loss = build_loss(loss_cfg) dice_loss = build_loss(loss_cfg)
dice_loss(logits, labels, ignore_index=None) dice_loss(logits, labels, ignore_index=None)
tmp_file.close() tmp_file.close()
@ -54,8 +57,21 @@ def test_dice_lose():
exponent=3, exponent=3,
reduction='sum', reduction='sum',
loss_weight=1.0, loss_weight=1.0,
ignore_index=0) ignore_index=0,
loss_name='loss_dice')
dice_loss = build_loss(loss_cfg) dice_loss = build_loss(loss_cfg)
logits = torch.rand(8, 2, 4, 4) logits = torch.rand(8, 2, 4, 4)
labels = (torch.rand(8, 4, 4) * 2).long() labels = (torch.rand(8, 4, 4) * 2).long()
dice_loss(logits, labels) dice_loss(logits, labels)
# test dice loss has name `loss_dice`
loss_cfg = dict(
type='DiceLoss',
smooth=2,
exponent=3,
reduction='sum',
loss_weight=1.0,
ignore_index=0,
loss_name='loss_dice')
dice_loss = build_loss(loss_cfg)
assert dice_loss.loss_name == 'loss_dice'

View File

@ -12,16 +12,24 @@ def test_lovasz_loss():
type='LovaszLoss', type='LovaszLoss',
loss_type='Binary', loss_type='Binary',
reduction='none', reduction='none',
loss_weight=1.0) loss_weight=1.0,
loss_name='loss_lovasz')
build_loss(loss_cfg) build_loss(loss_cfg)
# reduction should be 'none' when per_image is False. # reduction should be 'none' when per_image is False.
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
loss_cfg = dict(type='LovaszLoss', loss_type='multi_class') loss_cfg = dict(
type='LovaszLoss',
loss_type='multi_class',
loss_name='loss_lovasz')
build_loss(loss_cfg) build_loss(loss_cfg)
# test lovasz loss with loss_type = 'multi_class' and per_image = False # test lovasz loss with loss_type = 'multi_class' and per_image = False
loss_cfg = dict(type='LovaszLoss', reduction='none', loss_weight=1.0) loss_cfg = dict(
type='LovaszLoss',
reduction='none',
loss_weight=1.0,
loss_name='loss_lovasz')
lovasz_loss = build_loss(loss_cfg) lovasz_loss = build_loss(loss_cfg)
logits = torch.rand(1, 3, 4, 4) logits = torch.rand(1, 3, 4, 4)
labels = (torch.rand(1, 4, 4) * 2).long() labels = (torch.rand(1, 4, 4) * 2).long()
@ -33,7 +41,8 @@ def test_lovasz_loss():
per_image=True, per_image=True,
reduction='mean', reduction='mean',
class_weight=[1.0, 2.0, 3.0], class_weight=[1.0, 2.0, 3.0],
loss_weight=1.0) loss_weight=1.0,
loss_name='loss_lovasz')
lovasz_loss = build_loss(loss_cfg) lovasz_loss = build_loss(loss_cfg)
logits = torch.rand(1, 3, 4, 4) logits = torch.rand(1, 3, 4, 4)
labels = (torch.rand(1, 4, 4) * 2).long() labels = (torch.rand(1, 4, 4) * 2).long()
@ -52,7 +61,8 @@ def test_lovasz_loss():
per_image=True, per_image=True,
reduction='mean', reduction='mean',
class_weight=f'{tmp_file.name}.pkl', class_weight=f'{tmp_file.name}.pkl',
loss_weight=1.0) loss_weight=1.0,
loss_name='loss_lovasz')
lovasz_loss = build_loss(loss_cfg) lovasz_loss = build_loss(loss_cfg)
lovasz_loss(logits, labels, ignore_index=None) lovasz_loss(logits, labels, ignore_index=None)
@ -62,7 +72,8 @@ def test_lovasz_loss():
per_image=True, per_image=True,
reduction='mean', reduction='mean',
class_weight=f'{tmp_file.name}.npy', class_weight=f'{tmp_file.name}.npy',
loss_weight=1.0) loss_weight=1.0,
loss_name='loss_lovasz')
lovasz_loss = build_loss(loss_cfg) lovasz_loss = build_loss(loss_cfg)
lovasz_loss(logits, labels, ignore_index=None) lovasz_loss(logits, labels, ignore_index=None)
tmp_file.close() tmp_file.close()
@ -74,7 +85,8 @@ def test_lovasz_loss():
type='LovaszLoss', type='LovaszLoss',
loss_type='binary', loss_type='binary',
reduction='none', reduction='none',
loss_weight=1.0) loss_weight=1.0,
loss_name='loss_lovasz')
lovasz_loss = build_loss(loss_cfg) lovasz_loss = build_loss(loss_cfg)
logits = torch.rand(2, 4, 4) logits = torch.rand(2, 4, 4)
labels = (torch.rand(2, 4, 4)).long() labels = (torch.rand(2, 4, 4)).long()
@ -86,8 +98,20 @@ def test_lovasz_loss():
loss_type='binary', loss_type='binary',
per_image=True, per_image=True,
reduction='mean', reduction='mean',
loss_weight=1.0) loss_weight=1.0,
loss_name='loss_lovasz')
lovasz_loss = build_loss(loss_cfg) lovasz_loss = build_loss(loss_cfg)
logits = torch.rand(2, 4, 4) logits = torch.rand(2, 4, 4)
labels = (torch.rand(2, 4, 4)).long() labels = (torch.rand(2, 4, 4)).long()
lovasz_loss(logits, labels, ignore_index=None) lovasz_loss(logits, labels, ignore_index=None)
# test lovasz loss has name `loss_lovasz`
loss_cfg = dict(
type='LovaszLoss',
loss_type='binary',
per_image=True,
reduction='mean',
loss_weight=1.0,
loss_name='loss_lovasz')
lovasz_loss = build_loss(loss_cfg)
assert lovasz_loss.loss_name == 'loss_lovasz'