[Feature] Support Side Adapter Network (#3232)

## Motivation
Support SAN for Open-Vocabulary Semantic Segmentation
Paper: [Side Adapter Network for Open-Vocabulary Semantic
Segmentation](https://arxiv.org/abs/2302.12242)
official Code: [SAN](https://github.com/MendelXu/SAN)

## Modification
- Added the parameters of backbone vit for implementing the image
encoder of CLIP.
- Added text encoder code.
- Added segmentor multimodel encoder-decoder code for open-vocabulary
semantic segmentation.
- Added SideAdapterNetwork decode head code.
- Added config files for train and inference.
- Added tools for converting pretrained models.
- Added loss implementation for mask classification model, such as SAN,
Maskformer and remove dependency on mmdetection.
- Added test units for text encoder, multimodel encoder-decoder, san
decode head and hungarian_assigner.

## Use cases
### Convert Models
**pretrained SAN model**
The official pretrained model can be downloaded from
[san_clip_vit_b_16.pth](https://huggingface.co/Mendel192/san/blob/main/san_vit_b_16.pth)
and
[san_clip_vit_large_14.pth](https://huggingface.co/Mendel192/san/blob/main/san_vit_large_14.pth).
Use tools/model_converters/san2mmseg.py to convert offcial model into
mmseg style.
`python tools/model_converters/san2mmseg.py <MODEL_PATH> <OUTPUT_PATH>`

**pretrained CLIP model**
Use the CLIP model provided by openai to train SAN. The CLIP model can
be download from
[ViT-B-16.pt](https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt)
and
[ViT-L-14-336px.pt](https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt).
Use tools/model_converters/clip2mmseg.py to convert model into mmseg
style.
`python tools/model_converters/clip2mmseg.py <MODEL_PATH> <OUTPUT_PATH>`

### Inference
test san_vit-base-16 model on coco-stuff164k dataset
`python tools/test.py
./configs/san/san-vit-b16_coco-stuff164k-640x640.py
<TRAINED_MODEL_PATH>`

### Train
test san_vit-base-16 model on coco-stuff164k dataset
`python tools/train.py
./configs/san/san-vit-b16_coco-stuff164k-640x640.py --cfg-options
model.pretrained=<PRETRAINED_MODEL_PATH>`

## Comparision Results
### Train on COCO-Stuff164k
|                 |       | mIoU  | mAcc  | pAcc  |
| --------------- | ----- | ----- | ----- | ----- |
| san-vit-base16  | official  | 41.93 | 56.73 | 67.69 |
|                 | mmseg | 41.93 | 56.84 | 67.84 |
| san-vit-large14 | official  | 45.57 | 59.52 | 69.76 |
|                 | mmseg | 45.78 | 59.61 | 69.21 |

### Evaluate on Pascal Context
|                 |       | mIoU  | mAcc  | pAcc  |
| --------------- | ----- | ----- | ----- | ----- |
| san-vit-base16  | official  | 54.05 | 72.96 | 77.77 |
|                 | mmseg | 54.04 | 73.74 | 77.71 |
| san-vit-large14 | official  | 57.53 | 77.56 | 78.89 |
|                 | mmseg | 56.89 | 76.96 | 78.74 |

### Evaluate on Voc12Aug
|                 |       | mIoU  | mAcc  | pAcc  |
| --------------- | ----- | ----- | ----- | ----- |
| san-vit-base16  | official  | 93.86 | 96.61 | 97.11 |
|                 | mmseg | 94.58 | 97.01 | 97.38 |
| san-vit-large14 | official  | 95.17 | 97.61 | 97.63 |
|                 | mmseg | 95.58 | 97.75 | 97.79 |

---------

Co-authored-by: CastleDream <35064479+CastleDream@users.noreply.github.com>
Co-authored-by: yeedrag <46050186+yeedrag@users.noreply.github.com>
Co-authored-by: Yang-ChangHui <71805205+Yang-Changhui@users.noreply.github.com>
Co-authored-by: Xu CAO <49406546+SheffieldCao@users.noreply.github.com>
Co-authored-by: xiexinch <xiexinch@outlook.com>
Co-authored-by: 小飞猪 <106524776+ooooo-create@users.noreply.github.com>
pull/3348/head
angiecao 2023-09-20 21:20:26 +08:00 committed by GitHub
parent 1471d1e529
commit 608e319eb6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
42 changed files with 4114 additions and 29 deletions

View File

@ -0,0 +1,137 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
data_preprocessor = dict(
type='SegDataPreProcessor',
mean=[122.7709, 116.7460, 104.0937],
std=[68.5005, 66.6322, 70.3232],
bgr_to_rgb=True,
pad_val=0,
seg_pad_val=255,
size_divisor=640,
test_cfg=dict(size_divisor=32))
num_classes = 171
model = dict(
type='MultimodalEncoderDecoder',
data_preprocessor=data_preprocessor,
pretrained='pretrain/clip_vit_base_patch16_224.pth',
asymetric_input=True,
encoder_resolution=0.5,
image_encoder=dict(
type='VisionTransformer',
img_size=(224, 224),
patch_size=16,
patch_pad=0,
in_channels=3,
embed_dims=768,
num_layers=9,
num_heads=12,
mlp_ratio=4,
out_origin=True,
out_indices=(2, 5, 8),
qkv_bias=True,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
with_cls_token=True,
output_cls_token=True,
patch_bias=False,
pre_norm=True,
norm_cfg=dict(type='LN', eps=1e-5),
act_cfg=dict(type='QuickGELU'),
norm_eval=False,
interpolate_mode='bicubic',
frozen_exclude=['pos_embed']),
text_encoder=dict(
type='CLIPTextEncoder',
dataset_name=None,
templates='vild',
embed_dims=512,
num_layers=12,
num_heads=8,
mlp_ratio=4,
output_dims=512,
cache_feature=True,
cat_bg=True,
norm_cfg=dict(type='LN', eps=1e-5)
),
decode_head=dict(
type='SideAdapterCLIPHead',
num_classes=num_classes,
deep_supervision_idxs=[7],
san_cfg=dict(
in_channels=3,
clip_channels=768,
embed_dims=240,
patch_size=16,
patch_bias=True,
num_queries=100,
cfg_encoder=dict(
num_encode_layer=8,
num_heads=6,
mlp_ratio=4
),
fusion_index=[0, 1, 2, 3],
cfg_decoder=dict(
num_heads=12,
num_layers=1,
embed_channels=256,
mlp_channels=256,
num_mlp=3,
rescale=True),
norm_cfg=dict(type='LN', eps=1e-6),
),
maskgen_cfg=dict(
sos_token_format='cls_token',
sos_token_num=100,
cross_attn=False,
num_layers=3,
embed_dims=768,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
out_dims=512,
final_norm=True,
act_cfg=dict(type='QuickGELU'),
norm_cfg=dict(type='LN', eps=1e-5),
frozen_exclude=[]
),
align_corners=False,
train_cfg=dict(
num_points=12544,
oversample_ratio=3.0,
importance_sample_ratio=0.75,
assigner=dict(
type='HungarianAssigner',
match_costs=[
dict(type='ClassificationCost', weight=2.0),
dict(
type='CrossEntropyLossCost',
weight=5.0,
use_sigmoid=True),
dict(
type='DiceCost',
weight=5.0,
pred_act=True,
eps=1.0)
])),
loss_decode=[dict(type='CrossEntropyLoss',
loss_name='loss_cls_ce',
loss_weight=2.0,
class_weight=[1.0] * num_classes + [0.1]),
dict(type='CrossEntropyLoss',
use_sigmoid=True,
loss_name='loss_mask_ce',
loss_weight=5.0),
dict(type='DiceLoss',
ignore_index=None,
naive_dice=True,
eps=1,
loss_name='loss_mask_dice',
loss_weight=5.0)
]),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole')) # yapf: disable

View File

@ -0,0 +1,47 @@
# SAN
> [Side Adapter Network for Open-Vocabulary Semantic Segmentation](https://arxiv.org/abs/2302.12242)
## Introduction
<!-- [ALGORITHM] -->
<a href="https://github.com/MendelXu/SAN">Official Repo</a>
## Abstract
<!-- [ABSTRACT] -->
This paper presents a new framework for open-vocabulary semantic segmentation with the pre-trained vision-language model, named Side Adapter Network (SAN). Our approach models the semantic segmentation task as a region recognition problem. A side network is attached to a frozen CLIP model with two branches: one for predicting mask proposals, and the other for predicting attention bias which is applied in the CLIP model to recognize the class of masks. This decoupled design has the benefit CLIP in recognizing the class of mask proposals. Since the attached side network can reuse CLIP features, it can be very light. In addition, the entire network can be trained end-to-end, allowing the side network to be adapted to the frozen CLIP model, which makes the predicted mask proposals CLIP-aware. Our approach is fast, accurate, and only adds a few additional trainable parameters. We evaluate our approach on multiple semantic segmentation benchmarks. Our method significantly outperforms other counterparts, with up to 18 times fewer trainable parameters and 19 times faster inference speed. We hope our approach will serve as a solid baseline and help ease future research in open-vocabulary semantic segmentation.
<!-- [IMAGE] -->
<div align=center>
<img src="https://github.com/MendelXu/SAN/blob/main/resources/arch.png" width="800"/>
</div>
## Results and models
### COCO-Stuff164k
| Method | Backbone | Pretrained | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | Device | mIoU | mIoU(ms+flip) | config | download |
| ------ | -------- | ------------ | --------- | ------- | -------- | -------------- | ------ | ----- | ------------- | ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| SAN | ViT-B_16 | CLIP_ViT-B16 | 640x640 | 60000 | 12.61 | - | V100 | 41.93 | 41.77 | - | [model](https://download.openmmlab.com/mmsegmentation/v0.5/san/san-vit-b16_20230906-fd0a7684.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/san/san-vit-b16_20230906.log) |
| SAN | ViT-L_14 | CLIP_ViT-L14 | 640x640 | 60000 | 22.84 | - | V100 | 45.78 | 43.99 | - | [model](https://download.openmmlab.com/mmsegmentation/v0.5/san/san-vit-l14_20230907-a11e098f.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/san/san-vit-l14_20230907.log) |
## Notes
git push
The pretrained weights in config files are converted from open_clip models using tools/model_converters/clip2mmseg.py.
## Citation
```bibtex
@inproceedings{xu2023side,
title={Side adapter network for open-vocabulary semantic segmentation},
author={Xu, Mengde and Zhang, Zheng and Wei, Fangyun and Hu, Han and Bai, Xiang},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={2945--2954},
year={2023}
}
```

View File

@ -0,0 +1,82 @@
_base_ = [
'../_base_/models/san_vit-b16.py', '../_base_/datasets/coco-stuff164k.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
crop_size = (640, 640)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(
type='RandomChoiceResize',
scales=[int(640 * x * 0.1) for x in range(5, 16)],
resize_type='ResizeShortestEdge',
max_size=2560),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=1.0),
dict(type='PhotoMetricDistortion'),
dict(type='RandomFlip', prob=0.5),
dict(type='PackSegInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='ResizeShortestEdge', scale=crop_size, max_size=2560),
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
# By default, models are trained on 4 GPUs with 8 images per GPU
train_dataloader = dict(batch_size=8, dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(batch_size=1, dataset=dict(pipeline=test_pipeline))
test_dataloader = val_dataloader
pretrained = 'https://download.openmmlab.com/mmsegmentation/v0.5/san/clip_vit-base-patch16-224_3rdparty-d08f8887.pth' # noqa
data_preprocessor = dict(
mean=[122.7709, 116.7460, 104.0937],
std=[68.5005, 66.6322, 70.3232],
size_divisor=640,
test_cfg=dict(size_divisor=32))
model = dict(
pretrained=pretrained,
text_encoder=dict(dataset_name='coco-stuff164k'),
decode_head=dict(num_classes=171))
# training schedule for 60k
train_cfg = dict(
type='IterBasedTrainLoop',
max_iters=60000,
val_interval=500,
val_begin=55000)
default_hooks = dict(
checkpoint=dict(
type='CheckpointHook',
by_epoch=False,
interval=10000,
save_best='mIoU'))
# AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone
optim_wrapper = dict(
_delete_=True,
type='AmpOptimWrapper',
optimizer=dict(
type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.0001),
paramwise_cfg=dict(
custom_keys={
'img_encoder': dict(lr_mult=0.1, decay_mult=1.0),
'pos_embed': dict(decay_mult=0.),
'cls_token': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}),
loss_scale='dynamic',
clip_grad=dict(max_norm=0.01, norm_type=2))
param_scheduler = [
dict(
type='PolyLR',
eta_min=0.0,
power=1.0,
begin=0,
end=60000,
by_epoch=False,
)
]

View File

@ -0,0 +1,56 @@
_base_ = [
'../_base_/models/san_vit-b16.py',
'../_base_/datasets/pascal_context_59.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_160k.py'
]
crop_size = (640, 640)
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='ResizeShortestEdge', scale=crop_size, max_size=2560),
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
# By default, models are trained on 8 GPUs with 2 images per GPU
train_dataloader = dict(batch_size=2)
val_dataloader = dict(batch_size=1, dataset=dict(pipeline=test_pipeline))
test_dataloader = val_dataloader
data_preprocessor = dict(
mean=[122.7709, 116.7460, 104.0937],
std=[68.5005, 66.6322, 70.3232],
size_divisor=640,
test_cfg=dict(size_divisor=32))
model = dict(
data_preprocessor=data_preprocessor,
pretrained='pretrain/vit_base_patch16_224.pth',
text_encoder=dict(dataset_name='pascal_context'),
decode_head=dict(num_classes=59))
# AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone
optim_wrapper = dict(
_delete_=True,
type='OptimWrapper',
optimizer=dict(
type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
paramwise_cfg=dict(
custom_keys={
'pos_embed': dict(decay_mult=0.),
'cls_token': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}))
param_scheduler = [
dict(
type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=1500),
dict(
type='PolyLR',
eta_min=0.0,
power=1.0,
begin=1500,
end=160000,
by_epoch=False,
)
]

View File

@ -0,0 +1,65 @@
_base_ = [
'../_base_/models/san_vit-b16.py',
'../_base_/datasets/pascal_voc12_aug.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_160k.py'
]
crop_size = (640, 640)
metainfo = dict(
classes=('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'),
palette=[[128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
[128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
[192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
[192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
[128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]])
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='ResizeShortestEdge', scale=crop_size, max_size=2560),
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
# By default, models are trained on 8 GPUs with 2 images per GPU
train_dataloader = dict(batch_size=2)
val_dataloader = dict(
batch_size=1, dataset=dict(metainfo=metainfo, pipeline=test_pipeline))
test_dataloader = val_dataloader
data_preprocessor = dict(
mean=[122.7709, 116.7460, 104.0937],
std=[68.5005, 66.6322, 70.3232],
size_divisor=640,
test_cfg=dict(size_divisor=32))
model = dict(
data_preprocessor=data_preprocessor,
pretrained='pretrain/vit_base_patch16_224.pth',
text_encoder=dict(dataset_name='voc'),
decode_head=dict(num_classes=20))
# AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone
optim_wrapper = dict(
_delete_=True,
type='OptimWrapper',
optimizer=dict(
type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
paramwise_cfg=dict(
custom_keys={
'pos_embed': dict(decay_mult=0.),
'cls_token': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}))
param_scheduler = [
dict(
type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=1500),
dict(
type='PolyLR',
eta_min=0.0,
power=1.0,
begin=1500,
end=160000,
by_epoch=False,
)
]

View File

@ -0,0 +1,36 @@
_base_ = ['./san-vit-b16_coco-stuff164k-640x640.py']
pretrained = 'https://download.openmmlab.com/mmsegmentation/v0.5/san/clip_vit-large-patch14-336_3rdparty-0b5df9cb.pth' # noqa
model = dict(
type='MultimodalEncoderDecoder',
pretrained=pretrained,
encoder_resolution=0.7,
image_encoder=dict(
type='VisionTransformer',
img_size=(336, 336),
patch_size=14,
patch_pad=0,
embed_dims=1024,
num_layers=18,
num_heads=16,
out_indices=(5, 11, 17),
),
text_encoder=dict(
type='CLIPTextEncoder',
embed_dims=768,
num_layers=12,
num_heads=12,
output_dims=768,
),
decode_head=dict(
type='SideAdapterCLIPHead',
san_cfg=dict(clip_channels=1024, cfg_decoder=dict(num_heads=16)),
maskgen_cfg=dict(
num_layers=6,
embed_dims=1024,
num_heads=16,
out_dims=768,
)))
# By default, models are trained on 8 GPUs with 4 images per GPU
train_dataloader = dict(batch_size=4)

View File

@ -0,0 +1,32 @@
_base_ = ['./san-vit-b16_pascal_context-640x640.py']
model = dict(
type='MultimodalEncoderDecoder',
pretrained='pretrain/jx_vit_base_p16_224-80ecf9dd.pth',
encoder_resolution=0.7,
image_encoder=dict(
type='VisionTransformer',
img_size=(336, 336),
patch_size=14,
patch_pad=0,
embed_dims=1024,
num_layers=18,
num_heads=16,
out_indices=(5, 11, 17),
),
text_encoder=dict(
type='CLIPTextEncoder',
embed_dims=768,
num_layers=12,
num_heads=12,
output_dims=768,
),
decode_head=dict(
type='SideAdapterCLIPHead',
san_cfg=dict(clip_channels=1024, cfg_decoder=dict(num_heads=16)),
maskgen_cfg=dict(
num_layers=6,
embed_dims=1024,
num_heads=16,
out_dims=768,
)))

View File

@ -0,0 +1,32 @@
_base_ = ['./san-vit-b16_voc12aug-640x640.py']
model = dict(
type='MultimodalEncoderDecoder',
pretrained='pretrain/jx_vit_base_p16_224-80ecf9dd.pth',
encoder_resolution=0.7,
image_encoder=dict(
type='VisionTransformer',
img_size=(336, 336),
patch_size=14,
patch_pad=0,
embed_dims=1024,
num_layers=18,
num_heads=16,
out_indices=(5, 11, 17),
),
text_encoder=dict(
type='CLIPTextEncoder',
embed_dims=768,
num_layers=12,
num_heads=12,
output_dims=768,
),
decode_head=dict(
type='SideAdapterCLIPHead',
san_cfg=dict(clip_channels=1024, cfg_decoder=dict(num_heads=16)),
maskgen_cfg=dict(
num_layers=6,
embed_dims=1024,
num_heads=16,
out_dims=768,
)))

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .assigners import * # noqa: F401,F403
from .backbones import * # noqa: F401,F403
from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone,
build_head, build_loss, build_segmentor)
@ -7,6 +8,7 @@ from .decode_heads import * # noqa: F401,F403
from .losses import * # noqa: F401,F403
from .necks import * # noqa: F401,F403
from .segmentors import * # noqa: F401,F403
from .text_encoder import * # noqa: F401,F403
__all__ = [
'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone',

View File

@ -0,0 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base_assigner import BaseAssigner
from .hungarian_assigner import HungarianAssigner
from .match_cost import ClassificationCost, CrossEntropyLossCost, DiceCost
__all__ = [
'BaseAssigner',
'HungarianAssigner',
'ClassificationCost',
'CrossEntropyLossCost',
'DiceCost',
]

View File

@ -0,0 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import Optional
from mmengine.structures import InstanceData
class BaseAssigner(metaclass=ABCMeta):
"""Base assigner that assigns masks to ground truth class labels."""
@abstractmethod
def assign(self,
pred_instances: InstanceData,
gt_instances: InstanceData,
gt_instances_ignore: Optional[InstanceData] = None,
**kwargs):
"""Assign masks to either a ground truth class label or a negative
label."""

View File

@ -0,0 +1,86 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Union
import torch
from mmengine import ConfigDict
from mmengine.structures import InstanceData
from scipy.optimize import linear_sum_assignment
from torch.cuda.amp import autocast
from mmseg.registry import TASK_UTILS
from .base_assigner import BaseAssigner
@TASK_UTILS.register_module()
class HungarianAssigner(BaseAssigner):
"""Computes one-to-one matching between prediction masks and ground truth.
This class uses bipartite matching-based assignment to computes an
assignment between the prediction masks and the ground truth. The
assignment result is based on the weighted sum of match costs. The
Hungarian algorithm is used to calculate the best matching with the
minimum cost. The prediction masks that are not matched are classified
as background.
Args:
match_costs (ConfigDict|List[ConfigDict]): Match cost configs.
"""
def __init__(
self, match_costs: Union[List[Union[dict, ConfigDict]], dict,
ConfigDict]
) -> None:
if isinstance(match_costs, dict):
match_costs = [match_costs]
elif isinstance(match_costs, list):
assert len(match_costs) > 0, \
'match_costs must not be a empty list.'
self.match_costs = [
TASK_UTILS.build(match_cost) for match_cost in match_costs
]
def assign(self, pred_instances: InstanceData, gt_instances: InstanceData,
**kwargs):
"""Computes one-to-one matching based on the weighted costs.
This method assign each query prediction to a ground truth or
background. The assignment first calculates the cost for each
category assigned to each query mask, and then uses the
Hungarian algorithm to calculate the minimum cost as the best
match.
Args:
pred_instances (InstanceData): Instances of model
predictions. It includes "masks", with shape
(n, h, w) or (n, l), and "cls", with shape (n, num_classes+1)
gt_instances (InstanceData): Ground truth of instance
annotations. It includes "labels", with shape (k, ),
and "masks", with shape (k, h, w) or (k, l).
Returns:
matched_quiery_inds (Tensor): The indexes of matched quieres.
matched_label_inds (Tensor): The indexes of matched labels.
"""
# compute weighted cost
cost_list = []
with autocast(enabled=False):
for match_cost in self.match_costs:
cost = match_cost(
pred_instances=pred_instances, gt_instances=gt_instances)
cost_list.append(cost)
cost = torch.stack(cost_list).sum(dim=0)
device = cost.device
# do Hungarian matching on CPU using linear_sum_assignment
cost = cost.detach().cpu()
if linear_sum_assignment is None:
raise ImportError('Please run "pip install scipy" '
'to install scipy first.')
matched_quiery_inds, matched_label_inds = linear_sum_assignment(cost)
matched_quiery_inds = torch.from_numpy(matched_quiery_inds).to(device)
matched_label_inds = torch.from_numpy(matched_label_inds).to(device)
return matched_quiery_inds, matched_label_inds

View File

@ -0,0 +1,231 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import abstractmethod
from typing import Union
import torch
import torch.nn.functional as F
from mmengine.structures import InstanceData
from torch import Tensor
from mmseg.registry import TASK_UTILS
class BaseMatchCost:
"""Base match cost class.
Args:
weight (Union[float, int]): Cost weight. Defaults to 1.
"""
def __init__(self, weight: Union[float, int] = 1.) -> None:
self.weight = weight
@abstractmethod
def __call__(self, pred_instances: InstanceData,
gt_instances: InstanceData, **kwargs) -> Tensor:
"""Compute match cost.
Args:
pred_instances (InstanceData): Instances of model predictions.
It often includes "labels" and "scores".
gt_instances (InstanceData): Ground truth of instance
annotations. It usually includes "labels".
Returns:
Tensor: Match Cost matrix of shape (num_preds, num_gts).
"""
pass
@TASK_UTILS.register_module()
class ClassificationCost(BaseMatchCost):
"""ClsSoftmaxCost.
Args:
weight (Union[float, int]): Cost weight. Defaults to 1.
Examples:
>>> from mmseg.models.assigners import ClassificationCost
>>> import torch
>>> self = ClassificationCost()
>>> cls_pred = torch.rand(4, 3)
>>> gt_labels = torch.tensor([0, 1, 2])
>>> factor = torch.tensor([10, 8, 10, 8])
>>> self(cls_pred, gt_labels)
tensor([[-0.3430, -0.3525, -0.3045],
[-0.3077, -0.2931, -0.3992],
[-0.3664, -0.3455, -0.2881],
[-0.3343, -0.2701, -0.3956]])
"""
def __init__(self, weight: Union[float, int] = 1) -> None:
super().__init__(weight=weight)
def __call__(self, pred_instances: InstanceData,
gt_instances: InstanceData, **kwargs) -> Tensor:
"""Compute match cost.
Args:
pred_instances (InstanceData): "scores" inside is
predicted classification logits, of shape
(num_queries, num_class).
gt_instances (InstanceData): "labels" inside should have
shape (num_gt, ).
Returns:
Tensor: Match Cost matrix of shape (num_preds, num_gts).
"""
assert hasattr(pred_instances, 'scores'), \
"pred_instances must contain 'scores'"
assert hasattr(gt_instances, 'labels'), \
"gt_instances must contain 'labels'"
pred_scores = pred_instances.scores
gt_labels = gt_instances.labels
pred_scores = pred_scores.softmax(-1)
cls_cost = -pred_scores[:, gt_labels]
return cls_cost * self.weight
@TASK_UTILS.register_module()
class DiceCost(BaseMatchCost):
"""Cost of mask assignments based on dice losses.
Args:
pred_act (bool): Whether to apply sigmoid to mask_pred.
Defaults to False.
eps (float): Defaults to 1e-3.
naive_dice (bool): If True, use the naive dice loss
in which the power of the number in the denominator is
the first power. If False, use the second power that
is adopted by K-Net and SOLO. Defaults to True.
weight (Union[float, int]): Cost weight. Defaults to 1.
"""
def __init__(self,
pred_act: bool = False,
eps: float = 1e-3,
naive_dice: bool = True,
weight: Union[float, int] = 1.) -> None:
super().__init__(weight=weight)
self.pred_act = pred_act
self.eps = eps
self.naive_dice = naive_dice
def _binary_mask_dice_loss(self, mask_preds: Tensor,
gt_masks: Tensor) -> Tensor:
"""
Args:
mask_preds (Tensor): Mask prediction in shape (num_queries, *).
gt_masks (Tensor): Ground truth in shape (num_gt, *)
store 0 or 1, 0 for negative class and 1 for
positive class.
Returns:
Tensor: Dice cost matrix in shape (num_queries, num_gt).
"""
mask_preds = mask_preds.flatten(1)
gt_masks = gt_masks.flatten(1).float()
numerator = 2 * torch.einsum('nc,mc->nm', mask_preds, gt_masks)
if self.naive_dice:
denominator = mask_preds.sum(-1)[:, None] + \
gt_masks.sum(-1)[None, :]
else:
denominator = mask_preds.pow(2).sum(1)[:, None] + \
gt_masks.pow(2).sum(1)[None, :]
loss = 1 - (numerator + self.eps) / (denominator + self.eps)
return loss
def __call__(self, pred_instances: InstanceData,
gt_instances: InstanceData, **kwargs) -> Tensor:
"""Compute match cost.
Args:
pred_instances (InstanceData): Predicted instances which
must contain "masks".
gt_instances (InstanceData): Ground truth which must contain
"mask".
Returns:
Tensor: Match Cost matrix of shape (num_preds, num_gts).
"""
assert hasattr(pred_instances, 'masks'), \
"pred_instances must contain 'masks'"
assert hasattr(gt_instances, 'masks'), \
"gt_instances must contain 'masks'"
pred_masks = pred_instances.masks
gt_masks = gt_instances.masks
if self.pred_act:
pred_masks = pred_masks.sigmoid()
dice_cost = self._binary_mask_dice_loss(pred_masks, gt_masks)
return dice_cost * self.weight
@TASK_UTILS.register_module()
class CrossEntropyLossCost(BaseMatchCost):
"""CrossEntropyLossCost.
Args:
use_sigmoid (bool): Whether the prediction uses sigmoid
of softmax. Defaults to True.
weight (Union[float, int]): Cost weight. Defaults to 1.
"""
def __init__(self,
use_sigmoid: bool = True,
weight: Union[float, int] = 1.) -> None:
super().__init__(weight=weight)
self.use_sigmoid = use_sigmoid
def _binary_cross_entropy(self, cls_pred: Tensor,
gt_labels: Tensor) -> Tensor:
"""
Args:
cls_pred (Tensor): The prediction with shape (num_queries, 1, *) or
(num_queries, *).
gt_labels (Tensor): The learning label of prediction with
shape (num_gt, *).
Returns:
Tensor: Cross entropy cost matrix in shape (num_queries, num_gt).
"""
cls_pred = cls_pred.flatten(1).float()
gt_labels = gt_labels.flatten(1).float()
n = cls_pred.shape[1]
pos = F.binary_cross_entropy_with_logits(
cls_pred, torch.ones_like(cls_pred), reduction='none')
neg = F.binary_cross_entropy_with_logits(
cls_pred, torch.zeros_like(cls_pred), reduction='none')
cls_cost = torch.einsum('nc,mc->nm', pos, gt_labels) + \
torch.einsum('nc,mc->nm', neg, 1 - gt_labels)
cls_cost = cls_cost / n
return cls_cost
def __call__(self, pred_instances: InstanceData,
gt_instances: InstanceData, **kwargs) -> Tensor:
"""Compute match cost.
Args:
pred_instances (:obj:`InstanceData`): Predicted instances which
must contain ``masks``.
gt_instances (:obj:`InstanceData`): Ground truth which must contain
``masks``.
Returns:
Tensor: Match Cost matrix of shape (num_preds, num_gts).
"""
assert hasattr(pred_instances, 'masks'), \
"pred_instances must contain 'masks'"
assert hasattr(gt_instances, 'masks'), \
"gt_instances must contain 'masks'"
pred_masks = pred_instances.masks
gt_masks = gt_instances.masks
if self.use_sigmoid:
cls_cost = self._binary_cross_entropy(pred_masks, gt_masks)
else:
raise NotImplementedError
return cls_cost * self.weight

View File

@ -132,12 +132,16 @@ class VisionTransformer(BaseModule):
Args:
img_size (int | tuple): Input image size. Default: 224.
patch_size (int): The patch size. Default: 16.
patch_pad (str | int | None): The padding method in patch embedding.
Default: 'corner'.
in_channels (int): Number of input channels. Default: 3.
embed_dims (int): embedding dimension. Default: 768.
num_layers (int): depth of transformer. Default: 12.
num_heads (int): number of attention heads. Default: 12.
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
Default: 4.
out_origin (bool): Whether to output the original input embedding.
Default: False
out_indices (list | tuple | int): Output from which stages.
Default: -1.
qkv_bias (bool): enable bias for qkv if True. Default: True.
@ -154,8 +158,12 @@ class VisionTransformer(BaseModule):
Default: dict(type='LN')
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
patch_bias (dict): Whether use bias in convolution of PatchEmbed Block.
Default: True.
patch_norm (bool): Whether to add a norm in PatchEmbed Block.
Default: False.
pre_norm (bool): Whether to add a norm before Transformer Layers.
Default: False.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Default: False.
interpolate_mode (str): Select the interpolate mode for position
@ -167,6 +175,8 @@ class VisionTransformer(BaseModule):
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save
some memory while slowing down the training speed. Default: False.
frozen_exclude (List): List of parameters that are not to be frozen.
Default: ["all"], "all" means there are no frozen parameters.
pretrained (str, optional): model pretrained path. Default: None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
@ -175,11 +185,13 @@ class VisionTransformer(BaseModule):
def __init__(self,
img_size=224,
patch_size=16,
patch_pad='corner',
in_channels=3,
embed_dims=768,
num_layers=12,
num_heads=12,
mlp_ratio=4,
out_origin=False,
out_indices=-1,
qkv_bias=True,
drop_rate=0.,
@ -190,11 +202,14 @@ class VisionTransformer(BaseModule):
norm_cfg=dict(type='LN'),
act_cfg=dict(type='GELU'),
patch_norm=False,
patch_bias=False,
pre_norm=False,
final_norm=False,
interpolate_mode='bicubic',
num_fcs=2,
norm_eval=False,
with_cp=False,
frozen_exclude=['all'],
pretrained=None,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
@ -227,6 +242,8 @@ class VisionTransformer(BaseModule):
self.norm_eval = norm_eval
self.with_cp = with_cp
self.pretrained = pretrained
self.out_origin = out_origin
self.frozen_exclude = frozen_exclude
self.patch_embed = PatchEmbed(
in_channels=in_channels,
@ -234,7 +251,8 @@ class VisionTransformer(BaseModule):
conv_type='Conv2d',
kernel_size=patch_size,
stride=patch_size,
padding='corner',
padding=patch_pad,
bias=patch_bias,
norm_cfg=norm_cfg if patch_norm else None,
init_cfg=None,
)
@ -248,6 +266,12 @@ class VisionTransformer(BaseModule):
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, embed_dims))
self.drop_after_pos = nn.Dropout(p=drop_rate)
self.pre_norm = pre_norm
if self.pre_norm:
self.pre_ln_name, pre_ln = build_norm_layer(
norm_cfg, embed_dims, postfix='_pre')
self.add_module(self.pre_ln_name, pre_ln)
if isinstance(out_indices, int):
if out_indices == -1:
@ -285,20 +309,36 @@ class VisionTransformer(BaseModule):
norm_cfg, embed_dims, postfix=1)
self.add_module(self.norm1_name, norm1)
self._freeze()
@property
def pre_ln(self):
return getattr(self, self.pre_ln_name)
@property
def norm1(self):
return getattr(self, self.norm1_name)
def init_weights(self):
if (isinstance(self.init_cfg, dict)
and self.init_cfg.get('type') == 'Pretrained'):
if isinstance(self.init_cfg, dict) and \
self.init_cfg.get('type') in ['Pretrained', 'Pretrained_Part']:
checkpoint = CheckpointLoader.load_checkpoint(
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
if self.init_cfg.get('type') == 'Pretrained':
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
elif self.init_cfg.get('type') == 'Pretrained_Part':
state_dict = checkpoint.copy()
para_prefix = 'image_encoder'
prefix_len = len(para_prefix) + 1
for k, v in checkpoint.items():
state_dict.pop(k)
if para_prefix in k:
state_dict[k[prefix_len:]] = v
if 'pos_embed' in state_dict.keys():
if self.pos_embed.shape != state_dict['pos_embed'].shape:
@ -334,6 +374,13 @@ class VisionTransformer(BaseModule):
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
constant_init(m, val=1.0, bias=0.)
def _freeze(self):
if 'all' in self.frozen_exclude:
return
for name, param in self.named_parameters():
if not any([exclude in name for exclude in self.frozen_exclude]):
param.requires_grad = False
def _pos_embeding(self, patched_img, hw_shape, pos_embed):
"""Positioning embeding method.
@ -409,7 +456,23 @@ class VisionTransformer(BaseModule):
# Remove class token for transformer encoder input
x = x[:, 1:]
if self.pre_norm:
x = self.pre_ln(x)
outs = []
if self.out_origin:
if self.with_cls_token:
# Remove class token and reshape token for decoder head
out = x[:, 1:]
else:
out = x
B, _, C = out.shape
out = out.reshape(B, hw_shape[0], hw_shape[1],
C).permute(0, 3, 1, 2).contiguous()
if self.output_cls_token:
out = [out, x[:, 0]]
outs.append(out)
for i, layer in enumerate(self.layers):
x = layer(x)
if i == len(self.layers) - 1:

View File

@ -25,6 +25,7 @@ from .pid_head import PIDHead
from .point_head import PointHead
from .psa_head import PSAHead
from .psp_head import PSPHead
from .san_head import SideAdapterCLIPHead
from .segformer_head import SegformerHead
from .segmenter_mask_head import SegmenterMaskTransformerHead
from .sep_aspp_head import DepthwiseSeparableASPPHead
@ -43,5 +44,5 @@ __all__ = [
'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead',
'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead',
'KernelUpdateHead', 'KernelUpdator', 'MaskFormerHead', 'Mask2FormerHead',
'LightHamHead', 'PIDHead', 'DDRHead', 'VPDDepthHead'
'LightHamHead', 'PIDHead', 'DDRHead', 'VPDDepthHead', 'SideAdapterCLIPHead'
]

View File

@ -0,0 +1,733 @@
# Copyright (c) OpenMMLab. All rights reserved.
from functools import partial
from typing import Dict, List, Tuple
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, build_norm_layer
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
from mmcv.ops import point_sample
from mmengine.dist import all_reduce
from mmengine.model.weight_init import (caffe2_xavier_init, normal_init,
trunc_normal_)
from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict
from mmengine.structures import InstanceData
from torch import Tensor
from torch.nn import functional as F
from mmseg.models.backbones.vit import TransformerEncoderLayer
from mmseg.registry import MODELS
from mmseg.utils import (ConfigType, MatchMasks, SampleList,
seg_data_to_instance_data)
from ..utils import (MLP, LayerNorm2d, PatchEmbed, cross_attn_layer,
get_uncertain_point_coords_with_randomness, resize)
from .decode_head import BaseDecodeHead
class MLPMaskDecoder(nn.Module):
"""Module for decoding query and visual features with MLP layers to
generate the attention biases and the mask proposals."""
def __init__(
self,
*,
in_channels: int,
total_heads: int = 1,
total_layers: int = 1,
embed_channels: int = 256,
mlp_channels: int = 256,
mlp_num_layers: int = 3,
rescale_attn_bias: bool = False,
):
super().__init__()
self.total_heads = total_heads
self.total_layers = total_layers
dense_affine_func = partial(nn.Conv2d, kernel_size=1)
# Query Branch
self.query_mlp = MLP(in_channels, mlp_channels, embed_channels,
mlp_num_layers)
# Pixel Branch
self.pix_mlp = MLP(
in_channels,
mlp_channels,
embed_channels,
mlp_num_layers,
affine_func=dense_affine_func,
)
# Attention Bias Branch
self.attn_mlp = MLP(
in_channels,
mlp_channels,
embed_channels * self.total_heads * self.total_layers,
mlp_num_layers,
affine_func=dense_affine_func,
)
if rescale_attn_bias:
self.bias_scaling = nn.Linear(1, 1)
else:
self.bias_scaling = nn.Identity()
def forward(self, query: torch.Tensor,
x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""Forward function.
Args:
query (Tensor): Query Tokens [B,N,C].
x (Tensor): Visual features [B,C,H,W]
Return:
mask_preds (Tensor): Mask proposals.
attn_bias (List[Tensor]): List of attention bias.
"""
query = self.query_mlp(query)
pix = self.pix_mlp(x)
b, c, h, w = pix.shape
# preidict mask
mask_preds = torch.einsum('bqc,bchw->bqhw', query, pix)
# generate attn bias
attn = self.attn_mlp(x)
attn = attn.reshape(b, self.total_layers, self.total_heads, c, h, w)
attn_bias = torch.einsum('bqc,blnchw->blnqhw', query, attn)
attn_bias = self.bias_scaling(attn_bias[..., None]).squeeze(-1)
attn_bias = attn_bias.chunk(self.total_layers, dim=1)
attn_bias = [attn.squeeze(1) for attn in attn_bias]
return mask_preds, attn_bias
class SideAdapterNetwork(nn.Module):
"""Side Adapter Network for predicting mask proposals and attention bias.
Args:
in_channels (int): Number of input channels. Default: 3.
clip_channels (int): Number of channels of visual features.
Default: 768.
embed_dims (int): embedding dimension. Default: 240.
patch_size (int): The patch size. Default: 16.
patch_bias (bool): Whether use bias in patch embedding.
Default: True.
num_queries (int): Number of queries for mask proposals.
Default: 100.
fusion_index (List[int]): The layer number of the encode
transformer to fuse with the CLIP feature.
Default: [0, 1, 2, 3].
cfg_encoder (ConfigType): Configs for the encode layers.
cfg_decoder (ConfigType): Configs for the decode layers.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
"""
def __init__(
self,
in_channels: int = 3,
clip_channels: int = 768,
embed_dims: int = 240,
patch_size: int = 16,
patch_bias: bool = True,
num_queries: int = 100,
fusion_index: list = [0, 1, 2, 3],
cfg_encoder: ConfigType = ...,
cfg_decoder: ConfigType = ...,
norm_cfg: dict = dict(type='LN'),
):
super().__init__()
self.patch_embed = PatchEmbed(
in_channels=in_channels,
embed_dims=embed_dims,
conv_type='Conv2d',
kernel_size=patch_size,
stride=patch_size,
padding=0,
input_size=(640, 640),
bias=patch_bias,
norm_cfg=None,
init_cfg=None,
)
ori_h, ori_w = self.patch_embed.init_out_size
num_patches = ori_h * ori_w
self.pos_embed = nn.Parameter(
torch.randn(1, num_patches, embed_dims) * .02)
self.query_pos_embed = nn.Parameter(
torch.zeros(1, num_queries, embed_dims))
self.query_embed = nn.Parameter(
torch.zeros(1, num_queries, embed_dims))
encode_layers = []
for i in range(cfg_encoder.num_encode_layer):
encode_layers.append(
TransformerEncoderLayer(
embed_dims=embed_dims,
num_heads=cfg_encoder.num_heads,
feedforward_channels=cfg_encoder.mlp_ratio * embed_dims,
norm_cfg=norm_cfg))
self.encode_layers = nn.ModuleList(encode_layers)
conv_clips = []
for i in range(len(fusion_index)):
conv_clips.append(
nn.Sequential(
LayerNorm2d(clip_channels),
ConvModule(
clip_channels,
embed_dims,
kernel_size=1,
norm_cfg=None,
act_cfg=None)))
self.conv_clips = nn.ModuleList(conv_clips)
self.fusion_index = fusion_index
self.mask_decoder = MLPMaskDecoder(
in_channels=embed_dims,
total_heads=cfg_decoder.num_heads,
total_layers=cfg_decoder.num_layers,
embed_channels=cfg_decoder.embed_channels,
mlp_channels=cfg_decoder.mlp_channels,
mlp_num_layers=cfg_decoder.num_mlp,
rescale_attn_bias=cfg_decoder.rescale)
def init_weights(self):
trunc_normal_(self.pos_embed, std=0.02)
nn.init.normal_(self.query_embed, std=0.02)
nn.init.normal_(self.query_pos_embed, std=0.02)
for i in range(len(self.conv_clips)):
caffe2_xavier_init(self.conv_clips[i][1].conv)
def fuse_clip(self, fused_index: int, x: torch.Tensor,
clip_feature: torch.Tensor, hwshape: Tuple[int,
int], L: int):
"""Fuse CLIP feature and visual tokens."""
fused_clip = (resize(
self.conv_clips[fused_index](clip_feature.contiguous()),
size=hwshape,
mode='bilinear',
align_corners=False)).permute(0, 2, 3, 1).reshape(x[:, -L:,
...].shape)
x = torch.cat([x[:, :-L, ...], x[:, -L:, ...] + fused_clip], dim=1)
return x
def encode_feature(self, image: torch.Tensor,
clip_features: List[torch.Tensor],
deep_supervision_idxs: List[int]) -> List[List]:
"""Encode images by a lightweight vision transformer."""
assert len(self.fusion_index) == len(clip_features)
x, hwshape = self.patch_embed(image)
ori_h, ori_w = self.patch_embed.init_out_size
pos_embed = self.pos_embed
if self.pos_embed.shape[1] != x.shape[1]:
# resize the position embedding
pos_embed = (
resize(
self.pos_embed.reshape(1, ori_h, ori_w,
-1).permute(0, 3, 1, 2),
size=hwshape,
mode='bicubic',
align_corners=False,
).flatten(2).permute(0, 2, 1))
pos_embed = torch.cat([
self.query_pos_embed.expand(pos_embed.shape[0], -1, -1), pos_embed
],
dim=1)
x = torch.cat([self.query_embed.expand(x.shape[0], -1, -1), x], dim=1)
x = x + pos_embed
L = hwshape[0] * hwshape[1]
fused_index = 0
if self.fusion_index[fused_index] == 0:
x = self.fuse_clip(fused_index, x, clip_features[0][0], hwshape, L)
fused_index += 1
outs = []
for index, block in enumerate(self.encode_layers, start=1):
x = block(x)
if index < len(self.fusion_index
) and index == self.fusion_index[fused_index]:
x = self.fuse_clip(fused_index, x,
clip_features[fused_index][0], hwshape, L)
fused_index += 1
x_query = x[:, :-L, ...]
x_feat = x[:, -L:, ...].permute(0, 2, 1)\
.reshape(x.shape[0], x.shape[-1], hwshape[0], hwshape[1])
if index in deep_supervision_idxs or index == len(
self.encode_layers):
outs.append({'query': x_query, 'x': x_feat})
if index < len(self.encode_layers):
x = x + pos_embed
return outs
def decode_feature(self, features):
mask_embeds = []
attn_biases = []
for feature in features:
mask_embed, attn_bias = self.mask_decoder(**feature)
mask_embeds.append(mask_embed)
attn_biases.append(attn_bias)
return mask_embeds, attn_biases
def forward(
self, image: torch.Tensor, clip_features: List[torch.Tensor],
deep_supervision_idxs: List[int]
) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]:
"""Forward function."""
features = self.encode_feature(image, clip_features,
deep_supervision_idxs)
mask_embeds, attn_biases = self.decode_feature(features)
return mask_embeds, attn_biases
class RecWithAttnbias(nn.Module):
"""Mask recognition module by applying the attention biases to rest deeper
CLIP layers.
Args:
sos_token_format (str): The format of sos token. It should be
chosen from ["cls_token", "learnable_token", "pos_embedding"].
Default: 'cls_token'.
sos_token_num (int): Number of sos token. It should be equal to
the number of quries. Default: 100.
num_layers (int): Number of rest CLIP layers for mask recognition.
Default: 3.
cross_attn (bool): Whether use cross attention to update sos token.
Default: False.
embed_dims (int): The feature dimension of CLIP layers.
Default: 768.
num_heads (int): Parallel attention heads of CLIP layers.
Default: 768.
mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
Default: 4.
qkv_bias (bool): Whether to use bias in multihead-attention.
Default: True.
out_dims (int): Number of channels of the output mask proposals.
It should be equal to the out_dims of text_encoder.
Default: 512.
final_norm (True): Whether use norm layer for sos token.
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
frozen_exclude (List): List of parameters that are not to be frozen.
"""
def __init__(self,
sos_token_format: str = 'cls_token',
sos_token_num: int = 100,
num_layers: int = 3,
cross_attn: bool = False,
embed_dims: int = 768,
num_heads: int = 12,
mlp_ratio: int = 4,
num_fcs: int = 2,
qkv_bias: bool = True,
out_dims: int = 512,
final_norm: bool = True,
act_cfg: dict = dict(type='GELU'),
norm_cfg: dict = dict(type='LN'),
frozen_exclude: List = []):
super().__init__()
assert sos_token_format in [
'cls_token', 'learnable_token', 'pos_embedding'
]
self.sos_token_format = sos_token_format
self.sos_token_num = sos_token_num
self.frozen_exclude = frozen_exclude
self.cross_attn = cross_attn
self.num_layers = num_layers
self.num_heads = num_heads
if sos_token_format in ['learnable_token', 'pos_embedding']:
self.sos_token = nn.Parameter(
torch.randn(sos_token_num, 1, self.proj.shape[0]))
self.frozen.append('sos_token')
layers = []
for i in range(num_layers):
layers.append(
BaseTransformerLayer(
attn_cfgs=dict(
type='MultiheadAttention',
embed_dims=embed_dims,
num_heads=num_heads,
batch_first=False,
bias=qkv_bias),
ffn_cfgs=dict(
type='FFN',
embed_dims=embed_dims,
feedforward_channels=mlp_ratio * embed_dims,
act_cfg=act_cfg),
operation_order=('norm', 'self_attn', 'norm', 'ffn')))
self.layers = nn.ModuleList(layers)
self.ln_post = build_norm_layer(norm_cfg, embed_dims)[1]
self.proj = nn.Linear(embed_dims, out_dims, bias=False)
self.final_norm = final_norm
self._freeze()
def init_weights(self, rec_state_dict):
if hasattr(self, 'sos_token'):
normal_init(self.sos_token, std=0.02)
if rec_state_dict is not None:
load_state_dict(self, rec_state_dict, strict=False, logger=None)
else:
super().init_weights()
def _freeze(self):
if 'all' in self.frozen_exclude:
return
for name, param in self.named_parameters():
if not any([exclude in name for exclude in self.frozen_exclude]):
param.requires_grad = False
def _build_attn_biases(self, attn_biases, target_shape):
formatted_attn_biases = []
for attn_bias in attn_biases:
# convert it to proper format: N*num_head,L,L
# attn_bias: [N, num_head/1, num_sos,H,W]
n, num_head, num_sos, h, w = attn_bias.shape
# reshape and downsample
attn_bias = F.adaptive_max_pool2d(
attn_bias.reshape(n, num_head * num_sos, h, w),
output_size=target_shape)
attn_bias = attn_bias.reshape(n, num_head, num_sos, *target_shape)
true_num_head = self.num_heads
assert (num_head == 1 or num_head
== true_num_head), f'num_head={num_head} is not supported.'
if num_head == 1:
attn_bias = attn_bias.repeat(1, true_num_head, 1, 1, 1)
attn_bias = attn_bias.reshape(n * true_num_head, num_sos, -1)
L = attn_bias.shape[-1]
if self.cross_attn:
# [n*num_head, num_sos, L]
formatted_attn_biases.append(attn_bias)
else:
# [n*num_head, num_sos+1+L, num_sos+1+L]
new_attn_bias = attn_bias.new_zeros(num_sos + 1 + L,
num_sos + 1 + L)
new_attn_bias[:, :num_sos] = -100
new_attn_bias[torch.arange(num_sos), torch.arange(num_sos)] = 0
new_attn_bias[:num_sos, num_sos] = -100
new_attn_bias = (
new_attn_bias[None, ...].expand(n * true_num_head, -1,
-1).clone())
new_attn_bias[..., :num_sos, -L:] = attn_bias
formatted_attn_biases.append(new_attn_bias)
if len(formatted_attn_biases) == 1:
formatted_attn_biases = [
formatted_attn_biases[0] for _ in range(self.num_layers)
]
return formatted_attn_biases
def forward(self, bias: List[Tensor], feature: List[Tensor]):
"""Forward function to recognize the category of masks
Args:
bias (List[Tensor]): Attention bias for transformer layers
feature (List[Tensor]): Output of the image encoder,
including cls_token and img_feature.
"""
cls_token = feature[1].unsqueeze(0)
img_feature = feature[0]
b, c, h, w = img_feature.shape
# construct clip shadow features
x = torch.cat(
[cls_token,
img_feature.reshape(b, c, -1).permute(2, 0, 1)])
# construct sos token
if self.sos_token_format == 'cls_token':
sos_token = cls_token.repeat(self.sos_token_num, 1, 1)
elif self.sos_token_format == 'learnable_token':
sos_token = self.sos_token.expand(-1, b, -1)
elif self.sos_token_format == 'pos_embedding':
sos_token = self.sos_token.expand(-1, b, -1) + cls_token
# construct attn bias
attn_biases = self._build_attn_biases(bias, target_shape=(h, w))
if self.cross_attn:
for i, block in enumerate(self.layers):
if self.cross_attn:
sos_token = cross_attn_layer(
block,
sos_token,
x[1:, ],
attn_biases[i],
)
if i < len(self.layers) - 1:
x = block(x)
else:
x = torch.cat([sos_token, x], dim=0)
for i, block in enumerate(self.layers):
x = block(x, attn_masks=[attn_biases[i]])
sos_token = x[:self.sos_token_num]
sos_token = sos_token.permute(1, 0, 2) # LND -> NLD
sos_token = self.ln_post(sos_token)
sos_token = self.proj(sos_token)
if self.final_norm:
sos_token = F.normalize(sos_token, dim=-1)
return sos_token
@MODELS.register_module()
class SideAdapterCLIPHead(BaseDecodeHead):
"""Side Adapter Network (SAN) for open-vocabulary semantic segmentation
with pre-trained vision-language model.
This decode head is the implementation of `Side Adapter Network
for Open-Vocabulary Semantic Segmentation`
<https://arxiv.org/abs/2302.12242>.
Modified from https://github.com/MendelXu/SAN/blob/main/san/model/side_adapter/side_adapter.py # noqa:E501
Copyright (c) 2023 MendelXu.
Licensed under the MIT License
Args:
num_classes (int): the number of classes.
san_cfg (ConfigType): Configs for SideAdapterNetwork module
maskgen_cfg (ConfigType): Configs for RecWithAttnbias module
"""
def __init__(self, num_classes: int, san_cfg: ConfigType,
maskgen_cfg: ConfigType, deep_supervision_idxs: List[int],
train_cfg: ConfigType, **kwargs):
super().__init__(
in_channels=san_cfg.in_channels,
channels=san_cfg.embed_dims,
num_classes=num_classes,
**kwargs)
assert san_cfg.num_queries == maskgen_cfg.sos_token_num, \
'num_queries in san_cfg should be equal to sos_token_num ' \
'in maskgen_cfg'
del self.conv_seg
self.side_adapter_network = SideAdapterNetwork(**san_cfg)
self.rec_with_attnbias = RecWithAttnbias(**maskgen_cfg)
self.deep_supervision_idxs = deep_supervision_idxs
self.train_cfg = train_cfg
if train_cfg:
self.match_masks = MatchMasks(
num_points=train_cfg.num_points,
num_queries=san_cfg.num_queries,
num_classes=num_classes,
assigner=train_cfg.assigner)
def init_weights(self):
rec_state_dict = None
if isinstance(self.init_cfg, dict) and \
self.init_cfg.get('type') == 'Pretrained_Part':
checkpoint = CheckpointLoader.load_checkpoint(
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
rec_state_dict = checkpoint.copy()
para_prefix = 'decode_head.rec_with_attnbias'
prefix_len = len(para_prefix) + 1
for k, v in checkpoint.items():
rec_state_dict.pop(k)
if para_prefix in k:
rec_state_dict[k[prefix_len:]] = v
self.side_adapter_network.init_weights()
self.rec_with_attnbias.init_weights(rec_state_dict)
def forward(self, inputs: Tuple[Tensor],
deep_supervision_idxs) -> Tuple[List]:
"""Forward function.
Args:
inputs (Tuple[Tensor]): A triplet including images,
list of multi-level visual features from image encoder and
class embeddings from text_encoder.
Returns:
mask_props (List[Tensor]): Mask proposals predicted by SAN.
mask_logits (List[Tensor]): Class logits of mask proposals.
"""
imgs, clip_feature, class_embeds = inputs
# predict mask proposals and attention bias
mask_props, attn_biases = self.side_adapter_network(
imgs, clip_feature, deep_supervision_idxs)
# mask recognition with attention bias
mask_embeds = [
self.rec_with_attnbias(att_bias, clip_feature[-1])
for att_bias in attn_biases
]
# Obtain class prediction of masks by comparing the similarity
# between the image token and the text embedding of class names.
mask_logits = [
torch.einsum('bqc,nc->bqn', mask_embed, class_embeds)
for mask_embed in mask_embeds
]
return mask_props, mask_logits
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
test_cfg: ConfigType) -> Tensor:
"""Forward function for prediction.
Args:
inputs (Tuple[Tensor]): Images, visual features from image encoder
and class embedding from text encoder.
batch_img_metas (dict): List Image info where each dict may also
contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
'ori_shape', and 'pad_shape'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
test_cfg (dict): The testing config.
Returns:
Tensor: Outputs segmentation logits map.
"""
mask_props, mask_logits = self.forward(inputs, [])
return self.predict_by_feat([mask_props[-1], mask_logits[-1]],
batch_img_metas)
def predict_by_feat(self, seg_logits: List[Tensor],
batch_img_metas: List[dict]) -> Tensor:
"""1. Transform a batch of mask proposals to the input shape.
2. Generate segmentation map with mask proposals and class logits.
"""
mask_pred = seg_logits[0]
cls_score = seg_logits[1]
if 'pad_shape' in batch_img_metas[0]:
size = batch_img_metas[0]['pad_shape']
else:
size = batch_img_metas[0]['img_shape']
# upsample mask
mask_pred = F.interpolate(
mask_pred, size=size, mode='bilinear', align_corners=False)
mask_cls = F.softmax(cls_score, dim=-1)[..., :-1]
mask_pred = mask_pred.sigmoid()
seg_logits = torch.einsum('bqc,bqhw->bchw', mask_cls, mask_pred)
return seg_logits
def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList,
train_cfg: ConfigType) -> dict:
"""Perform forward propagation and loss calculation of the decoder head
on the features of the upstream network.
Args:
x (tuple[Tensor]): Multi-level features from the upstream
network, each is a 4D-tensor.
batch_data_samples (List[:obj:`SegDataSample`]): The Data
Samples. It usually includes information such as
`gt_sem_seg`.
train_cfg (ConfigType): Training config.
Returns:
dict[str, Tensor]: a dictionary of loss components.
"""
# batch SegDataSample to InstanceDataSample
batch_gt_instances = seg_data_to_instance_data(self.ignore_index,
batch_data_samples)
# forward
all_mask_props, all_mask_logits = self.forward(
x, self.deep_supervision_idxs)
# loss
losses = self.loss_by_feat(all_mask_logits, all_mask_props,
batch_gt_instances)
return losses
def loss_by_feat(
self, all_cls_scores: Tensor, all_mask_preds: Tensor,
batch_gt_instances: List[InstanceData]) -> Dict[str, Tensor]:
"""Loss function.
Args:
all_cls_scores (Tensor): Classification scores for all decoder
layers with shape (num_decoder, batch_size, num_queries,
cls_out_channels). Note `cls_out_channels` should includes
background.
all_mask_preds (Tensor): Mask scores for all decoder layers with
shape (num_decoder, batch_size, num_queries, h, w).
batch_gt_instances (list[obj:`InstanceData`]): each contains
``labels`` and ``masks``.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
num_dec_layers = len(all_cls_scores)
batch_gt_instances_list = [
batch_gt_instances for _ in range(num_dec_layers)
]
losses = []
for i in range(num_dec_layers):
cls_scores = all_cls_scores[i]
mask_preds = all_mask_preds[i]
# matching N mask predictions to K category labels
(labels, mask_targets, mask_weights,
avg_factor) = self.match_masks.get_targets(
cls_scores, mask_preds, batch_gt_instances_list[i])
cls_scores = cls_scores.flatten(0, 1)
labels = labels.flatten(0, 1)
num_total_masks = cls_scores.new_tensor([avg_factor],
dtype=torch.float)
all_reduce(num_total_masks, op='mean')
num_total_masks = max(num_total_masks, 1)
# extract positive ones
# shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w)
mask_preds = mask_preds[mask_weights > 0]
if mask_targets.shape[0] != 0:
with torch.no_grad():
points_coords = get_uncertain_point_coords_with_randomness(
mask_preds.unsqueeze(1), None,
self.train_cfg.num_points,
self.train_cfg.oversample_ratio,
self.train_cfg.importance_sample_ratio)
# shape (num_total_gts, h, w)
# -> (num_total_gts, num_points)
mask_point_targets = point_sample(
mask_targets.unsqueeze(1).float(),
points_coords).squeeze(1)
# shape (num_queries, h, w) -> (num_queries, num_points)
mask_point_preds = point_sample(
mask_preds.unsqueeze(1), points_coords).squeeze(1)
if not isinstance(self.loss_decode, nn.ModuleList):
losses_decode = [self.loss_decode]
else:
losses_decode = self.loss_decode
loss = dict()
for loss_decode in losses_decode:
if 'loss_cls' in loss_decode.loss_name:
if loss_decode.loss_name == 'loss_cls_ce':
loss[loss_decode.loss_name] = loss_decode(
cls_scores, labels)
else:
assert False, "Only support 'CrossEntropyLoss' in" \
' classification loss'
elif 'loss_mask' in loss_decode.loss_name:
if mask_targets.shape[0] == 0:
loss[loss_decode.loss_name] = mask_preds.sum()
elif loss_decode.loss_name == 'loss_mask_ce':
loss[loss_decode.loss_name] = loss_decode(
mask_point_preds,
mask_point_targets,
avg_factor=num_total_masks *
self.train_cfg.num_points)
elif loss_decode.loss_name == 'loss_mask_dice':
loss[loss_decode.loss_name] = loss_decode(
mask_point_preds,
mask_point_targets,
avg_factor=num_total_masks)
else:
assert False, "Only support 'CrossEntropyLoss' and" \
" 'DiceLoss' in mask loss"
else:
assert False, "Only support for 'loss_cls' and 'loss_mask'"
losses.append(loss)
loss_dict = dict()
# loss from the last decoder layer
loss_dict.update(losses[-1])
# loss from other decoder layers
for i, loss in enumerate(losses[:-1]):
for k, v in loss.items():
loss_dict[f'd{self.deep_supervision_idxs[i]}.{k}'] = v
return loss_dict

View File

@ -53,8 +53,22 @@ def cross_entropy(pred,
# average loss over non-ignored elements
# pytorch's official cross_entropy average loss over non-ignored elements
# refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
avg_factor = label.numel() - (label == ignore_index).sum().item()
if (avg_factor is None) and reduction == 'mean':
if class_weight is None:
if avg_non_ignore:
avg_factor = label.numel() - (label
== ignore_index).sum().item()
else:
avg_factor = label.numel()
else:
# the average factor should take the class weights into account
label_weights = torch.tensor([class_weight[cls] for cls in label],
device=class_weight.device)
if avg_non_ignore:
label_weights[label == ignore_index] = 0
avg_factor = label_weights.sum()
if weight is not None:
weight = weight.float()
loss = weight_reduce_loss(

View File

@ -66,10 +66,11 @@ def dice_loss(pred: torch.Tensor,
ignore_index (int, optional): The label index to be ignored.
Defaults to 255.
"""
num_classes = pred.shape[1]
pred = pred[:, torch.arange(num_classes) != ignore_index, :, :]
target = target[:, torch.arange(num_classes) != ignore_index, :, :]
assert pred.shape[1] != 0 # if the ignored index is the only class
if ignore_index is not None:
num_classes = pred.shape[1]
pred = pred[:, torch.arange(num_classes) != ignore_index, :, :]
target = target[:, torch.arange(num_classes) != ignore_index, :, :]
assert pred.shape[1] != 0 # if the ignored index is the only class
input = pred.flatten(1)
target = target.flatten(1).float()
a = torch.sum(input * target, 1)

View File

@ -3,9 +3,10 @@ from .base import BaseSegmentor
from .cascade_encoder_decoder import CascadeEncoderDecoder
from .depth_estimator import DepthEstimator
from .encoder_decoder import EncoderDecoder
from .multimodal_encoder_decoder import MultimodalEncoderDecoder
from .seg_tta import SegTTAModel
__all__ = [
'BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder', 'SegTTAModel',
'DepthEstimator'
'MultimodalEncoderDecoder', 'DepthEstimator'
]

View File

@ -0,0 +1,350 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional
import torch.nn.functional as F
from torch import Tensor
from mmseg.registry import MODELS
from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig,
OptSampleList, SampleList, add_prefix)
from .base import BaseSegmentor
@MODELS.register_module()
class MultimodalEncoderDecoder(BaseSegmentor):
"""Multimodal Encoder-Decoder segmentors.
Multimodal segmentation architecture is used for open-vocabulary
semantic segmentation with combining the visual and language
pretrain models. It consists of a image_encoder (backbone) to extract
visual feature, a text encoder to extract text feature, and a decode
head to generate semantic maps.
Note that the deep supervision during training is implemented in decode head.
1. The ``loss`` method is used to calculate the loss of model,
which includes two steps: (1) Extracts features to obtain the feature maps
(2) Call the decode head loss function to forward decode head model and
calculate losses.
.. code:: text
loss(): extract_feat() -> _decode_head_forward_train()
_decode_head_forward_train(): decode_head.loss()
2. The ``predict`` method is used to predict segmentation results,
which includes two steps: (1) Run inference function to obtain the list of
seg_logits (2) Call post-processing function to obtain list of
``SegDataSampel`` including ``pred_sem_seg`` and ``seg_logits``.
.. code:: text
predict(): inference() -> postprocess_result()
inference(): whole_inference()/slide_inference()
whole_inference()/slide_inference(): encoder_decoder()
encoder_decoder(): extract_feat() -> decode_head.predict()
3. The ``_forward`` method is used to output the tensor by running the model,
which includes two steps: (1) Extracts features to obtain the feature maps
(2)Call the decode head forward function to forward decode head model.
.. code:: text
_forward(): extract_feat() -> _decode_head.forward()
Args:
image_encoder (ConfigType): The config for the visual encoder of segmentor.
text_encoder ((ConfigType): The config for the text encoder of segmentor.
decode_head (ConfigType): The config for the decode head of segmentor.
train_cfg (OptConfigType): The config for training. Defaults to None.
test_cfg (OptConfigType): The config for testing. Defaults to None.
data_preprocessor (dict, optional): The pre-process config of
:class:`BaseDataPreprocessor`.
pretrained (str, optional): The path for pretrained model.
Defaults to None.
asymetric_input (bool): whether to use different size of input for image encoder
and decode head. Defaults to False.
encoder_resolution (float): resize scale of input images for image encoder.
Defaults to None.
init_cfg (dict, optional): The weight initialized config for
:class:`BaseModule`.
""" # noqa: E501
def __init__(self,
image_encoder: ConfigType,
text_encoder: ConfigType,
decode_head: ConfigType,
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
data_preprocessor: OptConfigType = None,
pretrained: Optional[str] = None,
asymetric_input: bool = True,
encoder_resolution: float = None,
init_cfg: OptMultiConfig = None):
super().__init__(
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
if pretrained is not None:
image_encoder.init_cfg = dict(
type='Pretrained_Part', checkpoint=pretrained)
text_encoder.init_cfg = dict(
type='Pretrained_Part', checkpoint=pretrained)
decode_head.init_cfg = dict(
type='Pretrained_Part', checkpoint=pretrained)
if asymetric_input:
assert encoder_resolution is not None, \
'if asymetric_input set True, ' \
'clip_resolution must be a certain value'
self.asymetric_input = asymetric_input
self.encoder_resolution = encoder_resolution
self.image_encoder = MODELS.build(image_encoder)
self.text_encoder = MODELS.build(text_encoder)
self._init_decode_head(decode_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
assert self.with_decode_head
def _init_decode_head(self, decode_head: ConfigType) -> None:
"""Initialize ``decode_head``"""
self.decode_head = MODELS.build(decode_head)
self.align_corners = self.decode_head.align_corners
self.num_classes = self.decode_head.num_classes
self.out_channels = self.decode_head.out_channels
def extract_feat(self, inputs: Tensor) -> List[Tensor]:
"""Extract visual features from images."""
x = self.image_encoder(inputs)
return x
def encode_decode(self, inputs: Tensor,
batch_img_metas: List[dict]) -> Tensor:
"""Encode the name of classes with text_encoder and encode images with
image_encoder.
Then decode the class embedding and visual feature into a semantic
segmentation map of the same size as input.
"""
classifier_embeds = self.text_encoder()
clip_inputs = inputs
if self.asymetric_input:
clip_inputs = F.interpolate(
inputs, scale_factor=self.encoder_resolution, mode='bilinear')
x = self.image_encoder(clip_inputs)
seg_logits = self.decode_head.predict([inputs, x, classifier_embeds],
batch_img_metas, self.test_cfg)
return seg_logits
def _decode_head_forward_train(self, inputs: List[Tensor],
data_samples: SampleList) -> dict:
"""Run forward function and calculate loss for decode head in
training."""
losses = dict()
loss_decode = self.decode_head.loss(inputs, data_samples,
self.train_cfg)
losses.update(add_prefix(loss_decode, 'decode'))
return losses
def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
"""Calculate losses from a batch of inputs and data samples.
Args:
inputs (Tensor): Input images.
data_samples (list[:obj:`SegDataSample`]): The seg data samples.
It usually includes information such as `metainfo` and
`gt_sem_seg`.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
classifier_embeds = self.text_encoder()
clip_inputs = inputs
if self.asymetric_input:
clip_inputs = F.interpolate(
inputs, scale_factor=self.encoder_resolution, mode='bilinear')
x = self.image_encoder(clip_inputs)
losses = dict()
loss_decode = self._decode_head_forward_train(
[inputs, x, classifier_embeds], data_samples)
losses.update(loss_decode)
return losses
def predict(self,
inputs: Tensor,
data_samples: OptSampleList = None) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing.
Args:
inputs (Tensor): Inputs with shape (N, C, H, W).
data_samples (List[:obj:`SegDataSample`], optional): The seg data
samples. It usually includes information such as `metainfo`
and `gt_sem_seg`.
Returns:
list[:obj:`SegDataSample`]: Segmentation results of the
input images. Each SegDataSample usually contain:
- ``pred_sem_seg``(PixelData): Prediction of semantic segmentation.
- ``seg_logits``(PixelData): Predicted logits of semantic
segmentation before normalization.
"""
if data_samples is not None:
batch_img_metas = [
data_sample.metainfo for data_sample in data_samples
]
else:
batch_img_metas = [
dict(
ori_shape=inputs.shape[2:],
img_shape=inputs.shape[2:],
pad_shape=inputs.shape[2:],
padding_size=[0, 0, 0, 0])
] * inputs.shape[0]
seg_logits = self.inference(inputs, batch_img_metas)
return self.postprocess_result(seg_logits, data_samples)
def _forward(self,
inputs: Tensor,
data_samples: OptSampleList = None) -> Tensor:
"""Network forward process.
Args:
inputs (Tensor): Inputs with shape (N, C, H, W).
data_samples (List[:obj:`SegDataSample`]): The seg
data samples. It usually includes information such
as `metainfo` and `gt_sem_seg`.
Returns:
Tensor: Forward output of model without any post-processes.
"""
x = self.extract_feat(inputs)
return self.decode_head.forward(x)
def slide_inference(self, inputs: Tensor,
batch_img_metas: List[dict]) -> Tensor:
"""Inference by sliding-window with overlap.
If h_crop > h_img or w_crop > w_img, the small patch will be used to
decode without padding.
Args:
inputs (tensor): the tensor should have a shape NxCxHxW,
which contains all images in the batch.
batch_img_metas (List[dict]): List of image metainfo where each may
also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
'ori_shape', and 'pad_shape'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
Returns:
Tensor: The segmentation results, seg_logits from model of each
input image.
"""
h_stride, w_stride = self.test_cfg.stride
h_crop, w_crop = self.test_cfg.crop_size
batch_size, _, h_img, w_img = inputs.size()
out_channels = self.out_channels
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img))
count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img))
for h_idx in range(h_grids):
for w_idx in range(w_grids):
y1 = h_idx * h_stride
x1 = w_idx * w_stride
y2 = min(y1 + h_crop, h_img)
x2 = min(x1 + w_crop, w_img)
y1 = max(y2 - h_crop, 0)
x1 = max(x2 - w_crop, 0)
crop_img = inputs[:, :, y1:y2, x1:x2]
# change the image shape to patch shape
batch_img_metas[0]['img_shape'] = crop_img.shape[2:]
# the output of encode_decode is seg logits tensor map
# with shape [N, C, H, W]
crop_seg_logit = self.encode_decode(crop_img, batch_img_metas)
preds += F.pad(crop_seg_logit,
(int(x1), int(preds.shape[3] - x2), int(y1),
int(preds.shape[2] - y2)))
count_mat[:, :, y1:y2, x1:x2] += 1
assert (count_mat == 0).sum() == 0
seg_logits = preds / count_mat
return seg_logits
def whole_inference(self, inputs: Tensor,
batch_img_metas: List[dict]) -> Tensor:
"""Inference with full image.
Args:
inputs (Tensor): The tensor should have a shape NxCxHxW, which
contains all images in the batch.
batch_img_metas (List[dict]): List of image metainfo where each may
also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
'ori_shape', and 'pad_shape'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
Returns:
Tensor: The segmentation results, seg_logits from model of each
input image.
"""
seg_logits = self.encode_decode(inputs, batch_img_metas)
return seg_logits
def inference(self, inputs: Tensor, batch_img_metas: List[dict]) -> Tensor:
"""Inference with slide/whole style.
Args:
inputs (Tensor): The input image of shape (N, 3, H, W).
batch_img_metas (List[dict]): List of image metainfo where each may
also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
'ori_shape', 'pad_shape', and 'padding_size'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
Returns:
Tensor: The segmentation results, seg_logits from model of each
input image.
"""
assert self.test_cfg.mode in ['slide', 'whole']
ori_shape = batch_img_metas[0]['ori_shape']
assert all(_['ori_shape'] == ori_shape for _ in batch_img_metas)
if self.test_cfg.mode == 'slide':
seg_logit = self.slide_inference(inputs, batch_img_metas)
else:
seg_logit = self.whole_inference(inputs, batch_img_metas)
return seg_logit
def aug_test(self, inputs, batch_img_metas, rescale=True):
"""Test with augmentations.
Only rescale=True is supported.
"""
# aug_test rescale all imgs back to ori_shape for now
assert rescale
# to save memory, we get augmented seg logit inplace
seg_logit = self.inference(inputs[0], batch_img_metas[0], rescale)
for i in range(1, len(inputs)):
cur_seg_logit = self.inference(inputs[i], batch_img_metas[i],
rescale)
seg_logit += cur_seg_logit
seg_logit /= len(inputs)
seg_pred = seg_logit.argmax(dim=1)
# unravel batch dim
seg_pred = list(seg_pred)
return seg_pred

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .clip_text_encoder import CLIPTextEncoder
__all__ = ['CLIPTextEncoder']

View File

@ -0,0 +1,229 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
from mmengine.model import BaseModule, ModuleList
from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict
from torch.nn import functional as F
from mmseg.registry import MODELS
from mmseg.utils import get_classes, get_predefined_templates, tokenizer
@MODELS.register_module()
class CLIPTextEncoder(BaseModule):
"""A text encoder with transformer architecture to encode the label text.
Modified from https://github.com/MendelXu/SAN/blob/main/san/model/clip_utils/classifier.py # noqa:E501
Copyright (c) 2023 MendelXu.
Licensed under the MIT License
Args:
dataset_name: (str|None): The name of the dataset to which
the data belongs.
vocabulary: (List[str]|None): The list of class names. Default: None.
templates: (List[str]|None): The prompt template used for labels.
Default: None.
total_vocab_size: (int): Number of all words used by the pre-trained
model. Default: 49408 (CLIP).
context_length: (int): The max length of prompt text.
Default: 77 (CLIP).
embed_dims: (int): Width of transformer model. Default: 512.
num_layers: (int): Depth of transformer. Default: 12,
num_heads: (int): Number of attention heads in transformer.
Default: 8,
mlp_ratio: (int) Ratio of mlp hidden dim to embedding dim in
transformer. Default: 4,
output_dims: (int) Dim of output text embeddings. Default: 512,
cache_feature: (bool) Whether to save class embeddings in cache.
Default: True,
cat_bg: (bool) Whether to add background embedding. Default: True.
norm_cfg (dict|None): Config for norm layer. Default: dict(type='LN')
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
dataset_name: str = None,
vocabulary: List[str] = None,
templates: str = 'vild',
total_vocab_size: int = 49408,
context_length: int = 77,
embed_dims: int = 512,
num_layers: int = 12,
num_heads: int = 8,
mlp_ratio: int = 4,
output_dims: int = 512,
cache_feature: bool = True,
cat_bg: bool = True,
norm_cfg: dict = dict(type='LN'),
init_cfg: dict = None):
super().__init__(init_cfg)
if isinstance(templates, List):
self.templates = templates
else:
self.templates = get_predefined_templates(templates)
assert dataset_name is not None or vocabulary is not None, \
"text_encoder required either 'dataset_name' or 'vocabulary'"
assert dataset_name is None or vocabulary is None, \
"there is conflict between 'dataset_name' and 'vocabulary'"
self.dataset_name = dataset_name
self.vocabulary = vocabulary
self.num_pos = context_length
self.token_embedding = nn.Embedding(total_vocab_size, embed_dims)
self.positional_embedding = nn.Parameter(
torch.empty(context_length, embed_dims))
self.text_projection = nn.Parameter(
torch.empty(embed_dims, output_dims))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.transformer = ModuleList()
self.register_buffer(
'attn_mask', self.build_attention_mask(), persistent=False)
for i in range(num_layers):
self.transformer.append(
BaseTransformerLayer(
attn_cfgs=dict(
type='MultiheadAttention',
embed_dims=embed_dims,
num_heads=num_heads,
batch_first=False,
bias=True),
ffn_cfgs=dict(
type='FFN',
embed_dims=embed_dims,
feedforward_channels=mlp_ratio * embed_dims,
act_cfg=dict(type='QuickGELU')),
operation_order=('norm', 'self_attn', 'norm', 'ffn')))
self.ln_final = build_norm_layer(
norm_cfg, embed_dims, postfix='_final')[1]
self.cache_feature = cache_feature
if self.cache_feature:
self.cache = {}
self._freeze()
self.cat_bg = cat_bg
if self.cat_bg:
self.bg_embed = nn.Parameter(
torch.randn(1, self.text_projection.shape[1]))
@property
def ln_final(self):
return getattr(self, self.final_name)
def build_attention_mask(self):
"""lazily create causal attention mask, with full attention between the
tokens.
pytorch uses additive attention mask; fill with -inf
"""
mask = torch.empty(self.num_pos, self.num_pos)
mask.fill_(float('-inf'))
mask.triu_(1) # zero out the lower diagonal
return mask
def _freeze(self):
for param in self.parameters():
param.requires_grad = False
def init_weights(self):
if self.cat_bg:
nn.init.normal_(
self.bg_embed,
std=self.bg_embed.shape[1]**-0.5,
)
if isinstance(self.init_cfg, dict) and \
self.init_cfg.get('type') == 'Pretrained_Part':
checkpoint = CheckpointLoader.load_checkpoint(
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
state_dict = checkpoint.copy()
para_prefix = 'text_encoder'
prefix_len = len(para_prefix) + 1
for k, v in checkpoint.items():
state_dict.pop(k)
if para_prefix in k:
state_dict[k[prefix_len:]] = v
load_state_dict(self, state_dict, strict=False, logger=None)
else:
super().init_weights()
@torch.no_grad()
def encode_text(self, text, normalize=False):
"""encode class token."""
embed_device = self.token_embedding.weight.device
x = self.token_embedding(
text.to(embed_device)) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
for block in self.transformer:
x = block(query=x, attn_masks=self.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
# take features from the eot embedding
# (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]),
text.argmax(dim=-1)] @ self.text_projection
return F.normalize(x, dim=-1) if normalize else x
def template_encode(self, vocabulary):
"""Prompt engineering."""
text_embed_bucket = []
for template in self.templates:
text_inputs = tokenizer.tokenize(
[template.format(noun) for noun in vocabulary])
text_embed = self.encode_text(text_inputs, normalize=True)
text_embed_bucket.append(text_embed)
text_embed = torch.stack(text_embed_bucket).mean(dim=0)
text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
return text_embed
def forward(self):
"""Forward function."""
if self.dataset_name is None: # encoding vocabulary directly
class_names = self.vocabulary
if self.cache_feature:
new_classes = [
word for word in class_names if word not in self.cache
]
if len(new_classes) > 0:
class_embeds = self.template_encode(new_classes)
self.cache.update(dict(zip(new_classes, class_embeds)))
class_embeds = torch.stack(
[self.cache[word] for word in class_names])
else:
class_embeds = self.template_encode(class_names)
else: # encoding the classes of the dataset
class_names = get_classes(self.dataset_name)
if class_names[0] == 'background':
class_names = class_names[1:]
if self.cache_feature:
if self.dataset_name not in self.cache:
class_embeds = self.template_encode(class_names)
self.cache[self.dataset_name] = class_embeds
else:
class_embeds = self.cache[self.dataset_name]
else:
class_embeds = self.template_encode(class_names)
if self.cat_bg:
class_embeds = torch.cat([class_embeds, self.bg_embed])
class_embeds = F.normalize(class_embeds, p=2, dim=-1)
return self.logit_scale.exp() * class_embeds
@MODELS.register_module()
class QuickGELU(nn.Module):
# From https://github.com/openai/CLIP/blob/main/clip/model.py
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)

View File

@ -4,6 +4,7 @@ from .embed import PatchEmbed
from .encoding import Encoding
from .inverted_residual import InvertedResidual, InvertedResidualV3
from .make_divisible import make_divisible
from .point_sample import get_uncertain_point_coords_with_randomness
from .ppm import DAPPM, PAPPM
from .res_layer import ResLayer
from .se_layer import SELayer
@ -11,11 +12,16 @@ from .self_attention_block import SelfAttentionBlock
from .shape_convert import (nchw2nlc2nchw, nchw_to_nlc, nlc2nchw2nlc,
nlc_to_nchw)
from .up_conv_block import UpConvBlock
# isort: off
from .wrappers import Upsample, resize
from .san_layers import MLP, LayerNorm2d, cross_attn_layer
__all__ = [
'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'PatchEmbed',
'nchw_to_nlc', 'nlc_to_nchw', 'nchw2nlc2nchw', 'nlc2nchw2nlc', 'Encoding',
'Upsample', 'resize', 'DAPPM', 'PAPPM', 'BasicBlock', 'Bottleneck'
'Upsample', 'resize', 'DAPPM', 'PAPPM', 'BasicBlock', 'Bottleneck',
'cross_attn_layer', 'LayerNorm2d', 'MLP',
'get_uncertain_point_coords_with_randomness'
]

View File

@ -0,0 +1,88 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.ops import point_sample
from torch import Tensor
def get_uncertainty(mask_preds: Tensor, labels: Tensor) -> Tensor:
"""Estimate uncertainty based on pred logits.
We estimate uncertainty as L1 distance between 0.0 and the logits
prediction in 'mask_preds' for the foreground class in `classes`.
Args:
mask_preds (Tensor): mask predication logits, shape (num_rois,
num_classes, mask_height, mask_width).
labels (Tensor): Either predicted or ground truth label for
each predicted mask, of length num_rois.
Returns:
scores (Tensor): Uncertainty scores with the most uncertain
locations having the highest uncertainty score,
shape (num_rois, 1, mask_height, mask_width)
"""
if mask_preds.shape[1] == 1:
gt_class_logits = mask_preds.clone()
else:
inds = torch.arange(mask_preds.shape[0], device=mask_preds.device)
gt_class_logits = mask_preds[inds, labels].unsqueeze(1)
return -torch.abs(gt_class_logits)
def get_uncertain_point_coords_with_randomness(
mask_preds: Tensor, labels: Tensor, num_points: int,
oversample_ratio: float, importance_sample_ratio: float) -> Tensor:
"""Get ``num_points`` most uncertain points with random points during
train.
Sample points in [0, 1] x [0, 1] coordinate space based on their
uncertainty. The uncertainties are calculated for each point using
'get_uncertainty()' function that takes point's logit prediction as
input.
Args:
mask_preds (Tensor): A tensor of shape (num_rois, num_classes,
mask_height, mask_width) for class-specific or class-agnostic
prediction.
labels (Tensor): The ground truth class for each instance.
num_points (int): The number of points to sample.
oversample_ratio (float): Oversampling parameter.
importance_sample_ratio (float): Ratio of points that are sampled
via importnace sampling.
Returns:
point_coords (Tensor): A tensor of shape (num_rois, num_points, 2)
that contains the coordinates sampled points.
"""
assert oversample_ratio >= 1
assert 0 <= importance_sample_ratio <= 1
batch_size = mask_preds.shape[0]
num_sampled = int(num_points * oversample_ratio)
point_coords = torch.rand(
batch_size, num_sampled, 2, device=mask_preds.device)
point_logits = point_sample(mask_preds, point_coords)
# It is crucial to calculate uncertainty based on the sampled
# prediction value for the points. Calculating uncertainties of the
# coarse predictions first and sampling them for points leads to
# incorrect results. To illustrate this: assume uncertainty func(
# logits)=-abs(logits), a sampled point between two coarse
# predictions with -1 and 1 logits has 0 logits, and therefore 0
# uncertainty value. However, if we calculate uncertainties for the
# coarse predictions first, both will have -1 uncertainty,
# and sampled point will get -1 uncertainty.
point_uncertainties = get_uncertainty(point_logits, labels)
num_uncertain_points = int(importance_sample_ratio * num_points)
num_random_points = num_points - num_uncertain_points
idx = torch.topk(
point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
shift = num_sampled * torch.arange(
batch_size, dtype=torch.long, device=mask_preds.device)
idx += shift[:, None]
point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
batch_size, num_uncertain_points, 2)
if num_random_points > 0:
rand_roi_coords = torch.rand(
batch_size, num_random_points, 2, device=mask_preds.device)
point_coords = torch.cat((point_coords, rand_roi_coords), dim=1)
return point_coords

View File

@ -0,0 +1,418 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from https://github.com/MendelXu/SAN/blob/main/san/model/attn_helper.py # noqa: E501
# Copyright (c) 2023 MendelXu.
# Licensed under the MIT License
import warnings
from typing import Optional
import torch
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
from torch import Tensor, nn
from torch.nn import functional as F
def cross_attn_with_self_bias(
query: Tensor,
key: Tensor,
value: Tensor,
embed_dim_to_check: int,
num_heads: int,
in_proj_weight: Tensor,
in_proj_bias: Tensor,
bias_k: Optional[Tensor],
bias_v: Optional[Tensor],
add_zero_attn: bool,
dropout_p: float,
out_proj_weight: Tensor,
out_proj_bias: Tensor,
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
use_separate_proj_weight: bool = False,
q_proj_weight: Optional[Tensor] = None,
k_proj_weight: Optional[Tensor] = None,
v_proj_weight: Optional[Tensor] = None,
static_k: Optional[Tensor] = None,
static_v: Optional[Tensor] = None,
):
"""Forward function of multi-head attention. Modified from
multi_head_attention_forward in
https://github.com/pytorch/pytorch/blob/main/torch/nn/functional.py.
Args:
query, key, value: map a query and a set of key-value pairs to an output.
See "Attention Is All You Need" for more details.
embed_dim_to_check: total dimension of the model.
num_heads: parallel attention heads.
in_proj_weight, in_proj_bias: input projection weight and bias.
bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
add_zero_attn: add a new batch of zeros to the key and
value sequences at dim=1.
dropout_p: probability of an element to be zeroed.
out_proj_weight, out_proj_bias: the output projection weight and bias.
training: apply dropout if is ``True``.
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. This is an binary mask. When the value is True,
the corresponding value on the attention layer will be filled with -inf.
need_weights: output attn_output_weights.
Default: `True`
Note: `needs_weight` defaults to `True`, but should be set to `False`
For best performance when attention weights are not needed.
*Setting needs_weights to `True`
leads to a significant performance degradation.*
attn_mask: 2D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
use_separate_proj_weight: the function accept the proj. weights for query, key,
and value in different forms. If false, in_proj_weight will be used, which is
a combination of q_proj_weight, k_proj_weight, v_proj_weight.
q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
static_k, static_v: static key and value used for attention operators.
""" # noqa: E501
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == embed_dim_to_check
# allow MHA to have different sizes for the feature dimension
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
head_dim = embed_dim // num_heads
assert head_dim * num_heads == embed_dim, \
'embed_dim must be divisible by num_heads'
scaling = float(head_dim)**-0.5
if not use_separate_proj_weight:
if (query is key or torch.equal(
query, key)) and (key is value or torch.equal(key, value)):
# self-attention
raise NotImplementedError('self-attention is not implemented')
elif key is value or torch.equal(key, value):
# encoder-decoder attention
# This is inline in_proj function
# with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = 0
_end = embed_dim
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = F.linear(query, _w, _b)
if key is None:
assert value is None
k = None
v = None
q_k = None
q_v = None
else:
# This is inline in_proj function with
# in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
_end = None
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
k, v = F.linear(key, _w, _b).chunk(2, dim=-1)
q_k, q_v = F.linear(query, _w, _b).chunk(2, dim=-1)
else:
# This is inline in_proj function with
# in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = 0
_end = embed_dim
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = F.linear(query, _w, _b)
# This is inline in_proj function with
# in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
_end = embed_dim * 2
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
k = F.linear(key, _w, _b)
q_k = F.linear(query, _w, _b)
# This is inline in_proj function with
# in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim * 2
_end = None
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
v = F.linear(value, _w, _b)
q_v = F.linear(query, _w, _b)
else:
q_proj_weight_non_opt = \
torch.jit._unwrap_optional(q_proj_weight)
len1, len2 = q_proj_weight_non_opt.size()
assert len1 == embed_dim and len2 == query.size(-1)
k_proj_weight_non_opt = \
torch.jit._unwrap_optional(k_proj_weight)
len1, len2 = k_proj_weight_non_opt.size()
assert len1 == embed_dim and len2 == key.size(-1)
v_proj_weight_non_opt = \
torch.jit._unwrap_optional(v_proj_weight)
len1, len2 = v_proj_weight_non_opt.size()
assert len1 == embed_dim and len2 == value.size(-1)
if in_proj_bias is not None:
q = F.linear(query, q_proj_weight_non_opt,
in_proj_bias[0:embed_dim])
k = F.linear(key, k_proj_weight_non_opt,
in_proj_bias[embed_dim:(embed_dim * 2)])
v = F.linear(value, v_proj_weight_non_opt,
in_proj_bias[(embed_dim * 2):])
else:
q = F.linear(query, q_proj_weight_non_opt, in_proj_bias)
k = F.linear(key, k_proj_weight_non_opt, in_proj_bias)
v = F.linear(value, v_proj_weight_non_opt, in_proj_bias)
q = q * scaling
if attn_mask is not None:
assert (
attn_mask.dtype == torch.float32
or attn_mask.dtype == torch.float64
or attn_mask.dtype == torch.float16
or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool
), 'Only float, byte, and bool types are supported for ' \
'attn_mask, not {}'.format(attn_mask.dtype)
if attn_mask.dtype == torch.uint8:
warnings.warn('Byte tensor for attn_mask in nn.MultiheadAttention '
'is deprecated. Use bool tensor instead.')
attn_mask = attn_mask.to(torch.bool)
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise RuntimeError(
'The size of the 2D attn_mask is not correct.')
elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [
bsz * num_heads,
query.size(0), key.size(0)
]:
raise RuntimeError(
'The size of the 3D attn_mask is not correct.')
else:
raise RuntimeError(
"attn_mask's dimension {} is not supported".format(
attn_mask.dim()))
# attn_mask's dim is 3 now.
# convert ByteTensor key_padding_mask to bool
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
warnings.warn(
'Byte tensor for key_padding_mask in nn.MultiheadAttention '
'is deprecated. Use bool tensor instead.')
key_padding_mask = key_padding_mask.to(torch.bool)
if bias_k is not None and bias_v is not None:
if static_k is None and static_v is None:
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = F.pad(attn_mask, (0, 1))
if key_padding_mask is not None:
key_padding_mask = F.pad(key_padding_mask, (0, 1))
else:
assert static_k is None, 'bias cannot be added to static key.'
assert static_v is None, 'bias cannot be added to static value.'
else:
assert bias_k is None
assert bias_v is None
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
if k is not None:
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
q_k = q_k.contiguous().view(tgt_len, bsz * num_heads,
head_dim).transpose(0, 1)
if v is not None:
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
q_v = q_v.contiguous().view(tgt_len, bsz * num_heads,
head_dim).transpose(0, 1)
if static_k is not None:
assert static_k.size(0) == bsz * num_heads
assert static_k.size(2) == head_dim
k = static_k
if static_v is not None:
assert static_v.size(0) == bsz * num_heads
assert static_v.size(2) == head_dim
v = static_v
src_len = k.size(1)
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if add_zero_attn:
src_len += 1
k = torch.cat(
[
k,
torch.zeros(
(k.size(0), 1) + k.size()[2:],
dtype=k.dtype,
device=k.device),
],
dim=1,
)
v = torch.cat(
[
v,
torch.zeros(
(v.size(0), 1) + v.size()[2:],
dtype=v.dtype,
device=v.device),
],
dim=1,
)
if attn_mask is not None:
attn_mask = F.pad(attn_mask, (0, 1))
if key_padding_mask is not None:
key_padding_mask = F.pad(key_padding_mask, (0, 1))
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
assert list(
attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_output_weights.masked_fill_(attn_mask, float('-inf'))
else:
attn_output_weights += attn_mask
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len,
src_len)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float('-inf'),
)
attn_output_weights = attn_output_weights.view(bsz * num_heads,
tgt_len, src_len)
# attn_out_weights: [bsz * num_heads, tgt_len, src_len]
# ->[bsz * num_heads, tgt_len, src_len+1]
self_weight = (q * q_k).sum(
dim=-1, keepdim=True) # [bsz * num_heads, tgt_len, 1]
total_attn_output_weights = torch.cat([attn_output_weights, self_weight],
dim=-1)
total_attn_output_weights = F.softmax(total_attn_output_weights, dim=-1)
total_attn_output_weights = F.dropout(
total_attn_output_weights, p=dropout_p, training=training)
attn_output_weights = \
total_attn_output_weights[:, :, : -1]
# [bsz * num_heads, tgt_len, src_len]
self_weight = \
total_attn_output_weights[:, :, -1:] # [bsz * num_heads, tgt_len, 1]
attn_output = torch.bmm(attn_output_weights,
v) # [bsz * num_heads, tgt_len, head_dim]
attn_output = (attn_output + self_weight * q_v
) # [bsz * num_heads, tgt_len, head_dim]
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
attn_output = attn_output.transpose(0, 1).contiguous().view(
tgt_len, bsz, embed_dim)
attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
if need_weights:
# average attention weights over heads
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len,
src_len)
return attn_output, attn_output_weights # .sum(dim=1) / num_heads
else:
return attn_output, None
def cross_attn_layer(tf_layer: BaseTransformerLayer, x, mem, attn_bias):
"""Implementation of transformer layer with cross attention. The cross
attention shares the embedding weights with self-attention of tf_layer.
Args:
tf_layer: (TransformerEncoderLayer): The Module of transformer layer.
x (Tensor): query [K,N,C]
mem (Tensor): key and value [L,N,C]
attn_bias (Tensor): attention bias [N*num_head,K,L]
Return:
x (Tensor): cross attention output [K,N,C]
"""
self_attn_layer = tf_layer.attentions[0].attn
attn_layer_paras = {
'embed_dim_to_check': self_attn_layer.embed_dim,
'num_heads': self_attn_layer.num_heads,
'in_proj_weight': self_attn_layer.in_proj_weight,
'in_proj_bias': self_attn_layer.in_proj_bias,
'bias_k': self_attn_layer.bias_k,
'bias_v': self_attn_layer.bias_v,
'add_zero_attn': self_attn_layer.add_zero_attn,
'dropout_p': self_attn_layer.dropout,
'out_proj_weight': self_attn_layer.out_proj.weight,
'out_proj_bias': self_attn_layer.out_proj.bias,
'training': self_attn_layer.training
}
q_x = tf_layer.norms[0](x)
k_x = v_x = tf_layer.norms[0](mem)
x = x + cross_attn_with_self_bias(
q_x,
k_x,
v_x,
attn_mask=attn_bias,
need_weights=False,
**attn_layer_paras)[0]
x = tf_layer.ffns[0](tf_layer.norms[1](x), identity=x)
return x
class LayerNorm2d(nn.Module):
"""A LayerNorm variant, popularized by Transformers, that performs point-
wise mean and variance normalization over the channel dimension for inputs
that have shape (batch_size, channels, height, width).
https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950
"""
def __init__(self, normalized_shape, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.normalized_shape = (normalized_shape, )
def forward(self, x: torch.Tensor):
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class MLP(nn.Module):
"""Very simple multi-layer perceptron (also called FFN)"""
def __init__(self,
input_dim,
hidden_dim,
output_dim,
num_layers,
affine_func=nn.Linear):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(
affine_func(n, k)
for n, k in zip([input_dim] + h, h + [output_dim]))
def forward(self, x: torch.Tensor):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x

View File

@ -11,23 +11,60 @@ from .class_names import (ade_classes, ade_palette, bdd100k_classes,
vaihingen_palette, voc_classes, voc_palette)
# yapf: enable
from .collect_env import collect_env
from .get_templates import get_predefined_templates
from .io import datafrombytes
from .misc import add_prefix, stack_batch
from .set_env import register_all_modules
from .tokenizer import tokenize
from .typing_utils import (ConfigType, ForwardResults, MultiConfig,
OptConfigType, OptMultiConfig, OptSampleList,
SampleList, TensorDict, TensorList)
# isort: off
from .mask_classification import MatchMasks, seg_data_to_instance_data
__all__ = [
'collect_env', 'register_all_modules', 'stack_batch', 'add_prefix',
'ConfigType', 'OptConfigType', 'MultiConfig', 'OptMultiConfig',
'SampleList', 'OptSampleList', 'TensorDict', 'TensorList',
'ForwardResults', 'cityscapes_classes', 'ade_classes', 'voc_classes',
'cocostuff_classes', 'loveda_classes', 'potsdam_classes',
'vaihingen_classes', 'isaid_classes', 'stare_classes',
'cityscapes_palette', 'ade_palette', 'voc_palette', 'cocostuff_palette',
'loveda_palette', 'potsdam_palette', 'vaihingen_palette', 'isaid_palette',
'stare_palette', 'dataset_aliases', 'get_classes', 'get_palette',
'datafrombytes', 'synapse_palette', 'synapse_classes', 'bdd100k_classes',
'bdd100k_palette'
'collect_env',
'register_all_modules',
'stack_batch',
'add_prefix',
'ConfigType',
'OptConfigType',
'MultiConfig',
'OptMultiConfig',
'SampleList',
'OptSampleList',
'TensorDict',
'TensorList',
'ForwardResults',
'cityscapes_classes',
'ade_classes',
'voc_classes',
'cocostuff_classes',
'loveda_classes',
'potsdam_classes',
'vaihingen_classes',
'isaid_classes',
'stare_classes',
'cityscapes_palette',
'ade_palette',
'voc_palette',
'cocostuff_palette',
'loveda_palette',
'potsdam_palette',
'vaihingen_palette',
'isaid_palette',
'stare_palette',
'dataset_aliases',
'get_classes',
'get_palette',
'datafrombytes',
'synapse_palette',
'synapse_classes',
'get_predefined_templates',
'tokenize',
'seg_data_to_instance_data',
'MatchMasks',
'bdd100k_classes',
'bdd100k_palette',
]

Binary file not shown.

View File

@ -52,6 +52,21 @@ def voc_classes():
]
def pcontext_classes():
"""Pascal Context class names for external use."""
return [
'aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle', 'bird',
'boat', 'book', 'bottle', 'building', 'bus', 'cabinet', 'car', 'cat',
'ceiling', 'chair', 'cloth', 'computer', 'cow', 'cup', 'curtain',
'dog', 'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground',
'horse', 'keyboard', 'light', 'motorbike', 'mountain', 'mouse',
'person', 'plate', 'platform', 'pottedplant', 'road', 'rock', 'sheep',
'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table', 'track',
'train', 'tree', 'truck', 'tvmonitor', 'wall', 'water', 'window',
'wood'
]
def cocostuff_classes():
"""CocoStuff class names for external use."""
return [
@ -306,6 +321,25 @@ def voc_palette():
[128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
def pcontext_palette():
"""Pascal Context palette for external use."""
return [[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3],
[120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230],
[4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61],
[120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140],
[204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200],
[61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71],
[255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92],
[112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6],
[10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8],
[102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8],
[0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255],
[235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140],
[250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0],
[255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0],
[0, 235, 255], [0, 173, 255], [31, 0, 255]]
def cocostuff_palette():
"""CocoStuff palette for external use."""
return [[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192],
@ -443,6 +477,7 @@ dataset_aliases = {
'cityscapes': ['cityscapes'],
'ade': ['ade', 'ade20k'],
'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug'],
'pcontext': ['pcontext', 'pascal_context', 'voc2010'],
'loveda': ['loveda'],
'potsdam': ['potsdam'],
'vaihingen': ['vaihingen'],

View File

@ -0,0 +1,109 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
PREDEFINED_TEMPLATES = {
'imagenet': [
'a bad photo of a {}.',
'a photo of many {}.',
'a sculpture of a {}.',
'a photo of the hard to see {}.',
'a low resolution photo of the {}.',
'a rendering of a {}.',
'graffiti of a {}.',
'a bad photo of the {}.',
'a cropped photo of the {}.',
'a tattoo of a {}.',
'the embroidered {}.',
'a photo of a hard to see {}.',
'a bright photo of a {}.',
'a photo of a clean {}.',
'a photo of a dirty {}.',
'a dark photo of the {}.',
'a drawing of a {}.',
'a photo of my {}.',
'the plastic {}.',
'a photo of the cool {}.',
'a close-up photo of a {}.',
'a black and white photo of the {}.',
'a painting of the {}.',
'a painting of a {}.',
'a pixelated photo of the {}.',
'a sculpture of the {}.',
'a bright photo of the {}.',
'a cropped photo of a {}.',
'a plastic {}.',
'a photo of the dirty {}.',
'a jpeg corrupted photo of a {}.',
'a blurry photo of the {}.',
'a photo of the {}.',
'a good photo of the {}.',
'a rendering of the {}.',
'a {} in a video game.',
'a photo of one {}.',
'a doodle of a {}.',
'a close-up photo of the {}.',
'a photo of a {}.',
'the origami {}.',
'the {} in a video game.',
'a sketch of a {}.',
'a doodle of the {}.',
'a origami {}.',
'a low resolution photo of a {}.',
'the toy {}.',
'a rendition of the {}.',
'a photo of the clean {}.',
'a photo of a large {}.',
'a rendition of a {}.',
'a photo of a nice {}.',
'a photo of a weird {}.',
'a blurry photo of a {}.',
'a cartoon {}.',
'art of a {}.',
'a sketch of the {}.',
'a embroidered {}.',
'a pixelated photo of a {}.',
'itap of the {}.',
'a jpeg corrupted photo of the {}.',
'a good photo of a {}.',
'a plushie {}.',
'a photo of the nice {}.',
'a photo of the small {}.',
'a photo of the weird {}.',
'the cartoon {}.',
'art of the {}.',
'a drawing of the {}.',
'a photo of the large {}.',
'a black and white photo of a {}.',
'the plushie {}.',
'a dark photo of a {}.',
'itap of a {}.',
'graffiti of the {}.',
'a toy {}.',
'itap of my {}.',
'a photo of a cool {}.',
'a photo of a small {}.',
'a tattoo of the {}.',
],
'vild': [
'a photo of a {}.',
'This is a photo of a {}',
'There is a {} in the scene',
'There is the {} in the scene',
'a photo of a {} in the scene',
'a photo of a small {}.',
'a photo of a medium {}.',
'a photo of a large {}.',
'This is a photo of a small {}.',
'This is a photo of a medium {}.',
'This is a photo of a large {}.',
'There is a small {} in the scene.',
'There is a medium {} in the scene.',
'There is a large {} in the scene.',
],
}
def get_predefined_templates(template_set_name: str) -> List[str]:
if template_set_name not in PREDEFINED_TEMPLATES:
raise ValueError(f'Template set {template_set_name} not found')
return PREDEFINED_TEMPLATES[template_set_name]

View File

@ -0,0 +1,205 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple
import torch
from mmcv.ops import point_sample
from mmengine.structures import InstanceData
from torch import Tensor
from mmseg.registry import TASK_UTILS
from mmseg.utils import ConfigType, SampleList
def seg_data_to_instance_data(ignore_index: int,
batch_data_samples: SampleList):
"""Convert the paradigm of ground truth from semantic segmentation to
instance segmentation.
Args:
ignore_index (int): The label index to be ignored.
batch_data_samples (List[SegDataSample]): The Data
Samples. It usually includes information such as
`gt_sem_seg`.
Returns:
tuple[Tensor]: A tuple contains two lists.
- batch_gt_instances (List[InstanceData]): Batch of
gt_instance. It usually includes ``labels``, each is
unique ground truth label id of images, with
shape (num_gt, ) and ``masks``, each is ground truth
masks of each instances of a image, shape (num_gt, h, w).
- batch_img_metas (List[Dict]): List of image meta information.
"""
batch_gt_instances = []
for data_sample in batch_data_samples:
gt_sem_seg = data_sample.gt_sem_seg.data
classes = torch.unique(
gt_sem_seg,
sorted=False,
return_inverse=False,
return_counts=False)
# remove ignored region
gt_labels = classes[classes != ignore_index]
masks = []
for class_id in gt_labels:
masks.append(gt_sem_seg == class_id)
if len(masks) == 0:
gt_masks = torch.zeros(
(0, gt_sem_seg.shape[-2],
gt_sem_seg.shape[-1])).to(gt_sem_seg).long()
else:
gt_masks = torch.stack(masks).squeeze(1).long()
instance_data = InstanceData(labels=gt_labels, masks=gt_masks)
batch_gt_instances.append(instance_data)
return batch_gt_instances
class MatchMasks:
"""Match the predictions to category labels.
Args:
num_points (int): the number of sampled points to compute cost.
num_queries (int): the number of prediction masks.
num_classes (int): the number of classes.
assigner (BaseAssigner): the assigner to compute matching.
"""
def __init__(self,
num_points: int,
num_queries: int,
num_classes: int,
assigner: ConfigType = None):
assert assigner is not None, "\'assigner\' in decode_head.train_cfg" \
'cannot be None'
assert num_points > 0, 'num_points should be a positive integer.'
self.num_points = num_points
self.num_queries = num_queries
self.num_classes = num_classes
self.assigner = TASK_UTILS.build(assigner)
def get_targets(self, cls_scores: List[Tensor], mask_preds: List[Tensor],
batch_gt_instances: List[InstanceData]) -> Tuple:
"""Compute best mask matches for all images for a decoder layer.
Args:
cls_scores (List[Tensor]): Mask score logits from a single
decoder layer for all images. Each with shape (num_queries,
cls_out_channels).
mask_preds (List[Tensor]): Mask logits from a single decoder
layer for all images. Each with shape (num_queries, h, w).
batch_gt_instances (List[InstanceData]): each contains
``labels`` and ``masks``.
Returns:
tuple: a tuple containing the following targets.
- labels (List[Tensor]): Labels of all images.\
Each with shape (num_queries, ).
- mask_targets (List[Tensor]): Mask targets of\
all images. Each with shape (num_queries, h, w).
- mask_weights (List[Tensor]): Mask weights of\
all images. Each with shape (num_queries, ).
- avg_factor (int): Average factor that is used to
average the loss. `avg_factor` is usually equal
to the number of positive priors.
"""
batch_size = cls_scores.shape[0]
results = dict({
'labels': [],
'mask_targets': [],
'mask_weights': [],
})
for i in range(batch_size):
labels, mask_targets, mask_weights\
= self._get_targets_single(cls_scores[i],
mask_preds[i],
batch_gt_instances[i])
results['labels'].append(labels)
results['mask_targets'].append(mask_targets)
results['mask_weights'].append(mask_weights)
# shape (batch_size, num_queries)
labels = torch.stack(results['labels'], dim=0)
# shape (batch_size, num_gts, h, w)
mask_targets = torch.cat(results['mask_targets'], dim=0)
# shape (batch_size, num_queries)
mask_weights = torch.stack(results['mask_weights'], dim=0)
avg_factor = sum(
[len(gt_instances.labels) for gt_instances in batch_gt_instances])
res = (labels, mask_targets, mask_weights, avg_factor)
return res
def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor,
gt_instances: InstanceData) \
-> Tuple[Tensor, Tensor, Tensor]:
"""Compute a set of best mask matches for one image.
Args:
cls_score (Tensor): Mask score logits from a single decoder layer
for one image. Shape (num_queries, cls_out_channels).
mask_pred (Tensor): Mask logits for a single decoder layer for one
image. Shape (num_queries, h, w).
gt_instances (:obj:`InstanceData`): It contains ``labels`` and
``masks``.
Returns:
tuple[Tensor]: A tuple containing the following for one image.
- labels (Tensor): Labels of each image. \
shape (num_queries, ).
- mask_targets (Tensor): Mask targets of each image. \
shape (num_queries, h, w).
- mask_weights (Tensor): Mask weights of each image. \
shape (num_queries, ).
"""
gt_labels = gt_instances.labels
gt_masks = gt_instances.masks
# when "gt_labels" is empty, classify all queries to background
if len(gt_labels) == 0:
labels = gt_labels.new_full((self.num_queries, ),
self.num_classes,
dtype=torch.long)
mask_targets = gt_labels
mask_weights = gt_labels.new_zeros((self.num_queries, ))
return labels, mask_targets, mask_weights
# sample points
num_queries = cls_score.shape[0]
num_gts = gt_labels.shape[0]
point_coords = torch.rand((1, self.num_points, 2),
device=cls_score.device)
# shape (num_queries, num_points)
mask_points_pred = point_sample(
mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1,
1)).squeeze(1)
# shape (num_gts, num_points)
gt_points_masks = point_sample(
gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1,
1)).squeeze(1)
sampled_gt_instances = InstanceData(
labels=gt_labels, masks=gt_points_masks)
sampled_pred_instances = InstanceData(
scores=cls_score, masks=mask_points_pred)
# assign and sample
matched_quiery_inds, matched_label_inds = self.assigner.assign(
pred_instances=sampled_pred_instances,
gt_instances=sampled_gt_instances)
labels = gt_labels.new_full((self.num_queries, ),
self.num_classes,
dtype=torch.long)
labels[matched_quiery_inds] = gt_labels[matched_label_inds]
mask_weights = gt_labels.new_zeros((self.num_queries, ))
mask_weights[matched_quiery_inds] = 1
mask_targets = gt_masks[matched_label_inds]
return labels, mask_targets, mask_weights

View File

@ -0,0 +1,240 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""CLIP tokenizer.
Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright
(c) 2021 OpenAI.
"""
import gzip
import html
import os
from functools import lru_cache
from typing import List, Union
import ftfy
import regex as re
import torch
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
@lru_cache()
def default_bpe():
return os.path.join(
os.path.dirname(os.path.abspath(__file__)),
'bpe_simple_vocab_16e6.txt.gz')
@lru_cache()
def bytes_to_unicode():
"""Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings. This means you need a
large # of unicode characters in your vocab if you want to avoid UNKs. When
you're at something like a 10B token dataset you end up needing around 5K
for decent coverage. This is a significant percentage of your normal, say,
32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and
unicode strings. And avoids mapping to whitespace/control characters the
bpe code barfs on.
"""
bs = list(range(ord('!'),
ord('~') + 1)) + list(range(
ord('¡'),
ord('¬') + 1)) + list(range(ord('®'),
ord('ÿ') + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length
strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
class SimpleTokenizer:
def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode('utf-8').split('\n')
merges = merges[1:49152 - 256 - 2 + 1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v + '</w>' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
if not special_tokens:
special_tokens = ['<start_of_text>', '<end_of_text>']
else:
special_tokens = ['<start_of_text>', '<end_of_text>'
] + special_tokens
vocab.extend(special_tokens)
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {t: t for t in special_tokens}
special = '|'.join(special_tokens)
self.pat = re.compile(
special +
r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
re.IGNORECASE)
self.vocab_size = len(self.encoder)
self.all_special_ids = [self.encoder[t] for t in special_tokens]
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + (token[-1] + '</w>', )
pairs = get_pairs(word)
if not pairs:
return token + '</w>'
while True:
bigram = min(
pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except: # noqa: E722, E261
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[
i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b]
for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token]
for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode(
'utf-8', errors='replace').replace('</w>', ' ')
return text
_tokenizer = SimpleTokenizer()
def decode(output_ids: torch.Tensor):
output_ids = output_ids.cpu().numpy()
return _tokenizer.decode(output_ids)
def tokenize(texts: Union[str, List[str]],
context_length: int = 77) -> torch.LongTensor:
"""Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all CLIP models use 77 as the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens,
shape = [number of input strings, context_length]
"""
if isinstance(texts, str):
texts = [texts]
sot_token = _tokenizer.encoder['<start_of_text>']
eot_token = _tokenizer.encoder['<end_of_text>']
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token]
for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
tokens = tokens[:context_length] # Truncate
tokens[-1] = eot_token
result[i, :len(tokens)] = torch.tensor(tokens)
return result
class HFTokenizer:
"""HuggingFace tokenizer wrapper."""
def __init__(self, tokenizer_name: str):
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
def save_pretrained(self, dest):
self.tokenizer.save_pretrained(dest)
def __call__(self,
texts: Union[str, List[str]],
context_length: int = 77) -> torch.Tensor:
# same cleaning as for default tokenizer, except lowercasing
# adding lower (for case-sensitive tokenizers) will make it
# more robust but less sensitive to nuance
if isinstance(texts, str):
texts = [texts]
texts = [whitespace_clean(basic_clean(text)) for text in texts]
input_ids = self.tokenizer(
texts,
return_tensors='pt',
max_length=context_length,
padding='max_length',
truncation=True,
).input_ids
return input_ids

View File

@ -0,0 +1,2 @@
ftfy
regex

View File

@ -1,6 +1,8 @@
codecov
flake8
ftfy
interrogate
pytest
regex
xdoctest>=0.10.0
yapf

View File

@ -194,6 +194,7 @@ if __name__ == '__main__':
'tests': parse_requirements('requirements/tests.txt'),
'optional': parse_requirements('requirements/optional.txt'),
'mim': parse_requirements('requirements/mminstall.txt'),
'multimodal': parse_requirements('requirements/multimodal.txt'),
},
ext_modules=[],
zip_safe=False)

View File

@ -0,0 +1,77 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmengine.structures import InstanceData
from mmseg.models.assigners import HungarianAssigner
class TestHungarianAssigner(TestCase):
def test_init(self):
with self.assertRaises(AssertionError):
HungarianAssigner([])
def test_hungarian_match_assigner(self):
assigner = HungarianAssigner([
dict(type='ClassificationCost', weight=2.0),
dict(type='CrossEntropyLossCost', weight=5.0, use_sigmoid=True),
dict(type='DiceCost', weight=5.0, pred_act=True, eps=1.0)
])
num_classes = 3
num_masks = 10
num_points = 20
gt_instances = InstanceData()
gt_instances.labels = torch.randint(0, num_classes, (num_classes, ))
gt_instances.masks = torch.randint(0, 2, (num_classes, num_points))
pred_instances = InstanceData()
pred_instances.scores = torch.rand((num_masks, num_classes))
pred_instances.masks = torch.rand((num_masks, num_points))
matched_quiery_inds, matched_label_inds = \
assigner.assign(pred_instances, gt_instances)
unique_quiery_inds = torch.unique(matched_quiery_inds)
unique_label_inds = torch.unique(matched_label_inds)
self.assertTrue(len(unique_quiery_inds) == len(matched_quiery_inds))
self.assertTrue(
torch.equal(unique_label_inds, torch.arange(0, num_classes)))
def test_cls_match_cost(self):
num_classes = 3
num_masks = 10
gt_instances = InstanceData()
gt_instances.labels = torch.randint(0, num_classes, (num_classes, ))
pred_instances = InstanceData()
pred_instances.scores = torch.rand((num_masks, num_classes))
# test ClassificationCost
assigner = HungarianAssigner(dict(type='ClassificationCost'))
matched_quiery_inds, matched_label_inds = \
assigner.assign(pred_instances, gt_instances)
unique_quiery_inds = torch.unique(matched_quiery_inds)
unique_label_inds = torch.unique(matched_label_inds)
self.assertTrue(len(unique_quiery_inds) == len(matched_quiery_inds))
self.assertTrue(
torch.equal(unique_label_inds, torch.arange(0, num_classes)))
def test_mask_match_cost(self):
num_classes = 3
num_masks = 10
num_points = 20
gt_instances = InstanceData()
gt_instances.masks = torch.randint(0, 2, (num_classes, num_points))
pred_instances = InstanceData()
pred_instances.masks = torch.rand((num_masks, num_points))
# test DiceCost
assigner = HungarianAssigner(
dict(type='DiceCost', pred_act=True, eps=1.0))
assign_result = assigner.assign(pred_instances, gt_instances)
self.assertTrue(len(assign_result[0]) == len(assign_result[1]))
# test CrossEntropyLossCost
assigner = HungarianAssigner(
dict(type='CrossEntropyLossCost', use_sigmoid=True))
assign_result = assigner.assign(pred_instances, gt_instances)
self.assertTrue(len(assign_result[0]) == len(assign_result[1]))

View File

@ -0,0 +1,43 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine import Config
from mmengine.registry import init_default_scope
from mmseg.models.text_encoder import CLIPTextEncoder
from mmseg.utils import get_classes
def test_clip_text_encoder():
init_default_scope('mmseg')
# test vocabulary
output_dims = 8
embed_dims = 32
vocabulary = ['cat', 'dog', 'bird', 'car', 'bike']
cfg = dict(
vocabulary=vocabulary,
templates=['a photo of a {}.'],
embed_dims=embed_dims,
output_dims=output_dims)
cfg = Config(cfg)
text_encoder = CLIPTextEncoder(**cfg)
if torch.cuda.is_available():
text_encoder = text_encoder.cuda()
with torch.no_grad():
class_embeds = text_encoder()
assert class_embeds.shape == (len(vocabulary) + 1, output_dims)
# test dataset name
cfg = dict(
dataset_name='vaihingen',
templates=['a photo of a {}.'],
embed_dims=embed_dims,
output_dims=output_dims)
cfg = Config(cfg)
text_encoder = CLIPTextEncoder(**cfg)
with torch.no_grad():
class_embeds = text_encoder()
class_nums = len(get_classes('vaihingen'))
assert class_embeds.shape == (class_nums + 1, output_dims)

View File

@ -0,0 +1,126 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine import Config
from mmengine.structures import PixelData
from mmseg.models.decode_heads import SideAdapterCLIPHead
from mmseg.structures import SegDataSample
from .utils import list_to_cuda
def test_san_head():
H, W = (64, 64)
clip_channels = 64
img_channels = 4
num_queries = 40
out_dims = 64
num_classes = 19
cfg = dict(
num_classes=num_classes,
deep_supervision_idxs=[4],
san_cfg=dict(
in_channels=img_channels,
embed_dims=128,
clip_channels=clip_channels,
num_queries=num_queries,
cfg_encoder=dict(num_encode_layer=4, mlp_ratio=2, num_heads=2),
cfg_decoder=dict(
num_heads=4,
num_layers=1,
embed_channels=32,
mlp_channels=32,
num_mlp=2,
rescale=True)),
maskgen_cfg=dict(
sos_token_num=num_queries,
embed_dims=clip_channels,
out_dims=out_dims,
num_heads=4,
mlp_ratio=2),
train_cfg=dict(
num_points=100,
oversample_ratio=3.0,
importance_sample_ratio=0.75,
assigner=dict(
type='HungarianAssigner',
match_costs=[
dict(type='ClassificationCost', weight=2.0),
dict(
type='CrossEntropyLossCost',
weight=5.0,
use_sigmoid=True),
dict(type='DiceCost', weight=5.0, pred_act=True, eps=1.0)
])),
loss_decode=[
dict(
type='CrossEntropyLoss',
loss_name='loss_cls_ce',
loss_weight=2.0,
class_weight=[1.0] * num_classes + [0.1]),
dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_name='loss_mask_ce',
loss_weight=5.0),
dict(
type='DiceLoss',
ignore_index=None,
naive_dice=True,
eps=1,
loss_name='loss_mask_dice',
loss_weight=5.0)
])
cfg = Config(cfg)
head = SideAdapterCLIPHead(**cfg)
inputs = torch.rand((2, img_channels, H, W))
clip_feature = [[
torch.rand((2, clip_channels, H // 2, W // 2)),
torch.rand((2, clip_channels))
],
[
torch.rand((2, clip_channels, H // 2, W // 2)),
torch.rand((2, clip_channels))
],
[
torch.rand((2, clip_channels, H // 2, W // 2)),
torch.rand((2, clip_channels))
],
[
torch.rand((2, clip_channels, H // 2, W // 2)),
torch.rand((2, clip_channels))
]]
class_embed = torch.rand((num_classes + 1, out_dims))
data_samples = []
for i in range(2):
data_sample = SegDataSample()
img_meta = {}
img_meta['img_shape'] = (H, W)
img_meta['ori_shape'] = (H, W)
data_sample.gt_sem_seg = PixelData(
data=torch.randint(0, num_classes, (1, H, W)))
data_sample.set_metainfo(img_meta)
data_samples.append(data_sample)
batch_img_metas = []
for data_sample in data_samples:
batch_img_metas.append(data_sample.metainfo)
if torch.cuda.is_available():
head = head.cuda()
data = list_to_cuda([inputs, clip_feature, class_embed])
for data_sample in data_samples:
data_sample.gt_sem_seg.data = data_sample.gt_sem_seg.data.cuda()
else:
data = [inputs, clip_feature, class_embed]
# loss test
loss_dict = head.loss(data, data_samples, None)
assert isinstance(loss_dict, dict)
# prediction test
with torch.no_grad():
seg_logits = head.predict(data, batch_img_metas, None)
assert seg_logits.shape == torch.Size((2, num_classes, H, W))

View File

@ -20,3 +20,12 @@ def to_cuda(module, data):
for i in range(len(data)):
data[i] = data[i].cuda()
return module, data
def list_to_cuda(data):
if isinstance(data, list):
for i in range(len(data)):
data[i] = list_to_cuda(data[i])
return data
else:
return data.cuda()

View File

@ -0,0 +1,24 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine import ConfigDict
from mmseg.models import build_segmentor
from tests.test_models.test_segmentors.utils import \
_segmentor_forward_train_test
def test_multimodal_encoder_decoder():
cfg = ConfigDict(
type='MultimodalEncoderDecoder',
asymetric_input=False,
image_encoder=dict(type='ExampleBackbone', out_indices=[1, 2, 3, 4]),
text_encoder=dict(
type='ExampleTextEncoder',
vocabulary=['A', 'B', 'C'],
output_dims=3),
decode_head=dict(
type='ExampleDecodeHead', out_channels=1, num_classes=2),
train_cfg=None,
test_cfg=dict(mode='whole'))
segmentor = build_segmentor(cfg)
_segmentor_forward_train_test(segmentor)

View File

@ -52,15 +52,22 @@ def _demo_mm_inputs(input_shape=(1, 3, 8, 16), num_classes=10):
@MODELS.register_module()
class ExampleBackbone(nn.Module):
def __init__(self):
def __init__(self, out_indices=None):
super().__init__()
self.conv = nn.Conv2d(3, 3, 3)
self.out_indices = out_indices
def init_weights(self, pretrained=None):
pass
def forward(self, x):
return [self.conv(x)]
if self.out_indices is None:
return [self.conv(x)]
else:
outs = []
for i in self.out_indices:
outs.append(self.conv(x))
return outs
@MODELS.register_module()
@ -74,6 +81,18 @@ class ExampleDecodeHead(BaseDecodeHead):
return self.cls_seg(inputs[0])
@MODELS.register_module()
class ExampleTextEncoder(nn.Module):
def __init__(self, vocabulary=None, output_dims=None):
super().__init__()
self.vocabulary = vocabulary
self.output_dims = output_dims
def forward(self):
return torch.randn((len(self.vocabulary), self.output_dims))
@MODELS.register_module()
class ExampleCascadeDecodeHead(BaseCascadeDecodeHead):
@ -132,3 +151,32 @@ def _segmentor_forward_train_test(segmentor):
data_batch = dict(inputs=imgs, data_samples=data_samples)
results = segmentor.forward(imgs, data_samples, mode='tensor')
assert isinstance(results, torch.Tensor)
def _segmentor_predict(segmentor):
if isinstance(segmentor.decode_head, nn.ModuleList):
num_classes = segmentor.decode_head[-1].num_classes
else:
num_classes = segmentor.decode_head.num_classes
# batch_size=2 for BatchNorm
mm_inputs = _demo_mm_inputs(num_classes=num_classes)
# convert to cuda Tensor if applicable
if torch.cuda.is_available():
segmentor = segmentor.cuda()
# check data preprocessor
if not hasattr(segmentor,
'data_preprocessor') or segmentor.data_preprocessor is None:
segmentor.data_preprocessor = SegDataPreProcessor()
mm_inputs = segmentor.data_preprocessor(mm_inputs, True)
imgs = mm_inputs.pop('imgs')
data_samples = mm_inputs.pop('data_samples')
# Test predict
with torch.no_grad():
segmentor.eval()
data_batch = dict(inputs=imgs, data_samples=data_samples)
outputs = segmentor.predict(**data_batch)
assert isinstance(outputs, list)

View File

@ -0,0 +1,163 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_vitlayer(paras):
new_para_name = ''
if paras[0] == 'ln_1':
new_para_name = '.'.join(['ln1'] + paras[1:])
elif paras[0] == 'attn':
new_para_name = '.'.join(['attn.attn'] + paras[1:])
elif paras[0] == 'ln_2':
new_para_name = '.'.join(['ln2'] + paras[1:])
elif paras[0] == 'mlp':
if paras[1] == 'c_fc':
new_para_name = '.'.join(['ffn.layers.0.0'] + paras[-1:])
else:
new_para_name = '.'.join(['ffn.layers.1'] + paras[-1:])
else:
print(f'Wrong for {paras}')
return new_para_name
def convert_translayer(paras):
new_para_name = ''
if paras[0] == 'attn':
new_para_name = '.'.join(['attentions.0.attn'] + paras[1:])
elif paras[0] == 'ln_1':
new_para_name = '.'.join(['norms.0'] + paras[1:])
elif paras[0] == 'ln_2':
new_para_name = '.'.join(['norms.1'] + paras[1:])
elif paras[0] == 'mlp':
if paras[1] == 'c_fc':
new_para_name = '.'.join(['ffns.0.layers.0.0'] + paras[2:])
elif paras[1] == 'c_proj':
new_para_name = '.'.join(['ffns.0.layers.1'] + paras[2:])
else:
print(f'Wrong for {paras}')
else:
print(f'Wrong for {paras}')
return new_para_name
def convert_key_name(ckpt, visual_split):
new_ckpt = OrderedDict()
for k, v in ckpt.items():
key_list = k.split('.')
if key_list[0] == 'visual':
new_transform_name = 'image_encoder'
if key_list[1] == 'class_embedding':
new_name = '.'.join([new_transform_name, 'cls_token'])
elif key_list[1] == 'positional_embedding':
new_name = '.'.join([new_transform_name, 'pos_embed'])
elif key_list[1] == 'conv1':
new_name = '.'.join([
new_transform_name, 'patch_embed.projection', key_list[2]
])
elif key_list[1] == 'ln_pre':
new_name = '.'.join(
[new_transform_name, key_list[1], key_list[2]])
elif key_list[1] == 'transformer':
new_layer_name = 'layers'
layer_index = key_list[3]
paras = key_list[4:]
if int(layer_index) < visual_split:
new_para_name = convert_vitlayer(paras)
new_name = '.'.join([
new_transform_name, new_layer_name, layer_index,
new_para_name
])
else:
new_para_name = convert_translayer(paras)
new_transform_name = 'decode_head.rec_with_attnbias'
new_layer_name = 'layers'
layer_index = str(int(layer_index) - visual_split)
new_name = '.'.join([
new_transform_name, new_layer_name, layer_index,
new_para_name
])
elif key_list[1] == 'proj':
new_name = 'decode_head.rec_with_attnbias.proj.weight'
elif key_list[1] == 'ln_post':
new_name = k.replace('visual', 'decode_head.rec_with_attnbias')
else:
print(f'pop parameter: {k}')
continue
else:
text_encoder_name = 'text_encoder'
if key_list[0] == 'transformer':
layer_name = 'transformer'
layer_index = key_list[2]
paras = key_list[3:]
new_para_name = convert_translayer(paras)
new_name = '.'.join([
text_encoder_name, layer_name, layer_index, new_para_name
])
elif key_list[0] in [
'positional_embedding', 'text_projection', 'bg_embed',
'attn_mask', 'logit_scale', 'token_embedding', 'ln_final'
]:
new_name = 'text_encoder.' + k
else:
print(f'pop parameter: {k}')
continue
new_ckpt[new_name] = v
return new_ckpt
def convert_tensor(ckpt):
cls_token = ckpt['image_encoder.cls_token']
new_cls_token = cls_token.unsqueeze(0).unsqueeze(0)
ckpt['image_encoder.cls_token'] = new_cls_token
pos_embed = ckpt['image_encoder.pos_embed']
new_pos_embed = pos_embed.unsqueeze(0)
ckpt['image_encoder.pos_embed'] = new_pos_embed
proj_weight = ckpt['decode_head.rec_with_attnbias.proj.weight']
new_proj_weight = proj_weight.transpose(1, 0)
ckpt['decode_head.rec_with_attnbias.proj.weight'] = new_proj_weight
return ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in timm pretrained vit models to '
'MMSegmentation style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
if any([s in args.src for s in ['B-16', 'b16', 'base_patch16']]):
visual_split = 9
elif any([s in args.src for s in ['L-14', 'l14', 'large_patch14']]):
visual_split = 18
else:
print('Make sure the clip model is ViT-B/16 or ViT-L/14!')
visual_split = -1
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if isinstance(checkpoint, torch.jit.RecursiveScriptModule):
state_dict = checkpoint.state_dict()
else:
if 'state_dict' in checkpoint:
# timm checkpoint
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
# deit checkpoint
state_dict = checkpoint['model']
else:
state_dict = checkpoint
weight = convert_key_name(state_dict, visual_split)
weight = convert_tensor(weight)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,220 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_key_name(ckpt):
new_ckpt = OrderedDict()
for k, v in ckpt.items():
key_list = k.split('.')
if key_list[0] == 'clip_visual_extractor':
new_transform_name = 'image_encoder'
if key_list[1] == 'class_embedding':
new_name = '.'.join([new_transform_name, 'cls_token'])
elif key_list[1] == 'positional_embedding':
new_name = '.'.join([new_transform_name, 'pos_embed'])
elif key_list[1] == 'conv1':
new_name = '.'.join([
new_transform_name, 'patch_embed.projection', key_list[2]
])
elif key_list[1] == 'ln_pre':
new_name = '.'.join(
[new_transform_name, key_list[1], key_list[2]])
elif key_list[1] == 'resblocks':
new_layer_name = 'layers'
layer_index = key_list[2]
paras = key_list[3:]
if paras[0] == 'ln_1':
new_para_name = '.'.join(['ln1'] + key_list[4:])
elif paras[0] == 'attn':
new_para_name = '.'.join(['attn.attn'] + key_list[4:])
elif paras[0] == 'ln_2':
new_para_name = '.'.join(['ln2'] + key_list[4:])
elif paras[0] == 'mlp':
if paras[1] == 'c_fc':
new_para_name = '.'.join(['ffn.layers.0.0'] +
key_list[-1:])
else:
new_para_name = '.'.join(['ffn.layers.1'] +
key_list[-1:])
new_name = '.'.join([
new_transform_name, new_layer_name, layer_index,
new_para_name
])
elif key_list[0] == 'side_adapter_network':
decode_head_name = 'decode_head'
module_name = 'side_adapter_network'
if key_list[1] == 'vit_model':
if key_list[2] == 'blocks':
layer_name = 'encode_layers'
layer_index = key_list[3]
paras = key_list[4:]
if paras[0] == 'norm1':
new_para_name = '.'.join(['ln1'] + key_list[5:])
elif paras[0] == 'attn':
new_para_name = '.'.join(key_list[4:])
new_para_name = new_para_name.replace(
'attn.qkv.', 'attn.attn.in_proj_')
new_para_name = new_para_name.replace(
'attn.proj', 'attn.attn.out_proj')
elif paras[0] == 'norm2':
new_para_name = '.'.join(['ln2'] + key_list[5:])
elif paras[0] == 'mlp':
new_para_name = '.'.join(['ffn'] + key_list[5:])
new_para_name = new_para_name.replace(
'fc1', 'layers.0.0')
new_para_name = new_para_name.replace(
'fc2', 'layers.1')
else:
print(f'Wrong for {k}')
new_name = '.'.join([
decode_head_name, module_name, layer_name, layer_index,
new_para_name
])
elif key_list[2] == 'pos_embed':
new_name = '.'.join(
[decode_head_name, module_name, 'pos_embed'])
elif key_list[2] == 'patch_embed':
new_name = '.'.join([
decode_head_name, module_name, 'patch_embed',
'projection', key_list[4]
])
else:
print(f'Wrong for {k}')
elif key_list[1] == 'query_embed' or key_list[
1] == 'query_pos_embed':
new_name = '.'.join(
[decode_head_name, module_name, key_list[1]])
elif key_list[1] == 'fusion_layers':
layer_name = 'conv_clips'
layer_index = key_list[2][-1]
paras = '.'.join(key_list[3:])
new_para_name = paras.replace('input_proj.0', '0')
new_para_name = new_para_name.replace('input_proj.1', '1.conv')
new_name = '.'.join([
decode_head_name, module_name, layer_name, layer_index,
new_para_name
])
elif key_list[1] == 'mask_decoder':
new_name = 'decode_head.' + k
else:
print(f'Wrong for {k}')
elif key_list[0] == 'clip_rec_head':
module_name = 'rec_with_attnbias'
if key_list[1] == 'proj':
new_name = '.'.join(
[decode_head_name, module_name, 'proj.weight'])
elif key_list[1] == 'ln_post':
new_name = '.'.join(
[decode_head_name, module_name, 'ln_post', key_list[2]])
elif key_list[1] == 'resblocks':
new_layer_name = 'layers'
layer_index = key_list[2]
paras = key_list[3:]
if paras[0] == 'ln_1':
new_para_name = '.'.join(['norms.0'] + paras[1:])
elif paras[0] == 'attn':
new_para_name = '.'.join(['attentions.0.attn'] + paras[1:])
elif paras[0] == 'ln_2':
new_para_name = '.'.join(['norms.1'] + paras[1:])
elif paras[0] == 'mlp':
if paras[1] == 'c_fc':
new_para_name = '.'.join(['ffns.0.layers.0.0'] +
paras[2:])
elif paras[1] == 'c_proj':
new_para_name = '.'.join(['ffns.0.layers.1'] +
paras[2:])
else:
print(f'Wrong for {k}')
new_name = '.'.join([
decode_head_name, module_name, new_layer_name, layer_index,
new_para_name
])
else:
print(f'Wrong for {k}')
elif key_list[0] == 'ov_classifier':
text_encoder_name = 'text_encoder'
if key_list[1] == 'transformer':
layer_name = 'transformer'
layer_index = key_list[3]
paras = key_list[4:]
if paras[0] == 'attn':
new_para_name = '.'.join(['attentions.0.attn'] + paras[1:])
elif paras[0] == 'ln_1':
new_para_name = '.'.join(['norms.0'] + paras[1:])
elif paras[0] == 'ln_2':
new_para_name = '.'.join(['norms.1'] + paras[1:])
elif paras[0] == 'mlp':
if paras[1] == 'c_fc':
new_para_name = '.'.join(['ffns.0.layers.0.0'] +
paras[2:])
elif paras[1] == 'c_proj':
new_para_name = '.'.join(['ffns.0.layers.1'] +
paras[2:])
else:
print(f'Wrong for {k}')
else:
print(f'Wrong for {k}')
new_name = '.'.join([
text_encoder_name, layer_name, layer_index, new_para_name
])
elif key_list[1] in [
'positional_embedding', 'text_projection', 'bg_embed',
'attn_mask', 'logit_scale', 'token_embedding', 'ln_final'
]:
new_name = k.replace('ov_classifier', 'text_encoder')
else:
print(f'Wrong for {k}')
elif key_list[0] == 'criterion':
new_name = k
else:
print(f'Wrong for {k}')
new_ckpt[new_name] = v
return new_ckpt
def convert_tensor(ckpt):
cls_token = ckpt['image_encoder.cls_token']
new_cls_token = cls_token.unsqueeze(0).unsqueeze(0)
ckpt['image_encoder.cls_token'] = new_cls_token
pos_embed = ckpt['image_encoder.pos_embed']
new_pos_embed = pos_embed.unsqueeze(0)
ckpt['image_encoder.pos_embed'] = new_pos_embed
proj_weight = ckpt['decode_head.rec_with_attnbias.proj.weight']
new_proj_weight = proj_weight.transpose(1, 0)
ckpt['decode_head.rec_with_attnbias.proj.weight'] = new_proj_weight
return ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in timm pretrained vit models to '
'MMSegmentation style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'state_dict' in checkpoint:
# timm checkpoint
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
# deit checkpoint
state_dict = checkpoint['model']
else:
state_dict = checkpoint
weight = convert_key_name(state_dict)
weight = convert_tensor(weight)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
if __name__ == '__main__':
main()