mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature] Support multiple losses during training (#818)
* multiple losses * fix lint error * fix typos * fix typos * Adding Attribute * Fixing loss_ prefix * Fixing loss_ prefix * Fixing loss_ prefix * Add Same * loss_name must has 'loss_' prefix * Fix unittest * Fix unittest * Fix unittest * Update mmseg/models/decode_heads/decode_head.py Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>
This commit is contained in:
parent
e13076adef
commit
0fd3972c41
@ -50,3 +50,21 @@ model=dict(
|
||||
```
|
||||
|
||||
`class_weight` will be passed into `CrossEntropyLoss` as `weight` argument. Please refer to [PyTorch Doc](https://pytorch.org/docs/stable/nn.html?highlight=crossentropy#torch.nn.CrossEntropyLoss) for details.
|
||||
|
||||
## Multiple Losses
|
||||
|
||||
For loss calculation, we support multiple losses training concurrently. Here is an example config of training `unet` on `DRIVE` dataset, whose loss function is `1:3` weighted sum of `CrossEntropyLoss` and `DiceLoss`:
|
||||
|
||||
```python
|
||||
_base_ = './fcn_unet_s5-d16_64x64_40k_drive.py'
|
||||
model = dict(
|
||||
decode_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0),
|
||||
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),
|
||||
auxiliary_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce',loss_weight=1.0),
|
||||
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),
|
||||
)
|
||||
```
|
||||
|
||||
In this way, `loss_weight` and `loss_name` will be weight and name in training log of corresponding loss, respectively.
|
||||
|
||||
Note: If you want this loss item to be included into the backward graph, `loss_` must be the prefix of the name.
|
||||
|
@ -49,3 +49,22 @@ model=dict(
|
||||
```
|
||||
|
||||
`class_weight` 将被作为 `weight` 参数,传递给 `CrossEntropyLoss`。详细信息请参照 [PyTorch 文档](https://pytorch.org/docs/stable/nn.html?highlight=crossentropy#torch.nn.CrossEntropyLoss) 。
|
||||
|
||||
## 同时使用多种损失函数 (Multiple Losses)
|
||||
|
||||
对于训练时损失函数的计算,我们目前支持多个损失函数同时使用。 以 `unet` 使用 `DRIVE` 数据集训练为例,
|
||||
使用 `CrossEntropyLoss` 和 `DiceLoss` 的 `1:3` 的加权和作为损失函数。配置文件写为:
|
||||
|
||||
```python
|
||||
_base_ = './fcn_unet_s5-d16_64x64_40k_drive.py'
|
||||
model = dict(
|
||||
decode_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0),
|
||||
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),
|
||||
auxiliary_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce',loss_weight=1.0),
|
||||
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),
|
||||
)
|
||||
```
|
||||
|
||||
通过这种方式,确定训练过程中损失函数的权重 `loss_weight` 和在训练日志里的名字 `loss_name`。
|
||||
|
||||
注意: `loss_name` 的名字必须带有 `loss_` 前缀,这样它才能被包括在反传的图里。
|
||||
|
@ -62,12 +62,14 @@ class OHEMPixelSampler(BasePixelSampler):
|
||||
threshold = max(min_threshold, self.thresh)
|
||||
valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
|
||||
else:
|
||||
losses = self.context.loss_decode(
|
||||
seg_logit,
|
||||
seg_label,
|
||||
weight=None,
|
||||
ignore_index=self.context.ignore_index,
|
||||
reduction_override='none')
|
||||
losses = 0.0
|
||||
for loss_module in self.context.loss_decode:
|
||||
losses += loss_module(
|
||||
seg_logit,
|
||||
seg_label,
|
||||
weight=None,
|
||||
ignore_index=self.context.ignore_index,
|
||||
reduction_override='none')
|
||||
# faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa
|
||||
_, sort_indices = losses[valid_mask].sort(descending=True)
|
||||
valid_seg_weight[sort_indices[:batch_kept]] = 1.
|
||||
|
@ -33,10 +33,17 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
||||
a list and passed into decode head.
|
||||
None: Only one select feature map is allowed.
|
||||
Default: None.
|
||||
loss_decode (dict): Config of decode loss.
|
||||
loss_decode (dict | Sequence[dict]): Config of decode loss.
|
||||
The `loss_name` is property of corresponding loss function which
|
||||
could be shown in training log. If you want this loss
|
||||
item to be included into the backward graph, `loss_` must be the
|
||||
prefix of the name. Defaults to 'loss_ce'.
|
||||
e.g. dict(type='CrossEntropyLoss'),
|
||||
[dict(type='CrossEntropyLoss', loss_name='loss_ce'),
|
||||
dict(type='DiceLoss', loss_name='loss_dice')]
|
||||
Default: dict(type='CrossEntropyLoss').
|
||||
ignore_index (int | None): The label index to be ignored. When using
|
||||
masked BCE loss, ignore_index should be set to None. Default: 255
|
||||
masked BCE loss, ignore_index should be set to None. Default: 255.
|
||||
sampler (dict|None): The config of segmentation map sampler.
|
||||
Default: None.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
@ -73,9 +80,20 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.in_index = in_index
|
||||
self.loss_decode = build_loss(loss_decode)
|
||||
|
||||
self.ignore_index = ignore_index
|
||||
self.align_corners = align_corners
|
||||
self.loss_decode = nn.ModuleList()
|
||||
|
||||
if isinstance(loss_decode, dict):
|
||||
self.loss_decode.append(build_loss(loss_decode))
|
||||
elif isinstance(loss_decode, (list, tuple)):
|
||||
for loss in loss_decode:
|
||||
self.loss_decode.append(build_loss(loss))
|
||||
else:
|
||||
raise TypeError(f'loss_decode must be a dict or sequence of dict,\
|
||||
but got {type(loss_decode)}')
|
||||
|
||||
if sampler is not None:
|
||||
self.sampler = build_pixel_sampler(sampler, context=self)
|
||||
else:
|
||||
@ -224,10 +242,19 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
||||
else:
|
||||
seg_weight = None
|
||||
seg_label = seg_label.squeeze(1)
|
||||
loss['loss_seg'] = self.loss_decode(
|
||||
seg_logit,
|
||||
seg_label,
|
||||
weight=seg_weight,
|
||||
ignore_index=self.ignore_index)
|
||||
for loss_decode in self.loss_decode:
|
||||
if loss_decode.loss_name not in loss:
|
||||
loss[loss_decode.loss_name] = loss_decode(
|
||||
seg_logit,
|
||||
seg_label,
|
||||
weight=seg_weight,
|
||||
ignore_index=self.ignore_index)
|
||||
else:
|
||||
loss[loss_decode.loss_name] += loss_decode(
|
||||
seg_logit,
|
||||
seg_label,
|
||||
weight=seg_weight,
|
||||
ignore_index=self.ignore_index)
|
||||
|
||||
loss['acc_seg'] = accuracy(seg_logit, seg_label)
|
||||
return loss
|
||||
|
@ -249,8 +249,9 @@ class PointHead(BaseCascadeDecodeHead):
|
||||
def losses(self, point_logits, point_label):
|
||||
"""Compute segmentation loss."""
|
||||
loss = dict()
|
||||
loss['loss_point'] = self.loss_decode(
|
||||
point_logits, point_label, ignore_index=self.ignore_index)
|
||||
for loss_module in self.loss_decode:
|
||||
loss['point' + loss_module.loss_name] = loss_module(
|
||||
point_logits, point_label, ignore_index=self.ignore_index)
|
||||
loss['acc_point'] = accuracy(point_logits, point_label)
|
||||
return loss
|
||||
|
||||
|
@ -150,6 +150,9 @@ class CrossEntropyLoss(nn.Module):
|
||||
class_weight (list[float] | str, optional): Weight of each class. If in
|
||||
str format, read them from a file. Defaults to None.
|
||||
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
|
||||
loss_name (str, optional): Name of the loss item. If you want this loss
|
||||
item to be included into the backward graph, `loss_` must be the
|
||||
prefix of the name. Defaults to 'loss_ce'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -157,7 +160,8 @@ class CrossEntropyLoss(nn.Module):
|
||||
use_mask=False,
|
||||
reduction='mean',
|
||||
class_weight=None,
|
||||
loss_weight=1.0):
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_ce'):
|
||||
super(CrossEntropyLoss, self).__init__()
|
||||
assert (use_sigmoid is False) or (use_mask is False)
|
||||
self.use_sigmoid = use_sigmoid
|
||||
@ -172,6 +176,7 @@ class CrossEntropyLoss(nn.Module):
|
||||
self.cls_criterion = mask_cross_entropy
|
||||
else:
|
||||
self.cls_criterion = cross_entropy
|
||||
self._loss_name = loss_name
|
||||
|
||||
def forward(self,
|
||||
cls_score,
|
||||
@ -197,3 +202,17 @@ class CrossEntropyLoss(nn.Module):
|
||||
avg_factor=avg_factor,
|
||||
**kwargs)
|
||||
return loss_cls
|
||||
|
||||
@property
|
||||
def loss_name(self):
|
||||
"""Loss Name.
|
||||
|
||||
This function must be implemented and will return the name of this
|
||||
loss function. This name will be used to combine different loss items
|
||||
by simple sum operation. In addition, if you want this loss item to be
|
||||
included into the backward graph, `loss_` must be the prefix of the
|
||||
name.
|
||||
Returns:
|
||||
str: The name of this loss item.
|
||||
"""
|
||||
return self._loss_name
|
||||
|
@ -68,6 +68,9 @@ class DiceLoss(nn.Module):
|
||||
str format, read them from a file. Defaults to None.
|
||||
loss_weight (float, optional): Weight of the loss. Default to 1.0.
|
||||
ignore_index (int | None): The label index to be ignored. Default: 255.
|
||||
loss_name (str, optional): Name of the loss item. If you want this loss
|
||||
item to be included into the backward graph, `loss_` must be the
|
||||
prefix of the name. Defaults to 'loss_dice'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -77,6 +80,7 @@ class DiceLoss(nn.Module):
|
||||
class_weight=None,
|
||||
loss_weight=1.0,
|
||||
ignore_index=255,
|
||||
loss_name='loss_dice',
|
||||
**kwards):
|
||||
super(DiceLoss, self).__init__()
|
||||
self.smooth = smooth
|
||||
@ -85,6 +89,7 @@ class DiceLoss(nn.Module):
|
||||
self.class_weight = get_class_weight(class_weight)
|
||||
self.loss_weight = loss_weight
|
||||
self.ignore_index = ignore_index
|
||||
self._loss_name = loss_name
|
||||
|
||||
def forward(self,
|
||||
pred,
|
||||
@ -118,3 +123,17 @@ class DiceLoss(nn.Module):
|
||||
class_weight=class_weight,
|
||||
ignore_index=self.ignore_index)
|
||||
return loss
|
||||
|
||||
@property
|
||||
def loss_name(self):
|
||||
"""Loss Name.
|
||||
|
||||
This function must be implemented and will return the name of this
|
||||
loss function. This name will be used to combine different loss items
|
||||
by simple sum operation. In addition, if you want this loss item to be
|
||||
included into the backward graph, `loss_` must be the prefix of the
|
||||
name.
|
||||
Returns:
|
||||
str: The name of this loss item.
|
||||
"""
|
||||
return self._loss_name
|
||||
|
@ -244,6 +244,9 @@ class LovaszLoss(nn.Module):
|
||||
class_weight (list[float] | str, optional): Weight of each class. If in
|
||||
str format, read them from a file. Defaults to None.
|
||||
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
|
||||
loss_name (str, optional): Name of the loss item. If you want this loss
|
||||
item to be included into the backward graph, `loss_` must be the
|
||||
prefix of the name. Defaults to 'loss_lovasz'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -252,7 +255,8 @@ class LovaszLoss(nn.Module):
|
||||
per_image=False,
|
||||
reduction='mean',
|
||||
class_weight=None,
|
||||
loss_weight=1.0):
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_lovasz'):
|
||||
super(LovaszLoss, self).__init__()
|
||||
assert loss_type in ('binary', 'multi_class'), "loss_type should be \
|
||||
'binary' or 'multi_class'."
|
||||
@ -271,6 +275,7 @@ class LovaszLoss(nn.Module):
|
||||
self.reduction = reduction
|
||||
self.loss_weight = loss_weight
|
||||
self.class_weight = get_class_weight(class_weight)
|
||||
self._loss_name = loss_name
|
||||
|
||||
def forward(self,
|
||||
cls_score,
|
||||
@ -302,3 +307,17 @@ class LovaszLoss(nn.Module):
|
||||
avg_factor=avg_factor,
|
||||
**kwargs)
|
||||
return loss_cls
|
||||
|
||||
@property
|
||||
def loss_name(self):
|
||||
"""Loss Name.
|
||||
|
||||
This function must be implemented and will return the name of this
|
||||
loss function. This name will be used to combine different loss items
|
||||
by simple sum operation. In addition, if you want this loss item to be
|
||||
included into the backward graph, `loss_` must be the prefix of the
|
||||
name.
|
||||
Returns:
|
||||
str: The name of this loss item.
|
||||
"""
|
||||
return self._loss_name
|
||||
|
@ -74,3 +74,92 @@ def test_decode_head():
|
||||
assert head.input_transform == 'resize_concat'
|
||||
transformed_inputs = head._transform_inputs(inputs)
|
||||
assert transformed_inputs.shape == (1, 48, 45, 45)
|
||||
|
||||
# test multi-loss, loss_decode is dict
|
||||
with pytest.raises(TypeError):
|
||||
# loss_decode must be a dict or sequence of dict.
|
||||
BaseDecodeHead(3, 16, num_classes=19, loss_decode=['CrossEntropyLoss'])
|
||||
|
||||
inputs = torch.randn(2, 19, 8, 8).float()
|
||||
target = torch.ones(2, 1, 64, 64).long()
|
||||
head = BaseDecodeHead(
|
||||
3,
|
||||
16,
|
||||
num_classes=19,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
head, target = to_cuda(head, target)
|
||||
loss = head.losses(seg_logit=inputs, seg_label=target)
|
||||
assert 'loss_ce' in loss
|
||||
|
||||
# test multi-loss, loss_decode is list of dict
|
||||
inputs = torch.randn(2, 19, 8, 8).float()
|
||||
target = torch.ones(2, 1, 64, 64).long()
|
||||
head = BaseDecodeHead(
|
||||
3,
|
||||
16,
|
||||
num_classes=19,
|
||||
loss_decode=[
|
||||
dict(type='CrossEntropyLoss', loss_name='loss_1'),
|
||||
dict(type='CrossEntropyLoss', loss_name='loss_2')
|
||||
])
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
head, target = to_cuda(head, target)
|
||||
loss = head.losses(seg_logit=inputs, seg_label=target)
|
||||
assert 'loss_1' in loss
|
||||
assert 'loss_2' in loss
|
||||
|
||||
# 'loss_decode' must be a dict or sequence of dict
|
||||
with pytest.raises(TypeError):
|
||||
BaseDecodeHead(3, 16, num_classes=19, loss_decode=['CrossEntropyLoss'])
|
||||
with pytest.raises(TypeError):
|
||||
BaseDecodeHead(3, 16, num_classes=19, loss_decode=0)
|
||||
|
||||
# test multi-loss, loss_decode is list of dict
|
||||
inputs = torch.randn(2, 19, 8, 8).float()
|
||||
target = torch.ones(2, 1, 64, 64).long()
|
||||
head = BaseDecodeHead(
|
||||
3,
|
||||
16,
|
||||
num_classes=19,
|
||||
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_1'),
|
||||
dict(type='CrossEntropyLoss', loss_name='loss_2'),
|
||||
dict(type='CrossEntropyLoss', loss_name='loss_3')))
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
head, target = to_cuda(head, target)
|
||||
loss = head.losses(seg_logit=inputs, seg_label=target)
|
||||
assert 'loss_1' in loss
|
||||
assert 'loss_2' in loss
|
||||
assert 'loss_3' in loss
|
||||
|
||||
# test multi-loss, loss_decode is list of dict, names of them are identical
|
||||
inputs = torch.randn(2, 19, 8, 8).float()
|
||||
target = torch.ones(2, 1, 64, 64).long()
|
||||
head = BaseDecodeHead(
|
||||
3,
|
||||
16,
|
||||
num_classes=19,
|
||||
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_ce'),
|
||||
dict(type='CrossEntropyLoss', loss_name='loss_ce'),
|
||||
dict(type='CrossEntropyLoss', loss_name='loss_ce')))
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
head, target = to_cuda(head, target)
|
||||
loss_3 = head.losses(seg_logit=inputs, seg_label=target)
|
||||
|
||||
head = BaseDecodeHead(
|
||||
3,
|
||||
16,
|
||||
num_classes=19,
|
||||
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_ce')))
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
head, target = to_cuda(head, target)
|
||||
loss = head.losses(seg_logit=inputs, seg_label=target)
|
||||
assert 'loss_ce' in loss
|
||||
assert 'loss_ce' in loss_3
|
||||
assert loss_3['loss_ce'] == 3 * loss['loss_ce']
|
||||
|
@ -20,7 +20,8 @@ def test_ce_loss():
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=False,
|
||||
class_weight=[0.8, 0.2],
|
||||
loss_weight=1.0)
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_ce')
|
||||
loss_cls = build_loss(loss_cls_cfg)
|
||||
fake_pred = torch.Tensor([[100, -100]])
|
||||
fake_label = torch.Tensor([1]).long()
|
||||
@ -38,7 +39,8 @@ def test_ce_loss():
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=False,
|
||||
class_weight=f'{tmp_file.name}.pkl',
|
||||
loss_weight=1.0)
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_ce')
|
||||
loss_cls = build_loss(loss_cls_cfg)
|
||||
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))
|
||||
|
||||
@ -47,7 +49,8 @@ def test_ce_loss():
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=False,
|
||||
class_weight=f'{tmp_file.name}.npy',
|
||||
loss_weight=1.0)
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_ce')
|
||||
loss_cls = build_loss(loss_cls_cfg)
|
||||
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))
|
||||
tmp_file.close()
|
||||
@ -74,4 +77,12 @@ def test_ce_loss():
|
||||
torch.tensor(0.9354),
|
||||
atol=1e-4)
|
||||
|
||||
# test cross entropy loss has name `loss_ce`
|
||||
loss_cls_cfg = dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=False,
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_ce')
|
||||
loss_cls = build_loss(loss_cls_cfg)
|
||||
assert loss_cls.loss_name == 'loss_ce'
|
||||
# TODO test use_mask
|
||||
|
@ -11,7 +11,8 @@ def test_dice_lose():
|
||||
reduction='none',
|
||||
class_weight=[1.0, 2.0, 3.0],
|
||||
loss_weight=1.0,
|
||||
ignore_index=1)
|
||||
ignore_index=1,
|
||||
loss_name='loss_dice')
|
||||
dice_loss = build_loss(loss_cfg)
|
||||
logits = torch.rand(8, 3, 4, 4)
|
||||
labels = (torch.rand(8, 4, 4) * 3).long()
|
||||
@ -30,7 +31,8 @@ def test_dice_lose():
|
||||
reduction='none',
|
||||
class_weight=f'{tmp_file.name}.pkl',
|
||||
loss_weight=1.0,
|
||||
ignore_index=1)
|
||||
ignore_index=1,
|
||||
loss_name='loss_dice')
|
||||
dice_loss = build_loss(loss_cfg)
|
||||
dice_loss(logits, labels, ignore_index=None)
|
||||
|
||||
@ -40,7 +42,8 @@ def test_dice_lose():
|
||||
reduction='none',
|
||||
class_weight=f'{tmp_file.name}.pkl',
|
||||
loss_weight=1.0,
|
||||
ignore_index=1)
|
||||
ignore_index=1,
|
||||
loss_name='loss_dice')
|
||||
dice_loss = build_loss(loss_cfg)
|
||||
dice_loss(logits, labels, ignore_index=None)
|
||||
tmp_file.close()
|
||||
@ -54,8 +57,21 @@ def test_dice_lose():
|
||||
exponent=3,
|
||||
reduction='sum',
|
||||
loss_weight=1.0,
|
||||
ignore_index=0)
|
||||
ignore_index=0,
|
||||
loss_name='loss_dice')
|
||||
dice_loss = build_loss(loss_cfg)
|
||||
logits = torch.rand(8, 2, 4, 4)
|
||||
labels = (torch.rand(8, 4, 4) * 2).long()
|
||||
dice_loss(logits, labels)
|
||||
|
||||
# test dice loss has name `loss_dice`
|
||||
loss_cfg = dict(
|
||||
type='DiceLoss',
|
||||
smooth=2,
|
||||
exponent=3,
|
||||
reduction='sum',
|
||||
loss_weight=1.0,
|
||||
ignore_index=0,
|
||||
loss_name='loss_dice')
|
||||
dice_loss = build_loss(loss_cfg)
|
||||
assert dice_loss.loss_name == 'loss_dice'
|
||||
|
@ -12,16 +12,24 @@ def test_lovasz_loss():
|
||||
type='LovaszLoss',
|
||||
loss_type='Binary',
|
||||
reduction='none',
|
||||
loss_weight=1.0)
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_lovasz')
|
||||
build_loss(loss_cfg)
|
||||
|
||||
# reduction should be 'none' when per_image is False.
|
||||
with pytest.raises(AssertionError):
|
||||
loss_cfg = dict(type='LovaszLoss', loss_type='multi_class')
|
||||
loss_cfg = dict(
|
||||
type='LovaszLoss',
|
||||
loss_type='multi_class',
|
||||
loss_name='loss_lovasz')
|
||||
build_loss(loss_cfg)
|
||||
|
||||
# test lovasz loss with loss_type = 'multi_class' and per_image = False
|
||||
loss_cfg = dict(type='LovaszLoss', reduction='none', loss_weight=1.0)
|
||||
loss_cfg = dict(
|
||||
type='LovaszLoss',
|
||||
reduction='none',
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_lovasz')
|
||||
lovasz_loss = build_loss(loss_cfg)
|
||||
logits = torch.rand(1, 3, 4, 4)
|
||||
labels = (torch.rand(1, 4, 4) * 2).long()
|
||||
@ -33,7 +41,8 @@ def test_lovasz_loss():
|
||||
per_image=True,
|
||||
reduction='mean',
|
||||
class_weight=[1.0, 2.0, 3.0],
|
||||
loss_weight=1.0)
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_lovasz')
|
||||
lovasz_loss = build_loss(loss_cfg)
|
||||
logits = torch.rand(1, 3, 4, 4)
|
||||
labels = (torch.rand(1, 4, 4) * 2).long()
|
||||
@ -52,7 +61,8 @@ def test_lovasz_loss():
|
||||
per_image=True,
|
||||
reduction='mean',
|
||||
class_weight=f'{tmp_file.name}.pkl',
|
||||
loss_weight=1.0)
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_lovasz')
|
||||
lovasz_loss = build_loss(loss_cfg)
|
||||
lovasz_loss(logits, labels, ignore_index=None)
|
||||
|
||||
@ -62,7 +72,8 @@ def test_lovasz_loss():
|
||||
per_image=True,
|
||||
reduction='mean',
|
||||
class_weight=f'{tmp_file.name}.npy',
|
||||
loss_weight=1.0)
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_lovasz')
|
||||
lovasz_loss = build_loss(loss_cfg)
|
||||
lovasz_loss(logits, labels, ignore_index=None)
|
||||
tmp_file.close()
|
||||
@ -74,7 +85,8 @@ def test_lovasz_loss():
|
||||
type='LovaszLoss',
|
||||
loss_type='binary',
|
||||
reduction='none',
|
||||
loss_weight=1.0)
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_lovasz')
|
||||
lovasz_loss = build_loss(loss_cfg)
|
||||
logits = torch.rand(2, 4, 4)
|
||||
labels = (torch.rand(2, 4, 4)).long()
|
||||
@ -86,8 +98,20 @@ def test_lovasz_loss():
|
||||
loss_type='binary',
|
||||
per_image=True,
|
||||
reduction='mean',
|
||||
loss_weight=1.0)
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_lovasz')
|
||||
lovasz_loss = build_loss(loss_cfg)
|
||||
logits = torch.rand(2, 4, 4)
|
||||
labels = (torch.rand(2, 4, 4)).long()
|
||||
lovasz_loss(logits, labels, ignore_index=None)
|
||||
|
||||
# test lovasz loss has name `loss_lovasz`
|
||||
loss_cfg = dict(
|
||||
type='LovaszLoss',
|
||||
loss_type='binary',
|
||||
per_image=True,
|
||||
reduction='mean',
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_lovasz')
|
||||
lovasz_loss = build_loss(loss_cfg)
|
||||
assert lovasz_loss.loss_name == 'loss_lovasz'
|
||||
|
Loading…
x
Reference in New Issue
Block a user