[Enhance] Support reading class_weight from file in loss functions to help MMDet3D (#513)
* support reading class_weight from file in loss function * add unit test of loss with class_weight from file * minor fix * move get_class_weight to utilspull/529/head
parent
768f704517
commit
cef8a4f611
|
@ -3,7 +3,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
from ..builder import LOSSES
|
||||
from .utils import weight_reduce_loss
|
||||
from .utils import get_class_weight, weight_reduce_loss
|
||||
|
||||
|
||||
def cross_entropy(pred,
|
||||
|
@ -146,8 +146,8 @@ class CrossEntropyLoss(nn.Module):
|
|||
Defaults to False.
|
||||
reduction (str, optional): . Defaults to 'mean'.
|
||||
Options are "none", "mean" and "sum".
|
||||
class_weight (list[float], optional): Weight of each class.
|
||||
Defaults to None.
|
||||
class_weight (list[float] | str, optional): Weight of each class. If in
|
||||
str format, read them from a file. Defaults to None.
|
||||
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
|
||||
"""
|
||||
|
||||
|
@ -163,7 +163,7 @@ class CrossEntropyLoss(nn.Module):
|
|||
self.use_mask = use_mask
|
||||
self.reduction = reduction
|
||||
self.loss_weight = loss_weight
|
||||
self.class_weight = class_weight
|
||||
self.class_weight = get_class_weight(class_weight)
|
||||
|
||||
if self.use_sigmoid:
|
||||
self.cls_criterion = binary_cross_entropy
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
from ..builder import LOSSES
|
||||
from .utils import weighted_loss
|
||||
from .utils import get_class_weight, weighted_loss
|
||||
|
||||
|
||||
@weighted_loss
|
||||
|
@ -63,8 +63,8 @@ class DiceLoss(nn.Module):
|
|||
reduction (str, optional): The method used to reduce the loss. Options
|
||||
are "none", "mean" and "sum". This parameter only works when
|
||||
per_image is True. Default: 'mean'.
|
||||
class_weight (list[float], optional): The weight for each class.
|
||||
Default: None.
|
||||
class_weight (list[float] | str, optional): Weight of each class. If in
|
||||
str format, read them from a file. Defaults to None.
|
||||
loss_weight (float, optional): Weight of the loss. Default to 1.0.
|
||||
ignore_index (int | None): The label index to be ignored. Default: 255.
|
||||
"""
|
||||
|
@ -81,7 +81,7 @@ class DiceLoss(nn.Module):
|
|||
self.smooth = smooth
|
||||
self.exponent = exponent
|
||||
self.reduction = reduction
|
||||
self.class_weight = class_weight
|
||||
self.class_weight = get_class_weight(class_weight)
|
||||
self.loss_weight = loss_weight
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
from ..builder import LOSSES
|
||||
from .utils import weight_reduce_loss
|
||||
from .utils import get_class_weight, weight_reduce_loss
|
||||
|
||||
|
||||
def lovasz_grad(gt_sorted):
|
||||
|
@ -240,8 +240,8 @@ class LovaszLoss(nn.Module):
|
|||
reduction (str, optional): The method used to reduce the loss. Options
|
||||
are "none", "mean" and "sum". This parameter only works when
|
||||
per_image is True. Default: 'mean'.
|
||||
class_weight (list[float], optional): The weight for each class.
|
||||
Default: None.
|
||||
class_weight (list[float] | str, optional): Weight of each class. If in
|
||||
str format, read them from a file. Defaults to None.
|
||||
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
|
||||
"""
|
||||
|
||||
|
@ -269,7 +269,7 @@ class LovaszLoss(nn.Module):
|
|||
self.per_image = per_image
|
||||
self.reduction = reduction
|
||||
self.loss_weight = loss_weight
|
||||
self.class_weight = class_weight
|
||||
self.class_weight = get_class_weight(class_weight)
|
||||
|
||||
def forward(self,
|
||||
cls_score,
|
||||
|
|
|
@ -1,8 +1,28 @@
|
|||
import functools
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def get_class_weight(class_weight):
|
||||
"""Get class weight for loss function.
|
||||
|
||||
Args:
|
||||
class_weight (list[float] | str | None): If class_weight is a str,
|
||||
take it as a file name and read from it.
|
||||
"""
|
||||
if isinstance(class_weight, str):
|
||||
# take it as a file path
|
||||
if class_weight.endswith('.npy'):
|
||||
class_weight = np.load(class_weight)
|
||||
else:
|
||||
# pkl, json or yaml
|
||||
class_weight = mmcv.load(class_weight)
|
||||
|
||||
return class_weight
|
||||
|
||||
|
||||
def reduce_loss(loss, reduction):
|
||||
"""Reduce loss as specified.
|
||||
|
||||
|
|
|
@ -25,6 +25,34 @@ def test_ce_loss():
|
|||
fake_label = torch.Tensor([1]).long()
|
||||
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))
|
||||
|
||||
# test loss with class weights from file
|
||||
import os
|
||||
import tempfile
|
||||
import mmcv
|
||||
import numpy as np
|
||||
tmp_file = tempfile.NamedTemporaryFile()
|
||||
|
||||
mmcv.dump([0.8, 0.2], f'{tmp_file.name}.pkl', 'pkl') # from pkl file
|
||||
loss_cls_cfg = dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=False,
|
||||
class_weight=f'{tmp_file.name}.pkl',
|
||||
loss_weight=1.0)
|
||||
loss_cls = build_loss(loss_cls_cfg)
|
||||
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))
|
||||
|
||||
np.save(f'{tmp_file.name}.npy', np.array([0.8, 0.2])) # from npy file
|
||||
loss_cls_cfg = dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=False,
|
||||
class_weight=f'{tmp_file.name}.npy',
|
||||
loss_weight=1.0)
|
||||
loss_cls = build_loss(loss_cls_cfg)
|
||||
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))
|
||||
tmp_file.close()
|
||||
os.remove(f'{tmp_file.name}.pkl')
|
||||
os.remove(f'{tmp_file.name}.npy')
|
||||
|
||||
loss_cls_cfg = dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
|
||||
loss_cls = build_loss(loss_cls_cfg)
|
||||
|
|
|
@ -16,6 +16,36 @@ def test_dice_lose():
|
|||
labels = (torch.rand(8, 4, 4) * 3).long()
|
||||
dice_loss(logits, labels)
|
||||
|
||||
# test loss with class weights from file
|
||||
import os
|
||||
import tempfile
|
||||
import mmcv
|
||||
import numpy as np
|
||||
tmp_file = tempfile.NamedTemporaryFile()
|
||||
|
||||
mmcv.dump([1.0, 2.0, 3.0], f'{tmp_file.name}.pkl', 'pkl') # from pkl file
|
||||
loss_cfg = dict(
|
||||
type='DiceLoss',
|
||||
reduction='none',
|
||||
class_weight=f'{tmp_file.name}.pkl',
|
||||
loss_weight=1.0,
|
||||
ignore_index=1)
|
||||
dice_loss = build_loss(loss_cfg)
|
||||
dice_loss(logits, labels, ignore_index=None)
|
||||
|
||||
np.save(f'{tmp_file.name}.npy', np.array([1.0, 2.0, 3.0])) # from npy file
|
||||
loss_cfg = dict(
|
||||
type='DiceLoss',
|
||||
reduction='none',
|
||||
class_weight=f'{tmp_file.name}.pkl',
|
||||
loss_weight=1.0,
|
||||
ignore_index=1)
|
||||
dice_loss = build_loss(loss_cfg)
|
||||
dice_loss(logits, labels, ignore_index=None)
|
||||
tmp_file.close()
|
||||
os.remove(f'{tmp_file.name}.pkl')
|
||||
os.remove(f'{tmp_file.name}.npy')
|
||||
|
||||
# test dice loss with loss_type = 'binary'
|
||||
loss_cfg = dict(
|
||||
type='DiceLoss',
|
||||
|
|
|
@ -38,6 +38,36 @@ def test_lovasz_loss():
|
|||
labels = (torch.rand(1, 4, 4) * 2).long()
|
||||
lovasz_loss(logits, labels, ignore_index=None)
|
||||
|
||||
# test loss with class weights from file
|
||||
import os
|
||||
import tempfile
|
||||
import mmcv
|
||||
import numpy as np
|
||||
tmp_file = tempfile.NamedTemporaryFile()
|
||||
|
||||
mmcv.dump([1.0, 2.0, 3.0], f'{tmp_file.name}.pkl', 'pkl') # from pkl file
|
||||
loss_cfg = dict(
|
||||
type='LovaszLoss',
|
||||
per_image=True,
|
||||
reduction='mean',
|
||||
class_weight=f'{tmp_file.name}.pkl',
|
||||
loss_weight=1.0)
|
||||
lovasz_loss = build_loss(loss_cfg)
|
||||
lovasz_loss(logits, labels, ignore_index=None)
|
||||
|
||||
np.save(f'{tmp_file.name}.npy', np.array([1.0, 2.0, 3.0])) # from npy file
|
||||
loss_cfg = dict(
|
||||
type='LovaszLoss',
|
||||
per_image=True,
|
||||
reduction='mean',
|
||||
class_weight=f'{tmp_file.name}.npy',
|
||||
loss_weight=1.0)
|
||||
lovasz_loss = build_loss(loss_cfg)
|
||||
lovasz_loss(logits, labels, ignore_index=None)
|
||||
tmp_file.close()
|
||||
os.remove(f'{tmp_file.name}.pkl')
|
||||
os.remove(f'{tmp_file.name}.npy')
|
||||
|
||||
# test lovasz loss with loss_type = 'binary' and per_image = False
|
||||
loss_cfg = dict(
|
||||
type='LovaszLoss',
|
||||
|
|
Loading…
Reference in New Issue