angiecao 608e319eb6
[Feature] Support Side Adapter Network (#3232)
## Motivation
Support SAN for Open-Vocabulary Semantic Segmentation
Paper: [Side Adapter Network for Open-Vocabulary Semantic
Segmentation](https://arxiv.org/abs/2302.12242)
official Code: [SAN](https://github.com/MendelXu/SAN)

## Modification
- Added the parameters of backbone vit for implementing the image
encoder of CLIP.
- Added text encoder code.
- Added segmentor multimodel encoder-decoder code for open-vocabulary
semantic segmentation.
- Added SideAdapterNetwork decode head code.
- Added config files for train and inference.
- Added tools for converting pretrained models.
- Added loss implementation for mask classification model, such as SAN,
Maskformer and remove dependency on mmdetection.
- Added test units for text encoder, multimodel encoder-decoder, san
decode head and hungarian_assigner.

## Use cases
### Convert Models
**pretrained SAN model**
The official pretrained model can be downloaded from
[san_clip_vit_b_16.pth](https://huggingface.co/Mendel192/san/blob/main/san_vit_b_16.pth)
and
[san_clip_vit_large_14.pth](https://huggingface.co/Mendel192/san/blob/main/san_vit_large_14.pth).
Use tools/model_converters/san2mmseg.py to convert offcial model into
mmseg style.
`python tools/model_converters/san2mmseg.py <MODEL_PATH> <OUTPUT_PATH>`

**pretrained CLIP model**
Use the CLIP model provided by openai to train SAN. The CLIP model can
be download from
[ViT-B-16.pt](https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt)
and
[ViT-L-14-336px.pt](https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt).
Use tools/model_converters/clip2mmseg.py to convert model into mmseg
style.
`python tools/model_converters/clip2mmseg.py <MODEL_PATH> <OUTPUT_PATH>`

### Inference
test san_vit-base-16 model on coco-stuff164k dataset
`python tools/test.py
./configs/san/san-vit-b16_coco-stuff164k-640x640.py
<TRAINED_MODEL_PATH>`

### Train
test san_vit-base-16 model on coco-stuff164k dataset
`python tools/train.py
./configs/san/san-vit-b16_coco-stuff164k-640x640.py --cfg-options
model.pretrained=<PRETRAINED_MODEL_PATH>`

## Comparision Results
### Train on COCO-Stuff164k
|                 |       | mIoU  | mAcc  | pAcc  |
| --------------- | ----- | ----- | ----- | ----- |
| san-vit-base16  | official  | 41.93 | 56.73 | 67.69 |
|                 | mmseg | 41.93 | 56.84 | 67.84 |
| san-vit-large14 | official  | 45.57 | 59.52 | 69.76 |
|                 | mmseg | 45.78 | 59.61 | 69.21 |

### Evaluate on Pascal Context
|                 |       | mIoU  | mAcc  | pAcc  |
| --------------- | ----- | ----- | ----- | ----- |
| san-vit-base16  | official  | 54.05 | 72.96 | 77.77 |
|                 | mmseg | 54.04 | 73.74 | 77.71 |
| san-vit-large14 | official  | 57.53 | 77.56 | 78.89 |
|                 | mmseg | 56.89 | 76.96 | 78.74 |

### Evaluate on Voc12Aug
|                 |       | mIoU  | mAcc  | pAcc  |
| --------------- | ----- | ----- | ----- | ----- |
| san-vit-base16  | official  | 93.86 | 96.61 | 97.11 |
|                 | mmseg | 94.58 | 97.01 | 97.38 |
| san-vit-large14 | official  | 95.17 | 97.61 | 97.63 |
|                 | mmseg | 95.58 | 97.75 | 97.79 |

---------

Co-authored-by: CastleDream <35064479+CastleDream@users.noreply.github.com>
Co-authored-by: yeedrag <46050186+yeedrag@users.noreply.github.com>
Co-authored-by: Yang-ChangHui <71805205+Yang-Changhui@users.noreply.github.com>
Co-authored-by: Xu CAO <49406546+SheffieldCao@users.noreply.github.com>
Co-authored-by: xiexinch <xiexinch@outlook.com>
Co-authored-by: 小飞猪 <106524776+ooooo-create@users.noreply.github.com>
2023-09-20 21:20:26 +08:00

127 lines
4.1 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine import Config
from mmengine.structures import PixelData
from mmseg.models.decode_heads import SideAdapterCLIPHead
from mmseg.structures import SegDataSample
from .utils import list_to_cuda
def test_san_head():
H, W = (64, 64)
clip_channels = 64
img_channels = 4
num_queries = 40
out_dims = 64
num_classes = 19
cfg = dict(
num_classes=num_classes,
deep_supervision_idxs=[4],
san_cfg=dict(
in_channels=img_channels,
embed_dims=128,
clip_channels=clip_channels,
num_queries=num_queries,
cfg_encoder=dict(num_encode_layer=4, mlp_ratio=2, num_heads=2),
cfg_decoder=dict(
num_heads=4,
num_layers=1,
embed_channels=32,
mlp_channels=32,
num_mlp=2,
rescale=True)),
maskgen_cfg=dict(
sos_token_num=num_queries,
embed_dims=clip_channels,
out_dims=out_dims,
num_heads=4,
mlp_ratio=2),
train_cfg=dict(
num_points=100,
oversample_ratio=3.0,
importance_sample_ratio=0.75,
assigner=dict(
type='HungarianAssigner',
match_costs=[
dict(type='ClassificationCost', weight=2.0),
dict(
type='CrossEntropyLossCost',
weight=5.0,
use_sigmoid=True),
dict(type='DiceCost', weight=5.0, pred_act=True, eps=1.0)
])),
loss_decode=[
dict(
type='CrossEntropyLoss',
loss_name='loss_cls_ce',
loss_weight=2.0,
class_weight=[1.0] * num_classes + [0.1]),
dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_name='loss_mask_ce',
loss_weight=5.0),
dict(
type='DiceLoss',
ignore_index=None,
naive_dice=True,
eps=1,
loss_name='loss_mask_dice',
loss_weight=5.0)
])
cfg = Config(cfg)
head = SideAdapterCLIPHead(**cfg)
inputs = torch.rand((2, img_channels, H, W))
clip_feature = [[
torch.rand((2, clip_channels, H // 2, W // 2)),
torch.rand((2, clip_channels))
],
[
torch.rand((2, clip_channels, H // 2, W // 2)),
torch.rand((2, clip_channels))
],
[
torch.rand((2, clip_channels, H // 2, W // 2)),
torch.rand((2, clip_channels))
],
[
torch.rand((2, clip_channels, H // 2, W // 2)),
torch.rand((2, clip_channels))
]]
class_embed = torch.rand((num_classes + 1, out_dims))
data_samples = []
for i in range(2):
data_sample = SegDataSample()
img_meta = {}
img_meta['img_shape'] = (H, W)
img_meta['ori_shape'] = (H, W)
data_sample.gt_sem_seg = PixelData(
data=torch.randint(0, num_classes, (1, H, W)))
data_sample.set_metainfo(img_meta)
data_samples.append(data_sample)
batch_img_metas = []
for data_sample in data_samples:
batch_img_metas.append(data_sample.metainfo)
if torch.cuda.is_available():
head = head.cuda()
data = list_to_cuda([inputs, clip_feature, class_embed])
for data_sample in data_samples:
data_sample.gt_sem_seg.data = data_sample.gt_sem_seg.data.cuda()
else:
data = [inputs, clip_feature, class_embed]
# loss test
loss_dict = head.loss(data, data_samples, None)
assert isinstance(loss_dict, dict)
# prediction test
with torch.no_grad():
seg_logits = head.predict(data, batch_img_metas, None)
assert seg_logits.shape == torch.Size((2, num_classes, H, W))