mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Enhance] Support reading class_weight from file in loss functions to help MMDet3D (#513)
* support reading class_weight from file in loss function * add unit test of loss with class_weight from file * minor fix * move get_class_weight to utils
This commit is contained in:
parent
ce56e68d30
commit
771ca7d3e0
@ -3,7 +3,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from ..builder import LOSSES
|
from ..builder import LOSSES
|
||||||
from .utils import weight_reduce_loss
|
from .utils import get_class_weight, weight_reduce_loss
|
||||||
|
|
||||||
|
|
||||||
def cross_entropy(pred,
|
def cross_entropy(pred,
|
||||||
@ -146,8 +146,8 @@ class CrossEntropyLoss(nn.Module):
|
|||||||
Defaults to False.
|
Defaults to False.
|
||||||
reduction (str, optional): . Defaults to 'mean'.
|
reduction (str, optional): . Defaults to 'mean'.
|
||||||
Options are "none", "mean" and "sum".
|
Options are "none", "mean" and "sum".
|
||||||
class_weight (list[float], optional): Weight of each class.
|
class_weight (list[float] | str, optional): Weight of each class. If in
|
||||||
Defaults to None.
|
str format, read them from a file. Defaults to None.
|
||||||
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
|
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -163,7 +163,7 @@ class CrossEntropyLoss(nn.Module):
|
|||||||
self.use_mask = use_mask
|
self.use_mask = use_mask
|
||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
self.loss_weight = loss_weight
|
self.loss_weight = loss_weight
|
||||||
self.class_weight = class_weight
|
self.class_weight = get_class_weight(class_weight)
|
||||||
|
|
||||||
if self.use_sigmoid:
|
if self.use_sigmoid:
|
||||||
self.cls_criterion = binary_cross_entropy
|
self.cls_criterion = binary_cross_entropy
|
||||||
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from ..builder import LOSSES
|
from ..builder import LOSSES
|
||||||
from .utils import weighted_loss
|
from .utils import get_class_weight, weighted_loss
|
||||||
|
|
||||||
|
|
||||||
@weighted_loss
|
@weighted_loss
|
||||||
@ -63,8 +63,8 @@ class DiceLoss(nn.Module):
|
|||||||
reduction (str, optional): The method used to reduce the loss. Options
|
reduction (str, optional): The method used to reduce the loss. Options
|
||||||
are "none", "mean" and "sum". This parameter only works when
|
are "none", "mean" and "sum". This parameter only works when
|
||||||
per_image is True. Default: 'mean'.
|
per_image is True. Default: 'mean'.
|
||||||
class_weight (list[float], optional): The weight for each class.
|
class_weight (list[float] | str, optional): Weight of each class. If in
|
||||||
Default: None.
|
str format, read them from a file. Defaults to None.
|
||||||
loss_weight (float, optional): Weight of the loss. Default to 1.0.
|
loss_weight (float, optional): Weight of the loss. Default to 1.0.
|
||||||
ignore_index (int | None): The label index to be ignored. Default: 255.
|
ignore_index (int | None): The label index to be ignored. Default: 255.
|
||||||
"""
|
"""
|
||||||
@ -81,7 +81,7 @@ class DiceLoss(nn.Module):
|
|||||||
self.smooth = smooth
|
self.smooth = smooth
|
||||||
self.exponent = exponent
|
self.exponent = exponent
|
||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
self.class_weight = class_weight
|
self.class_weight = get_class_weight(class_weight)
|
||||||
self.loss_weight = loss_weight
|
self.loss_weight = loss_weight
|
||||||
self.ignore_index = ignore_index
|
self.ignore_index = ignore_index
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from ..builder import LOSSES
|
from ..builder import LOSSES
|
||||||
from .utils import weight_reduce_loss
|
from .utils import get_class_weight, weight_reduce_loss
|
||||||
|
|
||||||
|
|
||||||
def lovasz_grad(gt_sorted):
|
def lovasz_grad(gt_sorted):
|
||||||
@ -240,8 +240,8 @@ class LovaszLoss(nn.Module):
|
|||||||
reduction (str, optional): The method used to reduce the loss. Options
|
reduction (str, optional): The method used to reduce the loss. Options
|
||||||
are "none", "mean" and "sum". This parameter only works when
|
are "none", "mean" and "sum". This parameter only works when
|
||||||
per_image is True. Default: 'mean'.
|
per_image is True. Default: 'mean'.
|
||||||
class_weight (list[float], optional): The weight for each class.
|
class_weight (list[float] | str, optional): Weight of each class. If in
|
||||||
Default: None.
|
str format, read them from a file. Defaults to None.
|
||||||
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
|
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -269,7 +269,7 @@ class LovaszLoss(nn.Module):
|
|||||||
self.per_image = per_image
|
self.per_image = per_image
|
||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
self.loss_weight = loss_weight
|
self.loss_weight = loss_weight
|
||||||
self.class_weight = class_weight
|
self.class_weight = get_class_weight(class_weight)
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
cls_score,
|
cls_score,
|
||||||
|
@ -1,8 +1,28 @@
|
|||||||
import functools
|
import functools
|
||||||
|
|
||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def get_class_weight(class_weight):
|
||||||
|
"""Get class weight for loss function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_weight (list[float] | str | None): If class_weight is a str,
|
||||||
|
take it as a file name and read from it.
|
||||||
|
"""
|
||||||
|
if isinstance(class_weight, str):
|
||||||
|
# take it as a file path
|
||||||
|
if class_weight.endswith('.npy'):
|
||||||
|
class_weight = np.load(class_weight)
|
||||||
|
else:
|
||||||
|
# pkl, json or yaml
|
||||||
|
class_weight = mmcv.load(class_weight)
|
||||||
|
|
||||||
|
return class_weight
|
||||||
|
|
||||||
|
|
||||||
def reduce_loss(loss, reduction):
|
def reduce_loss(loss, reduction):
|
||||||
"""Reduce loss as specified.
|
"""Reduce loss as specified.
|
||||||
|
|
||||||
|
@ -25,6 +25,34 @@ def test_ce_loss():
|
|||||||
fake_label = torch.Tensor([1]).long()
|
fake_label = torch.Tensor([1]).long()
|
||||||
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))
|
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))
|
||||||
|
|
||||||
|
# test loss with class weights from file
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
tmp_file = tempfile.NamedTemporaryFile()
|
||||||
|
|
||||||
|
mmcv.dump([0.8, 0.2], f'{tmp_file.name}.pkl', 'pkl') # from pkl file
|
||||||
|
loss_cls_cfg = dict(
|
||||||
|
type='CrossEntropyLoss',
|
||||||
|
use_sigmoid=False,
|
||||||
|
class_weight=f'{tmp_file.name}.pkl',
|
||||||
|
loss_weight=1.0)
|
||||||
|
loss_cls = build_loss(loss_cls_cfg)
|
||||||
|
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))
|
||||||
|
|
||||||
|
np.save(f'{tmp_file.name}.npy', np.array([0.8, 0.2])) # from npy file
|
||||||
|
loss_cls_cfg = dict(
|
||||||
|
type='CrossEntropyLoss',
|
||||||
|
use_sigmoid=False,
|
||||||
|
class_weight=f'{tmp_file.name}.npy',
|
||||||
|
loss_weight=1.0)
|
||||||
|
loss_cls = build_loss(loss_cls_cfg)
|
||||||
|
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))
|
||||||
|
tmp_file.close()
|
||||||
|
os.remove(f'{tmp_file.name}.pkl')
|
||||||
|
os.remove(f'{tmp_file.name}.npy')
|
||||||
|
|
||||||
loss_cls_cfg = dict(
|
loss_cls_cfg = dict(
|
||||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
|
||||||
loss_cls = build_loss(loss_cls_cfg)
|
loss_cls = build_loss(loss_cls_cfg)
|
||||||
|
@ -16,6 +16,36 @@ def test_dice_lose():
|
|||||||
labels = (torch.rand(8, 4, 4) * 3).long()
|
labels = (torch.rand(8, 4, 4) * 3).long()
|
||||||
dice_loss(logits, labels)
|
dice_loss(logits, labels)
|
||||||
|
|
||||||
|
# test loss with class weights from file
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
tmp_file = tempfile.NamedTemporaryFile()
|
||||||
|
|
||||||
|
mmcv.dump([1.0, 2.0, 3.0], f'{tmp_file.name}.pkl', 'pkl') # from pkl file
|
||||||
|
loss_cfg = dict(
|
||||||
|
type='DiceLoss',
|
||||||
|
reduction='none',
|
||||||
|
class_weight=f'{tmp_file.name}.pkl',
|
||||||
|
loss_weight=1.0,
|
||||||
|
ignore_index=1)
|
||||||
|
dice_loss = build_loss(loss_cfg)
|
||||||
|
dice_loss(logits, labels, ignore_index=None)
|
||||||
|
|
||||||
|
np.save(f'{tmp_file.name}.npy', np.array([1.0, 2.0, 3.0])) # from npy file
|
||||||
|
loss_cfg = dict(
|
||||||
|
type='DiceLoss',
|
||||||
|
reduction='none',
|
||||||
|
class_weight=f'{tmp_file.name}.pkl',
|
||||||
|
loss_weight=1.0,
|
||||||
|
ignore_index=1)
|
||||||
|
dice_loss = build_loss(loss_cfg)
|
||||||
|
dice_loss(logits, labels, ignore_index=None)
|
||||||
|
tmp_file.close()
|
||||||
|
os.remove(f'{tmp_file.name}.pkl')
|
||||||
|
os.remove(f'{tmp_file.name}.npy')
|
||||||
|
|
||||||
# test dice loss with loss_type = 'binary'
|
# test dice loss with loss_type = 'binary'
|
||||||
loss_cfg = dict(
|
loss_cfg = dict(
|
||||||
type='DiceLoss',
|
type='DiceLoss',
|
||||||
|
@ -38,6 +38,36 @@ def test_lovasz_loss():
|
|||||||
labels = (torch.rand(1, 4, 4) * 2).long()
|
labels = (torch.rand(1, 4, 4) * 2).long()
|
||||||
lovasz_loss(logits, labels, ignore_index=None)
|
lovasz_loss(logits, labels, ignore_index=None)
|
||||||
|
|
||||||
|
# test loss with class weights from file
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
tmp_file = tempfile.NamedTemporaryFile()
|
||||||
|
|
||||||
|
mmcv.dump([1.0, 2.0, 3.0], f'{tmp_file.name}.pkl', 'pkl') # from pkl file
|
||||||
|
loss_cfg = dict(
|
||||||
|
type='LovaszLoss',
|
||||||
|
per_image=True,
|
||||||
|
reduction='mean',
|
||||||
|
class_weight=f'{tmp_file.name}.pkl',
|
||||||
|
loss_weight=1.0)
|
||||||
|
lovasz_loss = build_loss(loss_cfg)
|
||||||
|
lovasz_loss(logits, labels, ignore_index=None)
|
||||||
|
|
||||||
|
np.save(f'{tmp_file.name}.npy', np.array([1.0, 2.0, 3.0])) # from npy file
|
||||||
|
loss_cfg = dict(
|
||||||
|
type='LovaszLoss',
|
||||||
|
per_image=True,
|
||||||
|
reduction='mean',
|
||||||
|
class_weight=f'{tmp_file.name}.npy',
|
||||||
|
loss_weight=1.0)
|
||||||
|
lovasz_loss = build_loss(loss_cfg)
|
||||||
|
lovasz_loss(logits, labels, ignore_index=None)
|
||||||
|
tmp_file.close()
|
||||||
|
os.remove(f'{tmp_file.name}.pkl')
|
||||||
|
os.remove(f'{tmp_file.name}.npy')
|
||||||
|
|
||||||
# test lovasz loss with loss_type = 'binary' and per_image = False
|
# test lovasz loss with loss_type = 'binary' and per_image = False
|
||||||
loss_cfg = dict(
|
loss_cfg = dict(
|
||||||
type='LovaszLoss',
|
type='LovaszLoss',
|
||||||
|
Loading…
x
Reference in New Issue
Block a user