mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
## Motivation Support SAN for Open-Vocabulary Semantic Segmentation Paper: [Side Adapter Network for Open-Vocabulary Semantic Segmentation](https://arxiv.org/abs/2302.12242) official Code: [SAN](https://github.com/MendelXu/SAN) ## Modification - Added the parameters of backbone vit for implementing the image encoder of CLIP. - Added text encoder code. - Added segmentor multimodel encoder-decoder code for open-vocabulary semantic segmentation. - Added SideAdapterNetwork decode head code. - Added config files for train and inference. - Added tools for converting pretrained models. - Added loss implementation for mask classification model, such as SAN, Maskformer and remove dependency on mmdetection. - Added test units for text encoder, multimodel encoder-decoder, san decode head and hungarian_assigner. ## Use cases ### Convert Models **pretrained SAN model** The official pretrained model can be downloaded from [san_clip_vit_b_16.pth](https://huggingface.co/Mendel192/san/blob/main/san_vit_b_16.pth) and [san_clip_vit_large_14.pth](https://huggingface.co/Mendel192/san/blob/main/san_vit_large_14.pth). Use tools/model_converters/san2mmseg.py to convert offcial model into mmseg style. `python tools/model_converters/san2mmseg.py <MODEL_PATH> <OUTPUT_PATH>` **pretrained CLIP model** Use the CLIP model provided by openai to train SAN. The CLIP model can be download from [ViT-B-16.pt](https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt) and [ViT-L-14-336px.pt](https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt). Use tools/model_converters/clip2mmseg.py to convert model into mmseg style. `python tools/model_converters/clip2mmseg.py <MODEL_PATH> <OUTPUT_PATH>` ### Inference test san_vit-base-16 model on coco-stuff164k dataset `python tools/test.py ./configs/san/san-vit-b16_coco-stuff164k-640x640.py <TRAINED_MODEL_PATH>` ### Train test san_vit-base-16 model on coco-stuff164k dataset `python tools/train.py ./configs/san/san-vit-b16_coco-stuff164k-640x640.py --cfg-options model.pretrained=<PRETRAINED_MODEL_PATH>` ## Comparision Results ### Train on COCO-Stuff164k | | | mIoU | mAcc | pAcc | | --------------- | ----- | ----- | ----- | ----- | | san-vit-base16 | official | 41.93 | 56.73 | 67.69 | | | mmseg | 41.93 | 56.84 | 67.84 | | san-vit-large14 | official | 45.57 | 59.52 | 69.76 | | | mmseg | 45.78 | 59.61 | 69.21 | ### Evaluate on Pascal Context | | | mIoU | mAcc | pAcc | | --------------- | ----- | ----- | ----- | ----- | | san-vit-base16 | official | 54.05 | 72.96 | 77.77 | | | mmseg | 54.04 | 73.74 | 77.71 | | san-vit-large14 | official | 57.53 | 77.56 | 78.89 | | | mmseg | 56.89 | 76.96 | 78.74 | ### Evaluate on Voc12Aug | | | mIoU | mAcc | pAcc | | --------------- | ----- | ----- | ----- | ----- | | san-vit-base16 | official | 93.86 | 96.61 | 97.11 | | | mmseg | 94.58 | 97.01 | 97.38 | | san-vit-large14 | official | 95.17 | 97.61 | 97.63 | | | mmseg | 95.58 | 97.75 | 97.79 | --------- Co-authored-by: CastleDream <35064479+CastleDream@users.noreply.github.com> Co-authored-by: yeedrag <46050186+yeedrag@users.noreply.github.com> Co-authored-by: Yang-ChangHui <71805205+Yang-Changhui@users.noreply.github.com> Co-authored-by: Xu CAO <49406546+SheffieldCao@users.noreply.github.com> Co-authored-by: xiexinch <xiexinch@outlook.com> Co-authored-by: 小飞猪 <106524776+ooooo-create@users.noreply.github.com>
183 lines
5.4 KiB
Python
183 lines
5.4 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
from mmengine.optim import OptimWrapper
|
|
from mmengine.structures import PixelData
|
|
from torch import nn
|
|
from torch.optim import SGD
|
|
|
|
from mmseg.models import SegDataPreProcessor
|
|
from mmseg.models.decode_heads.cascade_decode_head import BaseCascadeDecodeHead
|
|
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
|
from mmseg.registry import MODELS
|
|
from mmseg.structures import SegDataSample
|
|
|
|
|
|
def _demo_mm_inputs(input_shape=(1, 3, 8, 16), num_classes=10):
|
|
"""Create a superset of inputs needed to run test or train batches.
|
|
|
|
Args:
|
|
input_shape (tuple):
|
|
input batch dimensions
|
|
|
|
num_classes (int):
|
|
number of semantic classes
|
|
"""
|
|
(N, C, H, W) = input_shape
|
|
|
|
imgs = torch.randn(*input_shape)
|
|
segs = torch.randint(
|
|
low=0, high=num_classes - 1, size=(N, H, W), dtype=torch.long)
|
|
|
|
img_metas = [{
|
|
'img_shape': (H, W),
|
|
'ori_shape': (H, W),
|
|
'pad_shape': (H, W, C),
|
|
'filename': '<demo>.png',
|
|
'scale_factor': 1.0,
|
|
'flip': False,
|
|
'flip_direction': 'horizontal'
|
|
} for _ in range(N)]
|
|
|
|
data_samples = [
|
|
SegDataSample(
|
|
gt_sem_seg=PixelData(data=segs[i]), metainfo=img_metas[i])
|
|
for i in range(N)
|
|
]
|
|
|
|
mm_inputs = {'imgs': torch.FloatTensor(imgs), 'data_samples': data_samples}
|
|
|
|
return mm_inputs
|
|
|
|
|
|
@MODELS.register_module()
|
|
class ExampleBackbone(nn.Module):
|
|
|
|
def __init__(self, out_indices=None):
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(3, 3, 3)
|
|
self.out_indices = out_indices
|
|
|
|
def init_weights(self, pretrained=None):
|
|
pass
|
|
|
|
def forward(self, x):
|
|
if self.out_indices is None:
|
|
return [self.conv(x)]
|
|
else:
|
|
outs = []
|
|
for i in self.out_indices:
|
|
outs.append(self.conv(x))
|
|
return outs
|
|
|
|
|
|
@MODELS.register_module()
|
|
class ExampleDecodeHead(BaseDecodeHead):
|
|
|
|
def __init__(self, num_classes=19, out_channels=None, **kwargs):
|
|
super().__init__(
|
|
3, 3, num_classes=num_classes, out_channels=out_channels, **kwargs)
|
|
|
|
def forward(self, inputs):
|
|
return self.cls_seg(inputs[0])
|
|
|
|
|
|
@MODELS.register_module()
|
|
class ExampleTextEncoder(nn.Module):
|
|
|
|
def __init__(self, vocabulary=None, output_dims=None):
|
|
super().__init__()
|
|
self.vocabulary = vocabulary
|
|
self.output_dims = output_dims
|
|
|
|
def forward(self):
|
|
return torch.randn((len(self.vocabulary), self.output_dims))
|
|
|
|
|
|
@MODELS.register_module()
|
|
class ExampleCascadeDecodeHead(BaseCascadeDecodeHead):
|
|
|
|
def __init__(self):
|
|
super().__init__(3, 3, num_classes=19)
|
|
|
|
def forward(self, inputs, prev_out):
|
|
return self.cls_seg(inputs[0])
|
|
|
|
|
|
def _segmentor_forward_train_test(segmentor):
|
|
if isinstance(segmentor.decode_head, nn.ModuleList):
|
|
num_classes = segmentor.decode_head[-1].num_classes
|
|
else:
|
|
num_classes = segmentor.decode_head.num_classes
|
|
# batch_size=2 for BatchNorm
|
|
mm_inputs = _demo_mm_inputs(num_classes=num_classes)
|
|
|
|
# convert to cuda Tensor if applicable
|
|
if torch.cuda.is_available():
|
|
segmentor = segmentor.cuda()
|
|
|
|
# check data preprocessor
|
|
if not hasattr(segmentor,
|
|
'data_preprocessor') or segmentor.data_preprocessor is None:
|
|
segmentor.data_preprocessor = SegDataPreProcessor()
|
|
|
|
mm_inputs = segmentor.data_preprocessor(mm_inputs, True)
|
|
imgs = mm_inputs.pop('imgs')
|
|
data_samples = mm_inputs.pop('data_samples')
|
|
|
|
# create optimizer wrapper
|
|
optimizer = SGD(segmentor.parameters(), lr=0.1)
|
|
optim_wrapper = OptimWrapper(optimizer)
|
|
|
|
# Test forward train
|
|
losses = segmentor.forward(imgs, data_samples, mode='loss')
|
|
assert isinstance(losses, dict)
|
|
|
|
# Test train_step
|
|
data_batch = dict(inputs=imgs, data_samples=data_samples)
|
|
outputs = segmentor.train_step(data_batch, optim_wrapper)
|
|
assert isinstance(outputs, dict)
|
|
assert 'loss' in outputs
|
|
|
|
# Test val_step
|
|
with torch.no_grad():
|
|
segmentor.eval()
|
|
data_batch = dict(inputs=imgs, data_samples=data_samples)
|
|
outputs = segmentor.val_step(data_batch)
|
|
assert isinstance(outputs, list)
|
|
|
|
# Test forward simple test
|
|
with torch.no_grad():
|
|
segmentor.eval()
|
|
data_batch = dict(inputs=imgs, data_samples=data_samples)
|
|
results = segmentor.forward(imgs, data_samples, mode='tensor')
|
|
assert isinstance(results, torch.Tensor)
|
|
|
|
|
|
def _segmentor_predict(segmentor):
|
|
if isinstance(segmentor.decode_head, nn.ModuleList):
|
|
num_classes = segmentor.decode_head[-1].num_classes
|
|
else:
|
|
num_classes = segmentor.decode_head.num_classes
|
|
# batch_size=2 for BatchNorm
|
|
mm_inputs = _demo_mm_inputs(num_classes=num_classes)
|
|
|
|
# convert to cuda Tensor if applicable
|
|
if torch.cuda.is_available():
|
|
segmentor = segmentor.cuda()
|
|
|
|
# check data preprocessor
|
|
if not hasattr(segmentor,
|
|
'data_preprocessor') or segmentor.data_preprocessor is None:
|
|
segmentor.data_preprocessor = SegDataPreProcessor()
|
|
|
|
mm_inputs = segmentor.data_preprocessor(mm_inputs, True)
|
|
imgs = mm_inputs.pop('imgs')
|
|
data_samples = mm_inputs.pop('data_samples')
|
|
|
|
# Test predict
|
|
with torch.no_grad():
|
|
segmentor.eval()
|
|
data_batch = dict(inputs=imgs, data_samples=data_samples)
|
|
outputs = segmentor.predict(**data_batch)
|
|
assert isinstance(outputs, list)
|