[Enhance] Support reading class_weight from file in loss functions to help MMDet3D (#513)

* support reading class_weight from file in loss function

* add unit test of loss with class_weight from file

* minor fix

* move get_class_weight to utils
pull/529/head
Ziyi Wu 2021-04-29 16:04:15 +08:00 committed by GitHub
parent 768f704517
commit cef8a4f611
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 120 additions and 12 deletions

View File

@ -3,7 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F
from ..builder import LOSSES
from .utils import weight_reduce_loss
from .utils import get_class_weight, weight_reduce_loss
def cross_entropy(pred,
@ -146,8 +146,8 @@ class CrossEntropyLoss(nn.Module):
Defaults to False.
reduction (str, optional): . Defaults to 'mean'.
Options are "none", "mean" and "sum".
class_weight (list[float], optional): Weight of each class.
Defaults to None.
class_weight (list[float] | str, optional): Weight of each class. If in
str format, read them from a file. Defaults to None.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
"""
@ -163,7 +163,7 @@ class CrossEntropyLoss(nn.Module):
self.use_mask = use_mask
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = class_weight
self.class_weight = get_class_weight(class_weight)
if self.use_sigmoid:
self.cls_criterion = binary_cross_entropy

View File

@ -5,7 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F
from ..builder import LOSSES
from .utils import weighted_loss
from .utils import get_class_weight, weighted_loss
@weighted_loss
@ -63,8 +63,8 @@ class DiceLoss(nn.Module):
reduction (str, optional): The method used to reduce the loss. Options
are "none", "mean" and "sum". This parameter only works when
per_image is True. Default: 'mean'.
class_weight (list[float], optional): The weight for each class.
Default: None.
class_weight (list[float] | str, optional): Weight of each class. If in
str format, read them from a file. Defaults to None.
loss_weight (float, optional): Weight of the loss. Default to 1.0.
ignore_index (int | None): The label index to be ignored. Default: 255.
"""
@ -81,7 +81,7 @@ class DiceLoss(nn.Module):
self.smooth = smooth
self.exponent = exponent
self.reduction = reduction
self.class_weight = class_weight
self.class_weight = get_class_weight(class_weight)
self.loss_weight = loss_weight
self.ignore_index = ignore_index

View File

@ -8,7 +8,7 @@ import torch.nn as nn
import torch.nn.functional as F
from ..builder import LOSSES
from .utils import weight_reduce_loss
from .utils import get_class_weight, weight_reduce_loss
def lovasz_grad(gt_sorted):
@ -240,8 +240,8 @@ class LovaszLoss(nn.Module):
reduction (str, optional): The method used to reduce the loss. Options
are "none", "mean" and "sum". This parameter only works when
per_image is True. Default: 'mean'.
class_weight (list[float], optional): The weight for each class.
Default: None.
class_weight (list[float] | str, optional): Weight of each class. If in
str format, read them from a file. Defaults to None.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
"""
@ -269,7 +269,7 @@ class LovaszLoss(nn.Module):
self.per_image = per_image
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = class_weight
self.class_weight = get_class_weight(class_weight)
def forward(self,
cls_score,

View File

@ -1,8 +1,28 @@
import functools
import mmcv
import numpy as np
import torch.nn.functional as F
def get_class_weight(class_weight):
"""Get class weight for loss function.
Args:
class_weight (list[float] | str | None): If class_weight is a str,
take it as a file name and read from it.
"""
if isinstance(class_weight, str):
# take it as a file path
if class_weight.endswith('.npy'):
class_weight = np.load(class_weight)
else:
# pkl, json or yaml
class_weight = mmcv.load(class_weight)
return class_weight
def reduce_loss(loss, reduction):
"""Reduce loss as specified.

View File

@ -25,6 +25,34 @@ def test_ce_loss():
fake_label = torch.Tensor([1]).long()
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))
# test loss with class weights from file
import os
import tempfile
import mmcv
import numpy as np
tmp_file = tempfile.NamedTemporaryFile()
mmcv.dump([0.8, 0.2], f'{tmp_file.name}.pkl', 'pkl') # from pkl file
loss_cls_cfg = dict(
type='CrossEntropyLoss',
use_sigmoid=False,
class_weight=f'{tmp_file.name}.pkl',
loss_weight=1.0)
loss_cls = build_loss(loss_cls_cfg)
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))
np.save(f'{tmp_file.name}.npy', np.array([0.8, 0.2])) # from npy file
loss_cls_cfg = dict(
type='CrossEntropyLoss',
use_sigmoid=False,
class_weight=f'{tmp_file.name}.npy',
loss_weight=1.0)
loss_cls = build_loss(loss_cls_cfg)
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))
tmp_file.close()
os.remove(f'{tmp_file.name}.pkl')
os.remove(f'{tmp_file.name}.npy')
loss_cls_cfg = dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
loss_cls = build_loss(loss_cls_cfg)

View File

@ -16,6 +16,36 @@ def test_dice_lose():
labels = (torch.rand(8, 4, 4) * 3).long()
dice_loss(logits, labels)
# test loss with class weights from file
import os
import tempfile
import mmcv
import numpy as np
tmp_file = tempfile.NamedTemporaryFile()
mmcv.dump([1.0, 2.0, 3.0], f'{tmp_file.name}.pkl', 'pkl') # from pkl file
loss_cfg = dict(
type='DiceLoss',
reduction='none',
class_weight=f'{tmp_file.name}.pkl',
loss_weight=1.0,
ignore_index=1)
dice_loss = build_loss(loss_cfg)
dice_loss(logits, labels, ignore_index=None)
np.save(f'{tmp_file.name}.npy', np.array([1.0, 2.0, 3.0])) # from npy file
loss_cfg = dict(
type='DiceLoss',
reduction='none',
class_weight=f'{tmp_file.name}.pkl',
loss_weight=1.0,
ignore_index=1)
dice_loss = build_loss(loss_cfg)
dice_loss(logits, labels, ignore_index=None)
tmp_file.close()
os.remove(f'{tmp_file.name}.pkl')
os.remove(f'{tmp_file.name}.npy')
# test dice loss with loss_type = 'binary'
loss_cfg = dict(
type='DiceLoss',

View File

@ -38,6 +38,36 @@ def test_lovasz_loss():
labels = (torch.rand(1, 4, 4) * 2).long()
lovasz_loss(logits, labels, ignore_index=None)
# test loss with class weights from file
import os
import tempfile
import mmcv
import numpy as np
tmp_file = tempfile.NamedTemporaryFile()
mmcv.dump([1.0, 2.0, 3.0], f'{tmp_file.name}.pkl', 'pkl') # from pkl file
loss_cfg = dict(
type='LovaszLoss',
per_image=True,
reduction='mean',
class_weight=f'{tmp_file.name}.pkl',
loss_weight=1.0)
lovasz_loss = build_loss(loss_cfg)
lovasz_loss(logits, labels, ignore_index=None)
np.save(f'{tmp_file.name}.npy', np.array([1.0, 2.0, 3.0])) # from npy file
loss_cfg = dict(
type='LovaszLoss',
per_image=True,
reduction='mean',
class_weight=f'{tmp_file.name}.npy',
loss_weight=1.0)
lovasz_loss = build_loss(loss_cfg)
lovasz_loss(logits, labels, ignore_index=None)
tmp_file.close()
os.remove(f'{tmp_file.name}.pkl')
os.remove(f'{tmp_file.name}.npy')
# test lovasz loss with loss_type = 'binary' and per_image = False
loss_cfg = dict(
type='LovaszLoss',