mmsegmentation/mmseg/models/assigners/match_cost.py

232 lines
7.9 KiB
Python
Raw Normal View History

[Feature] Support Side Adapter Network (#3232) ## Motivation Support SAN for Open-Vocabulary Semantic Segmentation Paper: [Side Adapter Network for Open-Vocabulary Semantic Segmentation](https://arxiv.org/abs/2302.12242) official Code: [SAN](https://github.com/MendelXu/SAN) ## Modification - Added the parameters of backbone vit for implementing the image encoder of CLIP. - Added text encoder code. - Added segmentor multimodel encoder-decoder code for open-vocabulary semantic segmentation. - Added SideAdapterNetwork decode head code. - Added config files for train and inference. - Added tools for converting pretrained models. - Added loss implementation for mask classification model, such as SAN, Maskformer and remove dependency on mmdetection. - Added test units for text encoder, multimodel encoder-decoder, san decode head and hungarian_assigner. ## Use cases ### Convert Models **pretrained SAN model** The official pretrained model can be downloaded from [san_clip_vit_b_16.pth](https://huggingface.co/Mendel192/san/blob/main/san_vit_b_16.pth) and [san_clip_vit_large_14.pth](https://huggingface.co/Mendel192/san/blob/main/san_vit_large_14.pth). Use tools/model_converters/san2mmseg.py to convert offcial model into mmseg style. `python tools/model_converters/san2mmseg.py <MODEL_PATH> <OUTPUT_PATH>` **pretrained CLIP model** Use the CLIP model provided by openai to train SAN. The CLIP model can be download from [ViT-B-16.pt](https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt) and [ViT-L-14-336px.pt](https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt). Use tools/model_converters/clip2mmseg.py to convert model into mmseg style. `python tools/model_converters/clip2mmseg.py <MODEL_PATH> <OUTPUT_PATH>` ### Inference test san_vit-base-16 model on coco-stuff164k dataset `python tools/test.py ./configs/san/san-vit-b16_coco-stuff164k-640x640.py <TRAINED_MODEL_PATH>` ### Train test san_vit-base-16 model on coco-stuff164k dataset `python tools/train.py ./configs/san/san-vit-b16_coco-stuff164k-640x640.py --cfg-options model.pretrained=<PRETRAINED_MODEL_PATH>` ## Comparision Results ### Train on COCO-Stuff164k | | | mIoU | mAcc | pAcc | | --------------- | ----- | ----- | ----- | ----- | | san-vit-base16 | official | 41.93 | 56.73 | 67.69 | | | mmseg | 41.93 | 56.84 | 67.84 | | san-vit-large14 | official | 45.57 | 59.52 | 69.76 | | | mmseg | 45.78 | 59.61 | 69.21 | ### Evaluate on Pascal Context | | | mIoU | mAcc | pAcc | | --------------- | ----- | ----- | ----- | ----- | | san-vit-base16 | official | 54.05 | 72.96 | 77.77 | | | mmseg | 54.04 | 73.74 | 77.71 | | san-vit-large14 | official | 57.53 | 77.56 | 78.89 | | | mmseg | 56.89 | 76.96 | 78.74 | ### Evaluate on Voc12Aug | | | mIoU | mAcc | pAcc | | --------------- | ----- | ----- | ----- | ----- | | san-vit-base16 | official | 93.86 | 96.61 | 97.11 | | | mmseg | 94.58 | 97.01 | 97.38 | | san-vit-large14 | official | 95.17 | 97.61 | 97.63 | | | mmseg | 95.58 | 97.75 | 97.79 | --------- Co-authored-by: CastleDream <35064479+CastleDream@users.noreply.github.com> Co-authored-by: yeedrag <46050186+yeedrag@users.noreply.github.com> Co-authored-by: Yang-ChangHui <71805205+Yang-Changhui@users.noreply.github.com> Co-authored-by: Xu CAO <49406546+SheffieldCao@users.noreply.github.com> Co-authored-by: xiexinch <xiexinch@outlook.com> Co-authored-by: 小飞猪 <106524776+ooooo-create@users.noreply.github.com>
2023-09-20 21:20:26 +08:00
# Copyright (c) OpenMMLab. All rights reserved.
from abc import abstractmethod
from typing import Union
import torch
import torch.nn.functional as F
from mmengine.structures import InstanceData
from torch import Tensor
from mmseg.registry import TASK_UTILS
class BaseMatchCost:
"""Base match cost class.
Args:
weight (Union[float, int]): Cost weight. Defaults to 1.
"""
def __init__(self, weight: Union[float, int] = 1.) -> None:
self.weight = weight
@abstractmethod
def __call__(self, pred_instances: InstanceData,
gt_instances: InstanceData, **kwargs) -> Tensor:
"""Compute match cost.
Args:
pred_instances (InstanceData): Instances of model predictions.
It often includes "labels" and "scores".
gt_instances (InstanceData): Ground truth of instance
annotations. It usually includes "labels".
Returns:
Tensor: Match Cost matrix of shape (num_preds, num_gts).
"""
pass
@TASK_UTILS.register_module()
class ClassificationCost(BaseMatchCost):
"""ClsSoftmaxCost.
Args:
weight (Union[float, int]): Cost weight. Defaults to 1.
Examples:
>>> from mmseg.models.assigners import ClassificationCost
>>> import torch
>>> self = ClassificationCost()
>>> cls_pred = torch.rand(4, 3)
>>> gt_labels = torch.tensor([0, 1, 2])
>>> factor = torch.tensor([10, 8, 10, 8])
>>> self(cls_pred, gt_labels)
tensor([[-0.3430, -0.3525, -0.3045],
[-0.3077, -0.2931, -0.3992],
[-0.3664, -0.3455, -0.2881],
[-0.3343, -0.2701, -0.3956]])
"""
def __init__(self, weight: Union[float, int] = 1) -> None:
super().__init__(weight=weight)
def __call__(self, pred_instances: InstanceData,
gt_instances: InstanceData, **kwargs) -> Tensor:
"""Compute match cost.
Args:
pred_instances (InstanceData): "scores" inside is
predicted classification logits, of shape
(num_queries, num_class).
gt_instances (InstanceData): "labels" inside should have
shape (num_gt, ).
Returns:
Tensor: Match Cost matrix of shape (num_preds, num_gts).
"""
assert hasattr(pred_instances, 'scores'), \
"pred_instances must contain 'scores'"
assert hasattr(gt_instances, 'labels'), \
"gt_instances must contain 'labels'"
pred_scores = pred_instances.scores
gt_labels = gt_instances.labels
pred_scores = pred_scores.softmax(-1)
cls_cost = -pred_scores[:, gt_labels]
return cls_cost * self.weight
@TASK_UTILS.register_module()
class DiceCost(BaseMatchCost):
"""Cost of mask assignments based on dice losses.
Args:
pred_act (bool): Whether to apply sigmoid to mask_pred.
Defaults to False.
eps (float): Defaults to 1e-3.
naive_dice (bool): If True, use the naive dice loss
in which the power of the number in the denominator is
the first power. If False, use the second power that
is adopted by K-Net and SOLO. Defaults to True.
weight (Union[float, int]): Cost weight. Defaults to 1.
"""
def __init__(self,
pred_act: bool = False,
eps: float = 1e-3,
naive_dice: bool = True,
weight: Union[float, int] = 1.) -> None:
super().__init__(weight=weight)
self.pred_act = pred_act
self.eps = eps
self.naive_dice = naive_dice
def _binary_mask_dice_loss(self, mask_preds: Tensor,
gt_masks: Tensor) -> Tensor:
"""
Args:
mask_preds (Tensor): Mask prediction in shape (num_queries, *).
gt_masks (Tensor): Ground truth in shape (num_gt, *)
store 0 or 1, 0 for negative class and 1 for
positive class.
Returns:
Tensor: Dice cost matrix in shape (num_queries, num_gt).
"""
mask_preds = mask_preds.flatten(1)
gt_masks = gt_masks.flatten(1).float()
numerator = 2 * torch.einsum('nc,mc->nm', mask_preds, gt_masks)
if self.naive_dice:
denominator = mask_preds.sum(-1)[:, None] + \
gt_masks.sum(-1)[None, :]
else:
denominator = mask_preds.pow(2).sum(1)[:, None] + \
gt_masks.pow(2).sum(1)[None, :]
loss = 1 - (numerator + self.eps) / (denominator + self.eps)
return loss
def __call__(self, pred_instances: InstanceData,
gt_instances: InstanceData, **kwargs) -> Tensor:
"""Compute match cost.
Args:
pred_instances (InstanceData): Predicted instances which
must contain "masks".
gt_instances (InstanceData): Ground truth which must contain
"mask".
Returns:
Tensor: Match Cost matrix of shape (num_preds, num_gts).
"""
assert hasattr(pred_instances, 'masks'), \
"pred_instances must contain 'masks'"
assert hasattr(gt_instances, 'masks'), \
"gt_instances must contain 'masks'"
pred_masks = pred_instances.masks
gt_masks = gt_instances.masks
if self.pred_act:
pred_masks = pred_masks.sigmoid()
dice_cost = self._binary_mask_dice_loss(pred_masks, gt_masks)
return dice_cost * self.weight
@TASK_UTILS.register_module()
class CrossEntropyLossCost(BaseMatchCost):
"""CrossEntropyLossCost.
Args:
use_sigmoid (bool): Whether the prediction uses sigmoid
of softmax. Defaults to True.
weight (Union[float, int]): Cost weight. Defaults to 1.
"""
def __init__(self,
use_sigmoid: bool = True,
weight: Union[float, int] = 1.) -> None:
super().__init__(weight=weight)
self.use_sigmoid = use_sigmoid
def _binary_cross_entropy(self, cls_pred: Tensor,
gt_labels: Tensor) -> Tensor:
"""
Args:
cls_pred (Tensor): The prediction with shape (num_queries, 1, *) or
(num_queries, *).
gt_labels (Tensor): The learning label of prediction with
shape (num_gt, *).
Returns:
Tensor: Cross entropy cost matrix in shape (num_queries, num_gt).
"""
cls_pred = cls_pred.flatten(1).float()
gt_labels = gt_labels.flatten(1).float()
n = cls_pred.shape[1]
pos = F.binary_cross_entropy_with_logits(
cls_pred, torch.ones_like(cls_pred), reduction='none')
neg = F.binary_cross_entropy_with_logits(
cls_pred, torch.zeros_like(cls_pred), reduction='none')
cls_cost = torch.einsum('nc,mc->nm', pos, gt_labels) + \
torch.einsum('nc,mc->nm', neg, 1 - gt_labels)
cls_cost = cls_cost / n
return cls_cost
def __call__(self, pred_instances: InstanceData,
gt_instances: InstanceData, **kwargs) -> Tensor:
"""Compute match cost.
Args:
pred_instances (:obj:`InstanceData`): Predicted instances which
must contain ``masks``.
gt_instances (:obj:`InstanceData`): Ground truth which must contain
``masks``.
Returns:
Tensor: Match Cost matrix of shape (num_preds, num_gts).
"""
assert hasattr(pred_instances, 'masks'), \
"pred_instances must contain 'masks'"
assert hasattr(gt_instances, 'masks'), \
"gt_instances must contain 'masks'"
pred_masks = pred_instances.masks
gt_masks = gt_instances.masks
if self.use_sigmoid:
cls_cost = self._binary_cross_entropy(pred_masks, gt_masks)
else:
raise NotImplementedError
return cls_cost * self.weight