mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
## Motivation Support SAN for Open-Vocabulary Semantic Segmentation Paper: [Side Adapter Network for Open-Vocabulary Semantic Segmentation](https://arxiv.org/abs/2302.12242) official Code: [SAN](https://github.com/MendelXu/SAN) ## Modification - Added the parameters of backbone vit for implementing the image encoder of CLIP. - Added text encoder code. - Added segmentor multimodel encoder-decoder code for open-vocabulary semantic segmentation. - Added SideAdapterNetwork decode head code. - Added config files for train and inference. - Added tools for converting pretrained models. - Added loss implementation for mask classification model, such as SAN, Maskformer and remove dependency on mmdetection. - Added test units for text encoder, multimodel encoder-decoder, san decode head and hungarian_assigner. ## Use cases ### Convert Models **pretrained SAN model** The official pretrained model can be downloaded from [san_clip_vit_b_16.pth](https://huggingface.co/Mendel192/san/blob/main/san_vit_b_16.pth) and [san_clip_vit_large_14.pth](https://huggingface.co/Mendel192/san/blob/main/san_vit_large_14.pth). Use tools/model_converters/san2mmseg.py to convert offcial model into mmseg style. `python tools/model_converters/san2mmseg.py <MODEL_PATH> <OUTPUT_PATH>` **pretrained CLIP model** Use the CLIP model provided by openai to train SAN. The CLIP model can be download from [ViT-B-16.pt](https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt) and [ViT-L-14-336px.pt](https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt). Use tools/model_converters/clip2mmseg.py to convert model into mmseg style. `python tools/model_converters/clip2mmseg.py <MODEL_PATH> <OUTPUT_PATH>` ### Inference test san_vit-base-16 model on coco-stuff164k dataset `python tools/test.py ./configs/san/san-vit-b16_coco-stuff164k-640x640.py <TRAINED_MODEL_PATH>` ### Train test san_vit-base-16 model on coco-stuff164k dataset `python tools/train.py ./configs/san/san-vit-b16_coco-stuff164k-640x640.py --cfg-options model.pretrained=<PRETRAINED_MODEL_PATH>` ## Comparision Results ### Train on COCO-Stuff164k | | | mIoU | mAcc | pAcc | | --------------- | ----- | ----- | ----- | ----- | | san-vit-base16 | official | 41.93 | 56.73 | 67.69 | | | mmseg | 41.93 | 56.84 | 67.84 | | san-vit-large14 | official | 45.57 | 59.52 | 69.76 | | | mmseg | 45.78 | 59.61 | 69.21 | ### Evaluate on Pascal Context | | | mIoU | mAcc | pAcc | | --------------- | ----- | ----- | ----- | ----- | | san-vit-base16 | official | 54.05 | 72.96 | 77.77 | | | mmseg | 54.04 | 73.74 | 77.71 | | san-vit-large14 | official | 57.53 | 77.56 | 78.89 | | | mmseg | 56.89 | 76.96 | 78.74 | ### Evaluate on Voc12Aug | | | mIoU | mAcc | pAcc | | --------------- | ----- | ----- | ----- | ----- | | san-vit-base16 | official | 93.86 | 96.61 | 97.11 | | | mmseg | 94.58 | 97.01 | 97.38 | | san-vit-large14 | official | 95.17 | 97.61 | 97.63 | | | mmseg | 95.58 | 97.75 | 97.79 | --------- Co-authored-by: CastleDream <35064479+CastleDream@users.noreply.github.com> Co-authored-by: yeedrag <46050186+yeedrag@users.noreply.github.com> Co-authored-by: Yang-ChangHui <71805205+Yang-Changhui@users.noreply.github.com> Co-authored-by: Xu CAO <49406546+SheffieldCao@users.noreply.github.com> Co-authored-by: xiexinch <xiexinch@outlook.com> Co-authored-by: 小飞猪 <106524776+ooooo-create@users.noreply.github.com>
127 lines
4.1 KiB
Python
127 lines
4.1 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
from mmengine import Config
|
|
from mmengine.structures import PixelData
|
|
|
|
from mmseg.models.decode_heads import SideAdapterCLIPHead
|
|
from mmseg.structures import SegDataSample
|
|
from .utils import list_to_cuda
|
|
|
|
|
|
def test_san_head():
|
|
H, W = (64, 64)
|
|
clip_channels = 64
|
|
img_channels = 4
|
|
num_queries = 40
|
|
out_dims = 64
|
|
num_classes = 19
|
|
cfg = dict(
|
|
num_classes=num_classes,
|
|
deep_supervision_idxs=[4],
|
|
san_cfg=dict(
|
|
in_channels=img_channels,
|
|
embed_dims=128,
|
|
clip_channels=clip_channels,
|
|
num_queries=num_queries,
|
|
cfg_encoder=dict(num_encode_layer=4, mlp_ratio=2, num_heads=2),
|
|
cfg_decoder=dict(
|
|
num_heads=4,
|
|
num_layers=1,
|
|
embed_channels=32,
|
|
mlp_channels=32,
|
|
num_mlp=2,
|
|
rescale=True)),
|
|
maskgen_cfg=dict(
|
|
sos_token_num=num_queries,
|
|
embed_dims=clip_channels,
|
|
out_dims=out_dims,
|
|
num_heads=4,
|
|
mlp_ratio=2),
|
|
train_cfg=dict(
|
|
num_points=100,
|
|
oversample_ratio=3.0,
|
|
importance_sample_ratio=0.75,
|
|
assigner=dict(
|
|
type='HungarianAssigner',
|
|
match_costs=[
|
|
dict(type='ClassificationCost', weight=2.0),
|
|
dict(
|
|
type='CrossEntropyLossCost',
|
|
weight=5.0,
|
|
use_sigmoid=True),
|
|
dict(type='DiceCost', weight=5.0, pred_act=True, eps=1.0)
|
|
])),
|
|
loss_decode=[
|
|
dict(
|
|
type='CrossEntropyLoss',
|
|
loss_name='loss_cls_ce',
|
|
loss_weight=2.0,
|
|
class_weight=[1.0] * num_classes + [0.1]),
|
|
dict(
|
|
type='CrossEntropyLoss',
|
|
use_sigmoid=True,
|
|
loss_name='loss_mask_ce',
|
|
loss_weight=5.0),
|
|
dict(
|
|
type='DiceLoss',
|
|
ignore_index=None,
|
|
naive_dice=True,
|
|
eps=1,
|
|
loss_name='loss_mask_dice',
|
|
loss_weight=5.0)
|
|
])
|
|
|
|
cfg = Config(cfg)
|
|
head = SideAdapterCLIPHead(**cfg)
|
|
|
|
inputs = torch.rand((2, img_channels, H, W))
|
|
clip_feature = [[
|
|
torch.rand((2, clip_channels, H // 2, W // 2)),
|
|
torch.rand((2, clip_channels))
|
|
],
|
|
[
|
|
torch.rand((2, clip_channels, H // 2, W // 2)),
|
|
torch.rand((2, clip_channels))
|
|
],
|
|
[
|
|
torch.rand((2, clip_channels, H // 2, W // 2)),
|
|
torch.rand((2, clip_channels))
|
|
],
|
|
[
|
|
torch.rand((2, clip_channels, H // 2, W // 2)),
|
|
torch.rand((2, clip_channels))
|
|
]]
|
|
class_embed = torch.rand((num_classes + 1, out_dims))
|
|
|
|
data_samples = []
|
|
for i in range(2):
|
|
data_sample = SegDataSample()
|
|
img_meta = {}
|
|
img_meta['img_shape'] = (H, W)
|
|
img_meta['ori_shape'] = (H, W)
|
|
data_sample.gt_sem_seg = PixelData(
|
|
data=torch.randint(0, num_classes, (1, H, W)))
|
|
data_sample.set_metainfo(img_meta)
|
|
data_samples.append(data_sample)
|
|
|
|
batch_img_metas = []
|
|
for data_sample in data_samples:
|
|
batch_img_metas.append(data_sample.metainfo)
|
|
|
|
if torch.cuda.is_available():
|
|
head = head.cuda()
|
|
data = list_to_cuda([inputs, clip_feature, class_embed])
|
|
for data_sample in data_samples:
|
|
data_sample.gt_sem_seg.data = data_sample.gt_sem_seg.data.cuda()
|
|
else:
|
|
data = [inputs, clip_feature, class_embed]
|
|
|
|
# loss test
|
|
loss_dict = head.loss(data, data_samples, None)
|
|
assert isinstance(loss_dict, dict)
|
|
|
|
# prediction test
|
|
with torch.no_grad():
|
|
seg_logits = head.predict(data, batch_img_metas, None)
|
|
assert seg_logits.shape == torch.Size((2, num_classes, H, W))
|