232 lines
7.9 KiB
Python
232 lines
7.9 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from abc import abstractmethod
|
|
from typing import Union
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from mmengine.structures import InstanceData
|
|
from torch import Tensor
|
|
|
|
from mmseg.registry import TASK_UTILS
|
|
|
|
|
|
class BaseMatchCost:
|
|
"""Base match cost class.
|
|
|
|
Args:
|
|
weight (Union[float, int]): Cost weight. Defaults to 1.
|
|
"""
|
|
|
|
def __init__(self, weight: Union[float, int] = 1.) -> None:
|
|
self.weight = weight
|
|
|
|
@abstractmethod
|
|
def __call__(self, pred_instances: InstanceData,
|
|
gt_instances: InstanceData, **kwargs) -> Tensor:
|
|
"""Compute match cost.
|
|
|
|
Args:
|
|
pred_instances (InstanceData): Instances of model predictions.
|
|
It often includes "labels" and "scores".
|
|
gt_instances (InstanceData): Ground truth of instance
|
|
annotations. It usually includes "labels".
|
|
|
|
Returns:
|
|
Tensor: Match Cost matrix of shape (num_preds, num_gts).
|
|
"""
|
|
pass
|
|
|
|
|
|
@TASK_UTILS.register_module()
|
|
class ClassificationCost(BaseMatchCost):
|
|
"""ClsSoftmaxCost.
|
|
|
|
Args:
|
|
weight (Union[float, int]): Cost weight. Defaults to 1.
|
|
|
|
Examples:
|
|
>>> from mmseg.models.assigners import ClassificationCost
|
|
>>> import torch
|
|
>>> self = ClassificationCost()
|
|
>>> cls_pred = torch.rand(4, 3)
|
|
>>> gt_labels = torch.tensor([0, 1, 2])
|
|
>>> factor = torch.tensor([10, 8, 10, 8])
|
|
>>> self(cls_pred, gt_labels)
|
|
tensor([[-0.3430, -0.3525, -0.3045],
|
|
[-0.3077, -0.2931, -0.3992],
|
|
[-0.3664, -0.3455, -0.2881],
|
|
[-0.3343, -0.2701, -0.3956]])
|
|
"""
|
|
|
|
def __init__(self, weight: Union[float, int] = 1) -> None:
|
|
super().__init__(weight=weight)
|
|
|
|
def __call__(self, pred_instances: InstanceData,
|
|
gt_instances: InstanceData, **kwargs) -> Tensor:
|
|
"""Compute match cost.
|
|
|
|
Args:
|
|
pred_instances (InstanceData): "scores" inside is
|
|
predicted classification logits, of shape
|
|
(num_queries, num_class).
|
|
gt_instances (InstanceData): "labels" inside should have
|
|
shape (num_gt, ).
|
|
|
|
Returns:
|
|
Tensor: Match Cost matrix of shape (num_preds, num_gts).
|
|
"""
|
|
assert hasattr(pred_instances, 'scores'), \
|
|
"pred_instances must contain 'scores'"
|
|
assert hasattr(gt_instances, 'labels'), \
|
|
"gt_instances must contain 'labels'"
|
|
pred_scores = pred_instances.scores
|
|
gt_labels = gt_instances.labels
|
|
|
|
pred_scores = pred_scores.softmax(-1)
|
|
cls_cost = -pred_scores[:, gt_labels]
|
|
|
|
return cls_cost * self.weight
|
|
|
|
|
|
@TASK_UTILS.register_module()
|
|
class DiceCost(BaseMatchCost):
|
|
"""Cost of mask assignments based on dice losses.
|
|
|
|
Args:
|
|
pred_act (bool): Whether to apply sigmoid to mask_pred.
|
|
Defaults to False.
|
|
eps (float): Defaults to 1e-3.
|
|
naive_dice (bool): If True, use the naive dice loss
|
|
in which the power of the number in the denominator is
|
|
the first power. If False, use the second power that
|
|
is adopted by K-Net and SOLO. Defaults to True.
|
|
weight (Union[float, int]): Cost weight. Defaults to 1.
|
|
"""
|
|
|
|
def __init__(self,
|
|
pred_act: bool = False,
|
|
eps: float = 1e-3,
|
|
naive_dice: bool = True,
|
|
weight: Union[float, int] = 1.) -> None:
|
|
super().__init__(weight=weight)
|
|
self.pred_act = pred_act
|
|
self.eps = eps
|
|
self.naive_dice = naive_dice
|
|
|
|
def _binary_mask_dice_loss(self, mask_preds: Tensor,
|
|
gt_masks: Tensor) -> Tensor:
|
|
"""
|
|
Args:
|
|
mask_preds (Tensor): Mask prediction in shape (num_queries, *).
|
|
gt_masks (Tensor): Ground truth in shape (num_gt, *)
|
|
store 0 or 1, 0 for negative class and 1 for
|
|
positive class.
|
|
|
|
Returns:
|
|
Tensor: Dice cost matrix in shape (num_queries, num_gt).
|
|
"""
|
|
mask_preds = mask_preds.flatten(1)
|
|
gt_masks = gt_masks.flatten(1).float()
|
|
numerator = 2 * torch.einsum('nc,mc->nm', mask_preds, gt_masks)
|
|
if self.naive_dice:
|
|
denominator = mask_preds.sum(-1)[:, None] + \
|
|
gt_masks.sum(-1)[None, :]
|
|
else:
|
|
denominator = mask_preds.pow(2).sum(1)[:, None] + \
|
|
gt_masks.pow(2).sum(1)[None, :]
|
|
loss = 1 - (numerator + self.eps) / (denominator + self.eps)
|
|
return loss
|
|
|
|
def __call__(self, pred_instances: InstanceData,
|
|
gt_instances: InstanceData, **kwargs) -> Tensor:
|
|
"""Compute match cost.
|
|
|
|
Args:
|
|
pred_instances (InstanceData): Predicted instances which
|
|
must contain "masks".
|
|
gt_instances (InstanceData): Ground truth which must contain
|
|
"mask".
|
|
|
|
Returns:
|
|
Tensor: Match Cost matrix of shape (num_preds, num_gts).
|
|
"""
|
|
assert hasattr(pred_instances, 'masks'), \
|
|
"pred_instances must contain 'masks'"
|
|
assert hasattr(gt_instances, 'masks'), \
|
|
"gt_instances must contain 'masks'"
|
|
pred_masks = pred_instances.masks
|
|
gt_masks = gt_instances.masks
|
|
|
|
if self.pred_act:
|
|
pred_masks = pred_masks.sigmoid()
|
|
dice_cost = self._binary_mask_dice_loss(pred_masks, gt_masks)
|
|
return dice_cost * self.weight
|
|
|
|
|
|
@TASK_UTILS.register_module()
|
|
class CrossEntropyLossCost(BaseMatchCost):
|
|
"""CrossEntropyLossCost.
|
|
|
|
Args:
|
|
use_sigmoid (bool): Whether the prediction uses sigmoid
|
|
of softmax. Defaults to True.
|
|
weight (Union[float, int]): Cost weight. Defaults to 1.
|
|
"""
|
|
|
|
def __init__(self,
|
|
use_sigmoid: bool = True,
|
|
weight: Union[float, int] = 1.) -> None:
|
|
super().__init__(weight=weight)
|
|
self.use_sigmoid = use_sigmoid
|
|
|
|
def _binary_cross_entropy(self, cls_pred: Tensor,
|
|
gt_labels: Tensor) -> Tensor:
|
|
"""
|
|
Args:
|
|
cls_pred (Tensor): The prediction with shape (num_queries, 1, *) or
|
|
(num_queries, *).
|
|
gt_labels (Tensor): The learning label of prediction with
|
|
shape (num_gt, *).
|
|
|
|
Returns:
|
|
Tensor: Cross entropy cost matrix in shape (num_queries, num_gt).
|
|
"""
|
|
cls_pred = cls_pred.flatten(1).float()
|
|
gt_labels = gt_labels.flatten(1).float()
|
|
n = cls_pred.shape[1]
|
|
pos = F.binary_cross_entropy_with_logits(
|
|
cls_pred, torch.ones_like(cls_pred), reduction='none')
|
|
neg = F.binary_cross_entropy_with_logits(
|
|
cls_pred, torch.zeros_like(cls_pred), reduction='none')
|
|
cls_cost = torch.einsum('nc,mc->nm', pos, gt_labels) + \
|
|
torch.einsum('nc,mc->nm', neg, 1 - gt_labels)
|
|
cls_cost = cls_cost / n
|
|
|
|
return cls_cost
|
|
|
|
def __call__(self, pred_instances: InstanceData,
|
|
gt_instances: InstanceData, **kwargs) -> Tensor:
|
|
"""Compute match cost.
|
|
|
|
Args:
|
|
pred_instances (:obj:`InstanceData`): Predicted instances which
|
|
must contain ``masks``.
|
|
gt_instances (:obj:`InstanceData`): Ground truth which must contain
|
|
``masks``.
|
|
|
|
Returns:
|
|
Tensor: Match Cost matrix of shape (num_preds, num_gts).
|
|
"""
|
|
assert hasattr(pred_instances, 'masks'), \
|
|
"pred_instances must contain 'masks'"
|
|
assert hasattr(gt_instances, 'masks'), \
|
|
"gt_instances must contain 'masks'"
|
|
pred_masks = pred_instances.masks
|
|
gt_masks = gt_instances.masks
|
|
if self.use_sigmoid:
|
|
cls_cost = self._binary_cross_entropy(pred_masks, gt_masks)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
return cls_cost * self.weight
|