174 lines
6.6 KiB
Python
174 lines
6.6 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
# migrate from mmdetection with modifications
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from ..builder import LOSSES
|
|
from .utils import weight_reduce_loss
|
|
|
|
|
|
def seesaw_ce_loss(cls_score,
|
|
labels,
|
|
weight,
|
|
cum_samples,
|
|
num_classes,
|
|
p,
|
|
q,
|
|
eps,
|
|
reduction='mean',
|
|
avg_factor=None):
|
|
"""Calculate the Seesaw CrossEntropy loss.
|
|
|
|
Args:
|
|
cls_score (torch.Tensor): The prediction with shape (N, C),
|
|
C is the number of classes.
|
|
labels (torch.Tensor): The learning label of the prediction.
|
|
weight (torch.Tensor): Sample-wise loss weight.
|
|
cum_samples (torch.Tensor): Cumulative samples for each category.
|
|
num_classes (int): The number of classes.
|
|
p (float): The ``p`` in the mitigation factor.
|
|
q (float): The ``q`` in the compenstation factor.
|
|
eps (float): The minimal value of divisor to smooth
|
|
the computation of compensation factor
|
|
reduction (str, optional): The method used to reduce the loss.
|
|
avg_factor (int, optional): Average factor that is used to average
|
|
the loss. Defaults to None.
|
|
|
|
Returns:
|
|
torch.Tensor: The calculated loss
|
|
"""
|
|
assert cls_score.size(-1) == num_classes
|
|
assert len(cum_samples) == num_classes
|
|
|
|
onehot_labels = F.one_hot(labels, num_classes)
|
|
seesaw_weights = cls_score.new_ones(onehot_labels.size())
|
|
|
|
# mitigation factor
|
|
if p > 0:
|
|
sample_ratio_matrix = cum_samples[None, :].clamp(
|
|
min=1) / cum_samples[:, None].clamp(min=1)
|
|
index = (sample_ratio_matrix < 1.0).float()
|
|
sample_weights = sample_ratio_matrix.pow(p) * index + (1 - index
|
|
) # M_{ij}
|
|
mitigation_factor = sample_weights[labels.long(), :]
|
|
seesaw_weights = seesaw_weights * mitigation_factor
|
|
|
|
# compensation factor
|
|
if q > 0:
|
|
scores = F.softmax(cls_score.detach(), dim=1)
|
|
self_scores = scores[
|
|
torch.arange(0, len(scores)).to(scores.device).long(),
|
|
labels.long()]
|
|
score_matrix = scores / self_scores[:, None].clamp(min=eps)
|
|
index = (score_matrix > 1.0).float()
|
|
compensation_factor = score_matrix.pow(q) * index + (1 - index)
|
|
seesaw_weights = seesaw_weights * compensation_factor
|
|
|
|
cls_score = cls_score + (seesaw_weights.log() * (1 - onehot_labels))
|
|
|
|
loss = F.cross_entropy(cls_score, labels, weight=None, reduction='none')
|
|
|
|
if weight is not None:
|
|
weight = weight.float()
|
|
loss = weight_reduce_loss(
|
|
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
|
|
return loss
|
|
|
|
|
|
@LOSSES.register_module()
|
|
class SeesawLoss(nn.Module):
|
|
"""Implementation of seesaw loss.
|
|
|
|
Refers to `Seesaw Loss for Long-Tailed Instance Segmentation (CVPR 2021)
|
|
<https://arxiv.org/abs/2008.10032>`_
|
|
|
|
Args:
|
|
use_sigmoid (bool): Whether the prediction uses sigmoid of softmax.
|
|
Only False is supported. Defaults to False.
|
|
p (float): The ``p`` in the mitigation factor.
|
|
Defaults to 0.8.
|
|
q (float): The ``q`` in the compenstation factor.
|
|
Defaults to 2.0.
|
|
num_classes (int): The number of classes.
|
|
Default to 1000 for the ImageNet dataset.
|
|
eps (float): The minimal value of divisor to smooth
|
|
the computation of compensation factor, default to 1e-2.
|
|
reduction (str): The method that reduces the loss to a scalar.
|
|
Options are "none", "mean" and "sum". Default to "mean".
|
|
loss_weight (float): The weight of the loss. Defaults to 1.0
|
|
"""
|
|
|
|
def __init__(self,
|
|
use_sigmoid=False,
|
|
p=0.8,
|
|
q=2.0,
|
|
num_classes=1000,
|
|
eps=1e-2,
|
|
reduction='mean',
|
|
loss_weight=1.0):
|
|
super(SeesawLoss, self).__init__()
|
|
assert not use_sigmoid, '`use_sigmoid` is not supported'
|
|
self.use_sigmoid = False
|
|
self.p = p
|
|
self.q = q
|
|
self.num_classes = num_classes
|
|
self.eps = eps
|
|
self.reduction = reduction
|
|
self.loss_weight = loss_weight
|
|
|
|
self.cls_criterion = seesaw_ce_loss
|
|
|
|
# cumulative samples for each category
|
|
self.register_buffer('cum_samples',
|
|
torch.zeros(self.num_classes, dtype=torch.float))
|
|
|
|
def forward(self,
|
|
cls_score,
|
|
labels,
|
|
weight=None,
|
|
avg_factor=None,
|
|
reduction_override=None):
|
|
"""Forward function.
|
|
|
|
Args:
|
|
cls_score (torch.Tensor): The prediction with shape (N, C).
|
|
labels (torch.Tensor): The learning label of the prediction.
|
|
weight (torch.Tensor, optional): Sample-wise loss weight.
|
|
avg_factor (int, optional): Average factor that is used to average
|
|
the loss. Defaults to None.
|
|
reduction (str, optional): The method used to reduce the loss.
|
|
Options are "none", "mean" and "sum".
|
|
Returns:
|
|
torch.Tensor: The calculated loss
|
|
"""
|
|
assert reduction_override in (None, 'none', 'mean', 'sum'), \
|
|
f'The `reduction_override` should be one of (None, "none", ' \
|
|
f'"mean", "sum"), but get "{reduction_override}".'
|
|
assert cls_score.size(0) == labels.view(-1).size(0), \
|
|
f'Expected `labels` shape [{cls_score.size(0)}], ' \
|
|
f'but got {list(labels.size())}'
|
|
reduction = (
|
|
reduction_override if reduction_override else self.reduction)
|
|
assert cls_score.size(-1) == self.num_classes, \
|
|
f'The channel number of output ({cls_score.size(-1)}) does ' \
|
|
f'not match the `num_classes` of seesaw loss ({self.num_classes}).'
|
|
|
|
# accumulate the samples for each category
|
|
unique_labels = labels.unique()
|
|
for u_l in unique_labels:
|
|
inds_ = labels == u_l.item()
|
|
self.cum_samples[u_l] += inds_.sum()
|
|
|
|
if weight is not None:
|
|
weight = weight.float()
|
|
else:
|
|
weight = labels.new_ones(labels.size(), dtype=torch.float)
|
|
|
|
# calculate loss_cls_classes
|
|
loss_cls = self.loss_weight * self.cls_criterion(
|
|
cls_score, labels, weight, self.cum_samples, self.num_classes,
|
|
self.p, self.q, self.eps, reduction, avg_factor)
|
|
|
|
return loss_cls
|