mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
## Motivation Support SAN for Open-Vocabulary Semantic Segmentation Paper: [Side Adapter Network for Open-Vocabulary Semantic Segmentation](https://arxiv.org/abs/2302.12242) official Code: [SAN](https://github.com/MendelXu/SAN) ## Modification - Added the parameters of backbone vit for implementing the image encoder of CLIP. - Added text encoder code. - Added segmentor multimodel encoder-decoder code for open-vocabulary semantic segmentation. - Added SideAdapterNetwork decode head code. - Added config files for train and inference. - Added tools for converting pretrained models. - Added loss implementation for mask classification model, such as SAN, Maskformer and remove dependency on mmdetection. - Added test units for text encoder, multimodel encoder-decoder, san decode head and hungarian_assigner. ## Use cases ### Convert Models **pretrained SAN model** The official pretrained model can be downloaded from [san_clip_vit_b_16.pth](https://huggingface.co/Mendel192/san/blob/main/san_vit_b_16.pth) and [san_clip_vit_large_14.pth](https://huggingface.co/Mendel192/san/blob/main/san_vit_large_14.pth). Use tools/model_converters/san2mmseg.py to convert offcial model into mmseg style. `python tools/model_converters/san2mmseg.py <MODEL_PATH> <OUTPUT_PATH>` **pretrained CLIP model** Use the CLIP model provided by openai to train SAN. The CLIP model can be download from [ViT-B-16.pt](https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt) and [ViT-L-14-336px.pt](https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt). Use tools/model_converters/clip2mmseg.py to convert model into mmseg style. `python tools/model_converters/clip2mmseg.py <MODEL_PATH> <OUTPUT_PATH>` ### Inference test san_vit-base-16 model on coco-stuff164k dataset `python tools/test.py ./configs/san/san-vit-b16_coco-stuff164k-640x640.py <TRAINED_MODEL_PATH>` ### Train test san_vit-base-16 model on coco-stuff164k dataset `python tools/train.py ./configs/san/san-vit-b16_coco-stuff164k-640x640.py --cfg-options model.pretrained=<PRETRAINED_MODEL_PATH>` ## Comparision Results ### Train on COCO-Stuff164k | | | mIoU | mAcc | pAcc | | --------------- | ----- | ----- | ----- | ----- | | san-vit-base16 | official | 41.93 | 56.73 | 67.69 | | | mmseg | 41.93 | 56.84 | 67.84 | | san-vit-large14 | official | 45.57 | 59.52 | 69.76 | | | mmseg | 45.78 | 59.61 | 69.21 | ### Evaluate on Pascal Context | | | mIoU | mAcc | pAcc | | --------------- | ----- | ----- | ----- | ----- | | san-vit-base16 | official | 54.05 | 72.96 | 77.77 | | | mmseg | 54.04 | 73.74 | 77.71 | | san-vit-large14 | official | 57.53 | 77.56 | 78.89 | | | mmseg | 56.89 | 76.96 | 78.74 | ### Evaluate on Voc12Aug | | | mIoU | mAcc | pAcc | | --------------- | ----- | ----- | ----- | ----- | | san-vit-base16 | official | 93.86 | 96.61 | 97.11 | | | mmseg | 94.58 | 97.01 | 97.38 | | san-vit-large14 | official | 95.17 | 97.61 | 97.63 | | | mmseg | 95.58 | 97.75 | 97.79 | --------- Co-authored-by: CastleDream <35064479+CastleDream@users.noreply.github.com> Co-authored-by: yeedrag <46050186+yeedrag@users.noreply.github.com> Co-authored-by: Yang-ChangHui <71805205+Yang-Changhui@users.noreply.github.com> Co-authored-by: Xu CAO <49406546+SheffieldCao@users.noreply.github.com> Co-authored-by: xiexinch <xiexinch@outlook.com> Co-authored-by: 小飞猪 <106524776+ooooo-create@users.noreply.github.com>
351 lines
14 KiB
Python
351 lines
14 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import List, Optional
|
|
|
|
import torch.nn.functional as F
|
|
from torch import Tensor
|
|
|
|
from mmseg.registry import MODELS
|
|
from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig,
|
|
OptSampleList, SampleList, add_prefix)
|
|
from .base import BaseSegmentor
|
|
|
|
|
|
@MODELS.register_module()
|
|
class MultimodalEncoderDecoder(BaseSegmentor):
|
|
"""Multimodal Encoder-Decoder segmentors.
|
|
|
|
Multimodal segmentation architecture is used for open-vocabulary
|
|
semantic segmentation with combining the visual and language
|
|
pretrain models. It consists of a image_encoder (backbone) to extract
|
|
visual feature, a text encoder to extract text feature, and a decode
|
|
head to generate semantic maps.
|
|
Note that the deep supervision during training is implemented in decode head.
|
|
|
|
1. The ``loss`` method is used to calculate the loss of model,
|
|
which includes two steps: (1) Extracts features to obtain the feature maps
|
|
(2) Call the decode head loss function to forward decode head model and
|
|
calculate losses.
|
|
|
|
.. code:: text
|
|
|
|
loss(): extract_feat() -> _decode_head_forward_train()
|
|
_decode_head_forward_train(): decode_head.loss()
|
|
|
|
2. The ``predict`` method is used to predict segmentation results,
|
|
which includes two steps: (1) Run inference function to obtain the list of
|
|
seg_logits (2) Call post-processing function to obtain list of
|
|
``SegDataSampel`` including ``pred_sem_seg`` and ``seg_logits``.
|
|
|
|
.. code:: text
|
|
|
|
predict(): inference() -> postprocess_result()
|
|
inference(): whole_inference()/slide_inference()
|
|
whole_inference()/slide_inference(): encoder_decoder()
|
|
encoder_decoder(): extract_feat() -> decode_head.predict()
|
|
|
|
3. The ``_forward`` method is used to output the tensor by running the model,
|
|
which includes two steps: (1) Extracts features to obtain the feature maps
|
|
(2)Call the decode head forward function to forward decode head model.
|
|
|
|
.. code:: text
|
|
|
|
_forward(): extract_feat() -> _decode_head.forward()
|
|
|
|
Args:
|
|
|
|
image_encoder (ConfigType): The config for the visual encoder of segmentor.
|
|
text_encoder ((ConfigType): The config for the text encoder of segmentor.
|
|
decode_head (ConfigType): The config for the decode head of segmentor.
|
|
train_cfg (OptConfigType): The config for training. Defaults to None.
|
|
test_cfg (OptConfigType): The config for testing. Defaults to None.
|
|
data_preprocessor (dict, optional): The pre-process config of
|
|
:class:`BaseDataPreprocessor`.
|
|
pretrained (str, optional): The path for pretrained model.
|
|
Defaults to None.
|
|
asymetric_input (bool): whether to use different size of input for image encoder
|
|
and decode head. Defaults to False.
|
|
encoder_resolution (float): resize scale of input images for image encoder.
|
|
Defaults to None.
|
|
init_cfg (dict, optional): The weight initialized config for
|
|
:class:`BaseModule`.
|
|
""" # noqa: E501
|
|
|
|
def __init__(self,
|
|
image_encoder: ConfigType,
|
|
text_encoder: ConfigType,
|
|
decode_head: ConfigType,
|
|
train_cfg: OptConfigType = None,
|
|
test_cfg: OptConfigType = None,
|
|
data_preprocessor: OptConfigType = None,
|
|
pretrained: Optional[str] = None,
|
|
asymetric_input: bool = True,
|
|
encoder_resolution: float = None,
|
|
init_cfg: OptMultiConfig = None):
|
|
super().__init__(
|
|
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
|
|
if pretrained is not None:
|
|
image_encoder.init_cfg = dict(
|
|
type='Pretrained_Part', checkpoint=pretrained)
|
|
text_encoder.init_cfg = dict(
|
|
type='Pretrained_Part', checkpoint=pretrained)
|
|
decode_head.init_cfg = dict(
|
|
type='Pretrained_Part', checkpoint=pretrained)
|
|
|
|
if asymetric_input:
|
|
assert encoder_resolution is not None, \
|
|
'if asymetric_input set True, ' \
|
|
'clip_resolution must be a certain value'
|
|
self.asymetric_input = asymetric_input
|
|
self.encoder_resolution = encoder_resolution
|
|
self.image_encoder = MODELS.build(image_encoder)
|
|
self.text_encoder = MODELS.build(text_encoder)
|
|
self._init_decode_head(decode_head)
|
|
|
|
self.train_cfg = train_cfg
|
|
self.test_cfg = test_cfg
|
|
|
|
assert self.with_decode_head
|
|
|
|
def _init_decode_head(self, decode_head: ConfigType) -> None:
|
|
"""Initialize ``decode_head``"""
|
|
self.decode_head = MODELS.build(decode_head)
|
|
self.align_corners = self.decode_head.align_corners
|
|
self.num_classes = self.decode_head.num_classes
|
|
self.out_channels = self.decode_head.out_channels
|
|
|
|
def extract_feat(self, inputs: Tensor) -> List[Tensor]:
|
|
"""Extract visual features from images."""
|
|
x = self.image_encoder(inputs)
|
|
return x
|
|
|
|
def encode_decode(self, inputs: Tensor,
|
|
batch_img_metas: List[dict]) -> Tensor:
|
|
"""Encode the name of classes with text_encoder and encode images with
|
|
image_encoder.
|
|
|
|
Then decode the class embedding and visual feature into a semantic
|
|
segmentation map of the same size as input.
|
|
"""
|
|
classifier_embeds = self.text_encoder()
|
|
clip_inputs = inputs
|
|
if self.asymetric_input:
|
|
clip_inputs = F.interpolate(
|
|
inputs, scale_factor=self.encoder_resolution, mode='bilinear')
|
|
x = self.image_encoder(clip_inputs)
|
|
seg_logits = self.decode_head.predict([inputs, x, classifier_embeds],
|
|
batch_img_metas, self.test_cfg)
|
|
|
|
return seg_logits
|
|
|
|
def _decode_head_forward_train(self, inputs: List[Tensor],
|
|
data_samples: SampleList) -> dict:
|
|
"""Run forward function and calculate loss for decode head in
|
|
training."""
|
|
losses = dict()
|
|
loss_decode = self.decode_head.loss(inputs, data_samples,
|
|
self.train_cfg)
|
|
|
|
losses.update(add_prefix(loss_decode, 'decode'))
|
|
return losses
|
|
|
|
def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
|
|
"""Calculate losses from a batch of inputs and data samples.
|
|
|
|
Args:
|
|
inputs (Tensor): Input images.
|
|
data_samples (list[:obj:`SegDataSample`]): The seg data samples.
|
|
It usually includes information such as `metainfo` and
|
|
`gt_sem_seg`.
|
|
|
|
Returns:
|
|
dict[str, Tensor]: a dictionary of loss components
|
|
"""
|
|
classifier_embeds = self.text_encoder()
|
|
clip_inputs = inputs
|
|
if self.asymetric_input:
|
|
clip_inputs = F.interpolate(
|
|
inputs, scale_factor=self.encoder_resolution, mode='bilinear')
|
|
x = self.image_encoder(clip_inputs)
|
|
|
|
losses = dict()
|
|
|
|
loss_decode = self._decode_head_forward_train(
|
|
[inputs, x, classifier_embeds], data_samples)
|
|
losses.update(loss_decode)
|
|
|
|
return losses
|
|
|
|
def predict(self,
|
|
inputs: Tensor,
|
|
data_samples: OptSampleList = None) -> SampleList:
|
|
"""Predict results from a batch of inputs and data samples with post-
|
|
processing.
|
|
|
|
Args:
|
|
inputs (Tensor): Inputs with shape (N, C, H, W).
|
|
data_samples (List[:obj:`SegDataSample`], optional): The seg data
|
|
samples. It usually includes information such as `metainfo`
|
|
and `gt_sem_seg`.
|
|
|
|
Returns:
|
|
list[:obj:`SegDataSample`]: Segmentation results of the
|
|
input images. Each SegDataSample usually contain:
|
|
|
|
- ``pred_sem_seg``(PixelData): Prediction of semantic segmentation.
|
|
- ``seg_logits``(PixelData): Predicted logits of semantic
|
|
segmentation before normalization.
|
|
"""
|
|
if data_samples is not None:
|
|
batch_img_metas = [
|
|
data_sample.metainfo for data_sample in data_samples
|
|
]
|
|
else:
|
|
batch_img_metas = [
|
|
dict(
|
|
ori_shape=inputs.shape[2:],
|
|
img_shape=inputs.shape[2:],
|
|
pad_shape=inputs.shape[2:],
|
|
padding_size=[0, 0, 0, 0])
|
|
] * inputs.shape[0]
|
|
|
|
seg_logits = self.inference(inputs, batch_img_metas)
|
|
|
|
return self.postprocess_result(seg_logits, data_samples)
|
|
|
|
def _forward(self,
|
|
inputs: Tensor,
|
|
data_samples: OptSampleList = None) -> Tensor:
|
|
"""Network forward process.
|
|
|
|
Args:
|
|
inputs (Tensor): Inputs with shape (N, C, H, W).
|
|
data_samples (List[:obj:`SegDataSample`]): The seg
|
|
data samples. It usually includes information such
|
|
as `metainfo` and `gt_sem_seg`.
|
|
|
|
Returns:
|
|
Tensor: Forward output of model without any post-processes.
|
|
"""
|
|
x = self.extract_feat(inputs)
|
|
return self.decode_head.forward(x)
|
|
|
|
def slide_inference(self, inputs: Tensor,
|
|
batch_img_metas: List[dict]) -> Tensor:
|
|
"""Inference by sliding-window with overlap.
|
|
|
|
If h_crop > h_img or w_crop > w_img, the small patch will be used to
|
|
decode without padding.
|
|
|
|
Args:
|
|
inputs (tensor): the tensor should have a shape NxCxHxW,
|
|
which contains all images in the batch.
|
|
batch_img_metas (List[dict]): List of image metainfo where each may
|
|
also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
|
|
'ori_shape', and 'pad_shape'.
|
|
For details on the values of these keys see
|
|
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
|
|
|
|
Returns:
|
|
Tensor: The segmentation results, seg_logits from model of each
|
|
input image.
|
|
"""
|
|
|
|
h_stride, w_stride = self.test_cfg.stride
|
|
h_crop, w_crop = self.test_cfg.crop_size
|
|
batch_size, _, h_img, w_img = inputs.size()
|
|
out_channels = self.out_channels
|
|
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
|
|
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
|
|
preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img))
|
|
count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img))
|
|
for h_idx in range(h_grids):
|
|
for w_idx in range(w_grids):
|
|
y1 = h_idx * h_stride
|
|
x1 = w_idx * w_stride
|
|
y2 = min(y1 + h_crop, h_img)
|
|
x2 = min(x1 + w_crop, w_img)
|
|
y1 = max(y2 - h_crop, 0)
|
|
x1 = max(x2 - w_crop, 0)
|
|
crop_img = inputs[:, :, y1:y2, x1:x2]
|
|
# change the image shape to patch shape
|
|
batch_img_metas[0]['img_shape'] = crop_img.shape[2:]
|
|
# the output of encode_decode is seg logits tensor map
|
|
# with shape [N, C, H, W]
|
|
crop_seg_logit = self.encode_decode(crop_img, batch_img_metas)
|
|
preds += F.pad(crop_seg_logit,
|
|
(int(x1), int(preds.shape[3] - x2), int(y1),
|
|
int(preds.shape[2] - y2)))
|
|
|
|
count_mat[:, :, y1:y2, x1:x2] += 1
|
|
assert (count_mat == 0).sum() == 0
|
|
seg_logits = preds / count_mat
|
|
|
|
return seg_logits
|
|
|
|
def whole_inference(self, inputs: Tensor,
|
|
batch_img_metas: List[dict]) -> Tensor:
|
|
"""Inference with full image.
|
|
|
|
Args:
|
|
inputs (Tensor): The tensor should have a shape NxCxHxW, which
|
|
contains all images in the batch.
|
|
batch_img_metas (List[dict]): List of image metainfo where each may
|
|
also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
|
|
'ori_shape', and 'pad_shape'.
|
|
For details on the values of these keys see
|
|
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
|
|
|
|
Returns:
|
|
Tensor: The segmentation results, seg_logits from model of each
|
|
input image.
|
|
"""
|
|
|
|
seg_logits = self.encode_decode(inputs, batch_img_metas)
|
|
|
|
return seg_logits
|
|
|
|
def inference(self, inputs: Tensor, batch_img_metas: List[dict]) -> Tensor:
|
|
"""Inference with slide/whole style.
|
|
|
|
Args:
|
|
inputs (Tensor): The input image of shape (N, 3, H, W).
|
|
batch_img_metas (List[dict]): List of image metainfo where each may
|
|
also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
|
|
'ori_shape', 'pad_shape', and 'padding_size'.
|
|
For details on the values of these keys see
|
|
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
|
|
|
|
Returns:
|
|
Tensor: The segmentation results, seg_logits from model of each
|
|
input image.
|
|
"""
|
|
|
|
assert self.test_cfg.mode in ['slide', 'whole']
|
|
ori_shape = batch_img_metas[0]['ori_shape']
|
|
assert all(_['ori_shape'] == ori_shape for _ in batch_img_metas)
|
|
if self.test_cfg.mode == 'slide':
|
|
seg_logit = self.slide_inference(inputs, batch_img_metas)
|
|
else:
|
|
seg_logit = self.whole_inference(inputs, batch_img_metas)
|
|
|
|
return seg_logit
|
|
|
|
def aug_test(self, inputs, batch_img_metas, rescale=True):
|
|
"""Test with augmentations.
|
|
|
|
Only rescale=True is supported.
|
|
"""
|
|
# aug_test rescale all imgs back to ori_shape for now
|
|
assert rescale
|
|
# to save memory, we get augmented seg logit inplace
|
|
seg_logit = self.inference(inputs[0], batch_img_metas[0], rescale)
|
|
for i in range(1, len(inputs)):
|
|
cur_seg_logit = self.inference(inputs[i], batch_img_metas[i],
|
|
rescale)
|
|
seg_logit += cur_seg_logit
|
|
seg_logit /= len(inputs)
|
|
seg_pred = seg_logit.argmax(dim=1)
|
|
# unravel batch dim
|
|
seg_pred = list(seg_pred)
|
|
return seg_pred
|