[Project] Support CAT-Seg from CVPR2023 (#3098)

Thanks for your contribution and we appreciate it a lot. The following
instructions would make your pull request more healthy and more easily
get feedback. If you do not understand some items, don't worry, just
make the pull request and seek help from maintainers.

## Motivation

Support CAT-Seg open-vocabulary semantic segmentation (CVPR2023).

## Modification

Support CAT-Seg open-vocabulary semantic segmentation (CVPR2023).
- [x] Support CAT-Seg model training.
- [x] CLIP model based `backbone` (R101 & Swin-B), aggregation layers
based `neck`, and `decoder` head.
  - [x] Provide customized coco-stuff164k_384x384 training configs.
- [x] Language model supports for `open vocabulary` (OV) tasks. 
  - [x] Support CLIP-based pretrained language model (LM) inference.
  - [x] Add commonly used prompts templates. 
- [x] Add README tutorials.
- [x] Add zero-shot testing scripts.

**Working on the following tasks.**
- [x] Add unit test.

## BC-breaking (Optional)

Does the modification introduce changes that break the
backward-compatibility of the downstream repos?
If so, please describe how it breaks the compatibility and how the
downstream projects should modify their code to keep compatibility with
this PR.

## Use cases (Optional)

If this PR introduces a new feature, it is better to list some use cases
here, and update the documentation.

## Checklist

1. Pre-commit or other linting tools are used to fix the potential lint
issues.
2. The modification is covered by complete unit tests. If not, please
add more unit test to ensure the correctness.
3. If the modification has potential influence on downstream projects,
this PR should be tested with downstream projects, like MMDet or
MMDet3D.
4. The documentation has been modified accordingly, like docstring or
example tutorials.

---------

Co-authored-by: xiexinch <xiexinch@outlook.com>
This commit is contained in:
Xu CAO 2023-08-09 23:57:30 +08:00 committed by GitHub
parent 1e937961b3
commit e458a467d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 3305 additions and 0 deletions

View File

@ -0,0 +1,92 @@
# CAT-Seg
> [CAT-Seg: Cost Aggregation for Open-Vocabulary Semantic Segmentation](https://arxiv.org/abs/2303.11797)
## Introduction
<!-- [ALGORITHM] -->
<a href="https://github.com/KU-CVLAB/CAT-Seg">Official Repo</a>
<a href="https://github.com/SheffieldCao/mmsegmentation/blob/support-cat-seg/mmseg/models/necks/cat_aggregator.py">Code Snippet</a>
## Abstract
<!-- [ABSTRACT] -->
Existing works on open-vocabulary semantic segmentation have utilized large-scale vision-language models, such as CLIP, to leverage their exceptional open-vocabulary recognition capabilities. However, the problem of transferring these capabilities learned from image-level supervision to the pixel-level task of segmentation and addressing arbitrary unseen categories at inference makes this task challenging. To address these issues, we aim to attentively relate objects within an image to given categories by leveraging relational information among class categories and visual semantics through aggregation, while also adapting the CLIP representations to the pixel-level task. However, we observe that direct optimization of the CLIP embeddings can harm its open-vocabulary capabilities. In this regard, we propose an alternative approach to optimize the imagetext similarity map, i.e. the cost map, using a novel cost aggregation-based method. Our framework, namely CATSeg, achieves state-of-the-art performance across all benchmarks. We provide extensive ablation studies to validate our choices. [Project page](https://ku-cvlab.github.io/CAT-Seg).
<!-- [IMAGE] -->
<div align=center >
<img alt="CAT-Seg" src="https://github.com/open-mmlab/mmsegmentation/assets/49406546/d54674bb-52ae-4a20-a168-e25d041111e8"/>
CAT-Seg model structure
</div>
## Usage
CAT-Seg model training needs pretrained `CLIP` model. We have implemented `ViT-B` and `ViT-L` based `CLIP` model. To further use `ViT-bigG` or `ViT-H` ones, you need additional dependencies. Please install [open_clip](https://github.com/mlfoundations/open_clip) first. The pretrained `CLIP` model state dicts are loaded from [Huggingface-OpenCLIP](https://huggingface.co/models?library=open_clip). **If you come up with `ConnectionError` when downloading CLIP weights**, you can manually download them from the given repo and use `custom_clip_weights=/path/to/you/folder` of backbone in config file. Related tools are as shown in [requirements/optional.txt](requirements/optional.txt):
```shell
pip install ftfy==6.0.1
pip install huggingface-hub
pip install regex
```
In addition to the necessary [data preparation](https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md), you also need class texts for clip text encoder. Please download the class text json file first [cls_texts](https://github.com/open-mmlab/mmsegmentation/files/11714914/cls_texts.zip) and arrange the folder as follows:
```none
mmsegmentation
├── mmseg
├── tools
├── configs
├── data
│ ├── VOCdevkit
│ │ ├── VOC2012
│ │ ├── VOC2010
│ │ ├── VOCaug
│ ├── ade
│ ├── coco_stuff164k
│ ├── coco.json
│ ├── pc59.json
│ ├── pc459.json
│ ├── ade150.json
│ ├── ade847.json
│ ├── voc20b.json
│ ├── voc20.json
```
```shell
# setup PYTHONPATH
export PYTHONPATH=`pwd`:$PYTHONPATH
# run evaluation
mim test mmsegmentation ${CONFIG} --checkpoint ${CHECKPOINT} --launcher pytorch --gpus=8
```
## Results and models
### ADE20K-150-ZeroShot
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | Device | mIoU | mIoU(ms+flip) | config | download |
| ------- | ------------- | --------- | ------- | -------: | -------------- | ------- | ---- | ------------: | ------------------------------------------------------------------------------------------: | --------------------------------------------------------------------------------------------------------------------------------------------- |
| CAT-Seg | R-101 & ViT-B | 384x384 | 80000 | - | - | RTX3090 | 27.2 | - | [config](./configs/cat_seg/catseg_vitb-r101_4xb1-warmcoslr2e-4-adamw-80k_ade20k-384x384.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/cat_seg/catseg_vitb-r101_4xb1-warmcoslr2e-4-adamw-80k_ade20k-384x384-54194d72.pth) |
Note:
- All experiments of CAT-Seg are implemented with 4 RTX3090 GPUs, except the last one with pretrained ViT-bigG CLIP model (GPU Memory insufficient, you may need A100).
- Due to the feature size bottleneck of the CLIP image encoder, the inference and testing can only be done under `slide` mode, the inference time is longer since the test size is much more bigger that training size of `(384, 384)`.
- The ResNet backbones utilized in CAT-Seg models are standard `ResNet` rather than `ResNetV1c`.
- The zero-shot segmentation results on PASCAL VOC and ADE20K are from the original paper. Our results are coming soon. We appreatiate your contribution!
- In additional to zero-shot segmentation performance results, we also provided the evaluation results on the `val2017` set of **COCO-stuff164k** for reference, which is the training dataset of CAT-Seg. The testing was done **without TTA**.
- The number behind the dataset name is the category number for segmentation evaluation (except training data **COCO-stuff 164k**). **PASCAL VOC-20b** defines the "background" as classes present in **PASCAL-Context-59** but not in **PASCAL VOC-20**.
## Citation
```bibtex
@inproceedings{cheng2021mask2former,
title={CAT-Seg: Cost Aggregation for Open-Vocabulary Semantic Segmentation},
author={Seokju Cho and Heeseong Shin and Sunghwan Hong and Seungjun An and Seungjun Lee and Anurag Arnab and Paul Hongsuck Seo and Seungryong Kim},
journal={CVPR},
year={2023}
}
```

View File

@ -0,0 +1,2 @@
from .models import * # noqa: F401,F403
from .utils import * # noqa: F401,F403

View File

@ -0,0 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .cat_aggregator import (AggregatorLayer, CATSegAggregator,
ClassAggregateLayer, SpatialAggregateLayer)
from .cat_head import CATSegHead
from .clip_ovseg import CLIPOVCATSeg
__all__ = [
'AggregatorLayer', 'CATSegAggregator', 'ClassAggregateLayer',
'SpatialAggregateLayer', 'CATSegHead', 'CLIPOVCATSeg'
]

View File

@ -0,0 +1,763 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, build_dropout
from mmengine.model import BaseModule
from mmengine.utils import to_2tuple
from mmseg.registry import MODELS
from ..utils import FullAttention, LinearAttention
class AGWindowMSA(BaseModule):
"""Appearance Guidance Window based multi-head self-attention (W-MSA)
module with relative position bias.
Args:
embed_dims (int): Number of input channels.
appearance_dims (int): Number of appearance guidance feature channels.
num_heads (int): Number of attention heads.
window_size (tuple[int]): The height and width of the window.
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: True.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
attn_drop_rate (float, optional): Dropout ratio of attention weight.
Default: 0.0
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
init_cfg (dict | None, optional): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
appearance_dims,
num_heads,
window_size,
qkv_bias=True,
qk_scale=None,
attn_drop_rate=0.,
proj_drop_rate=0.,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims
self.appearance_dims = appearance_dims
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_embed_dims = embed_dims // num_heads
self.scale = qk_scale or head_embed_dims**-0.5
# About 2x faster than original impl
Wh, Ww = self.window_size
rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww)
rel_position_index = rel_index_coords + rel_index_coords.T
rel_position_index = rel_position_index.flip(1).contiguous()
self.register_buffer('relative_position_index', rel_position_index)
self.qk = nn.Linear(
embed_dims + appearance_dims, embed_dims * 2, bias=qkv_bias)
self.v = nn.Linear(embed_dims, embed_dims, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop_rate)
self.proj = nn.Linear(embed_dims, embed_dims)
self.proj_drop = nn.Dropout(proj_drop_rate)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x (tensor): input features with shape of (num_windows*B, N, C),
C = embed_dims + appearance_dims.
mask (tensor | None, Optional): mask with shape of (num_windows,
Wh*Ww, Wh*Ww), value should be between (-inf, 0].
"""
B, N, _ = x.shape
qk = self.qk(x).reshape(B, N, 2, self.num_heads,
self.embed_dims // self.num_heads).permute(
2, 0, 3, 1,
4) # 2 B NUM_HEADS N embed_dims//NUM_HEADS
v = self.v(x[:, :, :self.embed_dims]).reshape(
B, N, self.num_heads, self.embed_dims // self.num_heads).permute(
0, 2, 1, 3) # B NUM_HEADS N embed_dims//NUM_HEADS
# make torchscript happy (cannot use tensor as tuple)
q, k = qk[0], qk[1]
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B // nW, nW, self.num_heads, N,
N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.embed_dims)
x = self.proj(x)
x = self.proj_drop(x)
return x
@staticmethod
def double_step_seq(step1, len1, step2, len2):
"""Double step sequence."""
seq1 = torch.arange(0, step1 * len1, step1)
seq2 = torch.arange(0, step2 * len2, step2)
return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
class AGShiftWindowMSA(BaseModule):
"""Appearance Guidance Shifted Window Multihead Self-Attention Module.
Args:
embed_dims (int): Number of input channels.
appearance_dims (int): Number of appearance guidance channels
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window.
shift_size (int, optional): The shift step of each window towards
right-bottom. If zero, act as regular window-msa. Defaults to 0.
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: True
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Defaults: None.
attn_drop_rate (float, optional): Dropout ratio of attention weight.
Defaults: 0.
proj_drop_rate (float, optional): Dropout ratio of output.
Defaults: 0.
dropout_layer (dict, optional): The dropout_layer used before output.
Defaults: dict(type='DropPath', drop_prob=0.).
init_cfg (dict, optional): The extra config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
appearance_dims,
num_heads,
window_size,
shift_size=0,
qkv_bias=True,
qk_scale=None,
attn_drop_rate=0,
proj_drop_rate=0,
dropout_layer=dict(type='DropPath', drop_prob=0.),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.window_size = window_size
self.shift_size = shift_size
assert 0 <= self.shift_size < self.window_size
self.w_msa = AGWindowMSA(
embed_dims=embed_dims,
appearance_dims=appearance_dims,
num_heads=num_heads,
window_size=to_2tuple(window_size),
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop_rate=attn_drop_rate,
proj_drop_rate=proj_drop_rate,
init_cfg=None)
self.drop = build_dropout(dropout_layer)
def forward(self, query, hw_shape):
"""
Args:
query: The input query.
hw_shape: The shape of the feature height and width.
"""
B, L, C = query.shape
H, W = hw_shape
assert L == H * W, 'input feature has wrong size'
query = query.view(B, H, W, C)
# pad feature maps to multiples of window size
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b))
H_pad, W_pad = query.shape[1], query.shape[2]
# cyclic shift
if self.shift_size > 0:
shifted_query = torch.roll(
query,
shifts=(-self.shift_size, -self.shift_size),
dims=(1, 2))
# calculate attention mask for SW-MSA
img_mask = torch.zeros((1, H_pad, W_pad, 1), device=query.device)
h_slices = (slice(0, -self.window_size),
slice(-self.window_size,
-self.shift_size), slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size,
-self.shift_size), slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
# nW, window_size, window_size, 1
mask_windows = self.window_partition(img_mask)
mask_windows = mask_windows.view(
-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0,
float(-100.0)).masked_fill(
attn_mask == 0, float(0.0))
else:
shifted_query = query
attn_mask = None
# nW*B, window_size, window_size, C
query_windows = self.window_partition(shifted_query)
# nW*B, window_size*window_size, C
query_windows = query_windows.view(-1, self.window_size**2, C)
# W-MSA/SW-MSA (nW*B, window_size*window_size, C)
attn_windows = self.w_msa(query_windows, mask=attn_mask)
# merge windows
attn_windows = attn_windows.view(-1, self.window_size,
self.window_size,
self.w_msa.embed_dims)
# B H' W' self.w_msa.embed_dims
shifted_x = self.window_reverse(attn_windows, H_pad, W_pad)
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(
shifted_x,
shifts=(self.shift_size, self.shift_size),
dims=(1, 2))
else:
x = shifted_x
if pad_r > 0 or pad_b:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, self.w_msa.embed_dims)
x = self.drop(x)
return x
def window_reverse(self, windows, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
window_size = self.window_size
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size,
window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
def window_partition(self, x):
"""
Args:
x: (B, H, W, C)
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
window_size = self.window_size
x = x.view(B, H // window_size, window_size, W // window_size,
window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
windows = windows.view(-1, window_size, window_size, C)
return windows
class AGSwinBlock(BaseModule):
"""Appearance Guidance Swin Transformer Block.
Args:
embed_dims (int): The feature dimension.
appearance_dims (int): The appearance guidance dimension.
num_heads (int): Parallel attention heads.
mlp_ratios (int): The hidden dimension ratio w.r.t. embed_dims
for FFNs.
window_size (int, optional): The local window scale.
Default: 7.
shift (bool, optional): whether to shift window or not.
Default False.
qkv_bias (bool, optional): enable bias for qkv if True.
Default: True.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
drop_rate (float, optional): Dropout rate. Default: 0.
attn_drop_rate (float, optional): Attention dropout rate.
Default: 0.
drop_path_rate (float, optional): Stochastic depth rate.
Default: 0.
act_cfg (dict, optional): The config dict of activation function.
Default: dict(type='GELU').
norm_cfg (dict, optional): The config dict of normalization.
Default: dict(type='LN').
init_cfg (dict | list | None, optional): The init config.
Default: None.
"""
def __init__(self,
embed_dims,
appearance_dims,
num_heads,
mlp_ratios=4,
window_size=7,
shift=False,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
self.attn = AGShiftWindowMSA(
embed_dims=embed_dims,
appearance_dims=appearance_dims,
num_heads=num_heads,
window_size=window_size,
shift_size=window_size // 2 if shift else 0,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop_rate=attn_drop_rate,
proj_drop_rate=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
init_cfg=None)
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
self.ffn = FFN(
embed_dims=embed_dims,
feedforward_channels=embed_dims * mlp_ratios,
num_fcs=2,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg,
add_identity=True,
init_cfg=None)
def forward(self, inputs, hw_shape):
"""
Args:
inputs (list[Tensor]): appearance_guidance (B, H, W, C);
x (B, L, C)
hw_shape (tuple[int]): shape of feature.
"""
x, appearance_guidance = inputs
B, L, C = x.shape
H, W = hw_shape
assert L == H * W, 'input feature has wrong size'
identity = x
x = self.norm1(x)
# appearance guidance
x = x.view(B, H, W, C)
if appearance_guidance is not None:
x = torch.cat([x, appearance_guidance], dim=-1).flatten(1, 2)
x = self.attn(x, hw_shape)
x = x + identity
identity = x
x = self.norm2(x)
x = self.ffn(x, identity=identity)
return x
@MODELS.register_module()
class SpatialAggregateLayer(BaseModule):
"""Spatial aggregation layer of CAT-Seg.
Args:
embed_dims (int): The feature dimension.
appearance_dims (int): The appearance guidance dimension.
num_heads (int): Parallel attention heads.
mlp_ratios (int): The hidden dimension ratio w.r.t. embed_dims
for FFNs.
window_size (int, optional): The local window scale. Default: 7.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
init_cfg (dict | list | None, optional): The init config.
Default: None.
"""
def __init__(self,
embed_dims,
appearance_dims,
num_heads,
mlp_ratios,
window_size=7,
qk_scale=None,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.block_1 = AGSwinBlock(
embed_dims,
appearance_dims,
num_heads,
mlp_ratios,
window_size=window_size,
shift=False,
qk_scale=qk_scale)
self.block_2 = AGSwinBlock(
embed_dims,
appearance_dims,
num_heads,
mlp_ratios,
window_size=window_size,
shift=True,
qk_scale=qk_scale)
self.guidance_norm = nn.LayerNorm(
appearance_dims) if appearance_dims > 0 else None
def forward(self, x, appearance_guidance):
"""
Args:
x (torch.Tensor): B C T H W.
appearance_guidance (torch.Tensor): B C H W.
"""
B, C, T, H, W = x.shape
x = x.permute(0, 2, 3, 4, 1).flatten(0, 1).flatten(1, 2) # BT, HW, C
if appearance_guidance is not None:
appearance_guidance = appearance_guidance.repeat(
T, 1, 1, 1).permute(0, 2, 3, 1) # BT, HW, C
appearance_guidance = self.guidance_norm(appearance_guidance)
else:
assert self.appearance_dims == 0
x = self.block_1((x, appearance_guidance), (H, W))
x = self.block_2((x, appearance_guidance), (H, W))
x = x.transpose(1, 2).reshape(B, T, C, -1)
x = x.transpose(1, 2).reshape(B, C, T, H, W)
return x
class AttentionLayer(nn.Module):
"""Attention layer for ClassAggregration of CAT-Seg.
Source: https://github.com/KU-CVLAB/CAT-Seg/blob/main/cat_seg/modeling/transformer/model.py#L310 # noqa
"""
def __init__(self,
hidden_dim,
guidance_dim,
nheads=8,
attention_type='linear'):
super().__init__()
self.nheads = nheads
self.q = nn.Linear(hidden_dim + guidance_dim, hidden_dim)
self.k = nn.Linear(hidden_dim + guidance_dim, hidden_dim)
self.v = nn.Linear(hidden_dim, hidden_dim)
if attention_type == 'linear':
self.attention = LinearAttention()
elif attention_type == 'full':
self.attention = FullAttention()
else:
raise NotImplementedError
def forward(self, x, guidance=None):
"""
Args:
x: B*H_p*W_p, T, C
guidance: B*H_p*W_p, T, C
"""
B, L, _ = x.shape
q = self.q(torch.cat([x, guidance],
dim=-1)) if guidance is not None else self.q(x)
k = self.k(torch.cat([x, guidance],
dim=-1)) if guidance is not None else self.k(x)
v = self.v(x)
q = q.reshape(B, L, self.nheads, -1)
k = k.reshape(B, L, self.nheads, -1)
v = v.reshape(B, L, self.nheads, -1)
out = self.attention(q, k, v)
out = out.reshape(B, L, -1)
return out
@MODELS.register_module()
class ClassAggregateLayer(BaseModule):
"""Class aggregation layer of CAT-Seg.
Args:
hidden_dims (int): The feature dimension.
guidance_dims (int): The appearance guidance dimension.
num_heads (int): Parallel attention heads.
attention_type (str): Type of attention layer. Default: 'linear'.
pooling_size (tuple[int] | list[int]): Pooling size.
init_cfg (dict | list | None, optional): The init config.
Default: None.
"""
def __init__(
self,
hidden_dims=64,
guidance_dims=64,
num_heads=8,
attention_type='linear',
pooling_size=(4, 4),
init_cfg=None,
):
super().__init__(init_cfg=init_cfg)
self.pool = nn.AvgPool2d(pooling_size)
self.attention = AttentionLayer(
hidden_dims,
guidance_dims,
nheads=num_heads,
attention_type=attention_type)
self.MLP = FFN(
embed_dims=hidden_dims,
feedforward_channels=hidden_dims * 4,
num_fcs=2)
self.norm1 = nn.LayerNorm(hidden_dims)
self.norm2 = nn.LayerNorm(hidden_dims)
def pool_features(self, x):
"""Intermediate pooling layer for computational efficiency.
Args:
x: B, C, T, H, W
"""
B, C, T, H, W = x.shape
x = x.transpose(1, 2).reshape(-1, C, H, W)
x = self.pool(x)
*_, H_, W_ = x.shape
x = x.reshape(B, T, C, H_, W_).transpose(1, 2)
return x
def forward(self, x, guidance):
"""
Args:
x: B, C, T, H, W
guidance: B, T, C
"""
B, C, T, H, W = x.size()
x_pool = self.pool_features(x)
*_, H_pool, W_pool = x_pool.size()
x_pool = x_pool.permute(0, 3, 4, 2, 1).reshape(-1, T, C)
# B*H_p*W_p T C
if guidance is not None:
guidance = guidance.repeat(H_pool * W_pool, 1, 1)
x_pool = x_pool + self.attention(self.norm1(x_pool),
guidance) # Attention
x_pool = x_pool + self.MLP(self.norm2(x_pool)) # MLP
x_pool = x_pool.reshape(B, H_pool * W_pool, T,
C).permute(0, 2, 3, 1).reshape(
B, T, C, H_pool,
W_pool).flatten(0, 1) # BT C H_p W_p
x_pool = F.interpolate(
x_pool, size=(H, W), mode='bilinear', align_corners=True)
x_pool = x_pool.reshape(B, T, C, H, W).transpose(1, 2) # B C T H W
x = x + x_pool # Residual
return x
@MODELS.register_module()
class AggregatorLayer(BaseModule):
"""Single Aggregator Layer of CAT-Seg."""
def __init__(self,
embed_dims=64,
text_guidance_dims=512,
appearance_guidance_dims=512,
num_heads=4,
mlp_ratios=4,
window_size=7,
attention_type='linear',
pooling_size=(2, 2),
init_cfg=None) -> None:
super().__init__(init_cfg=init_cfg)
self.spatial_agg = SpatialAggregateLayer(
embed_dims,
appearance_guidance_dims,
num_heads=num_heads,
mlp_ratios=mlp_ratios,
window_size=window_size)
self.class_agg = ClassAggregateLayer(
embed_dims,
text_guidance_dims,
num_heads=num_heads,
attention_type=attention_type,
pooling_size=pooling_size)
def forward(self, x, appearance_guidance, text_guidance):
"""
Args:
x: B C T H W
"""
x = self.spatial_agg(x, appearance_guidance)
x = self.class_agg(x, text_guidance)
return x
@MODELS.register_module()
class CATSegAggregator(BaseModule):
"""CATSeg Aggregator.
This Aggregator is the mmseg implementation of
`CAT-Seg <https://arxiv.org/abs/2303.11797>`_.
Args:
text_guidance_dim (int): Text guidance dimensions. Default: 512.
text_guidance_proj_dim (int): Text guidance projection dimensions.
Default: 128.
appearance_guidance_dim (int): Appearance guidance dimensions.
Default: 512.
appearance_guidance_proj_dim (int): Appearance guidance projection
dimensions. Default: 128.
num_layers (int): Aggregator layer number. Default: 4.
num_heads (int): Attention layer head number. Default: 4.
embed_dims (int): Input feature dimensions. Default: 128.
pooling_size (tuple | list): Pooling size of the class aggregator
layer. Default: (6, 6).
mlp_ratios (int): The hidden dimension ratio w.r.t. input dimension.
Default: 4.
window_size (int): Swin block window size. Default:12.
attention_type (str): Attention type of class aggregator layer.
Default:'linear'.
prompt_channel (int): Prompt channels. Default: 80.
"""
def __init__(self,
text_guidance_dim=512,
text_guidance_proj_dim=128,
appearance_guidance_dim=512,
appearance_guidance_proj_dim=128,
num_layers=4,
num_heads=4,
embed_dims=128,
pooling_size=(6, 6),
mlp_ratios=4,
window_size=12,
attention_type='linear',
prompt_channel=80,
**kwargs):
super().__init__(**kwargs)
self.num_layers = num_layers
self.embed_dims = embed_dims
self.layers = nn.ModuleList([
AggregatorLayer(
embed_dims=embed_dims,
text_guidance_dims=text_guidance_proj_dim,
appearance_guidance_dims=appearance_guidance_proj_dim,
num_heads=num_heads,
mlp_ratios=mlp_ratios,
window_size=window_size,
attention_type=attention_type,
pooling_size=pooling_size) for _ in range(num_layers)
])
self.conv1 = nn.Conv2d(
prompt_channel, embed_dims, kernel_size=7, stride=1, padding=3)
self.guidance_projection = nn.Sequential(
nn.Conv2d(
appearance_guidance_dim,
appearance_guidance_proj_dim,
kernel_size=3,
stride=1,
padding=1),
nn.ReLU(),
) if appearance_guidance_dim > 0 else None
self.text_guidance_projection = nn.Sequential(
nn.Linear(text_guidance_dim, text_guidance_proj_dim),
nn.ReLU(),
) if text_guidance_dim > 0 else None
def feature_map(self, img_feats, text_feats):
"""Concatenation type cost volume.
For ablation study of cost volume type.
"""
img_feats = F.normalize(img_feats, dim=1) # B C H W
img_feats = img_feats.unsqueeze(2).repeat(1, 1, text_feats.shape[1], 1,
1)
text_feats = F.normalize(text_feats, dim=-1) # B T P C
text_feats = text_feats.mean(dim=-2)
text_feats = F.normalize(text_feats, dim=-1) # B T C
text_feats = text_feats.unsqueeze(-1).unsqueeze(-1).repeat(
1, 1, 1, img_feats.shape[-2], img_feats.shape[-1]).transpose(1, 2)
return torch.cat((img_feats, text_feats), dim=1) # B 2C T H W
def correlation(self, img_feats, text_feats):
"""Correlation of image features and text features."""
img_feats = F.normalize(img_feats, dim=1) # B C H W
text_feats = F.normalize(text_feats, dim=-1) # B T P C
corr = torch.einsum('bchw, btpc -> bpthw', img_feats, text_feats)
return corr
def corr_embed(self, x):
"""Correlation embeddings encoding."""
B = x.shape[0]
corr_embed = x.permute(0, 2, 1, 3, 4).flatten(0, 1)
corr_embed = self.conv1(corr_embed)
corr_embed = corr_embed.reshape(B, -1, self.embed_dims, x.shape[-2],
x.shape[-1]).transpose(1, 2)
return corr_embed
def forward(self, inputs):
"""
Args:
inputs (dict): including the following keys,
'appearance_feat': list[torch.Tensor], w.r.t. out_indices of
`self.feature_extractor`.
'clip_text_feat': the text feature extracted by clip text
encoder.
'clip_text_feat_test': the text feature extracted by clip text
encoder for testing.
'clip_img_feat': the image feature extracted clip image
encoder.
"""
img_feats = inputs['clip_img_feat']
B = img_feats.size(0)
appearance_guidance = inputs[
'appearance_feat'][::-1] # order (out_indices) 2, 1, 0
text_feats = inputs['clip_text_feat'] if self.training else inputs[
'clip_text_feat_test']
text_feats = text_feats.repeat(B, 1, 1, 1)
corr = self.correlation(img_feats, text_feats)
# corr = self.feature_map(img_feats, text_feats)
corr_embed = self.corr_embed(corr)
projected_guidance, projected_text_guidance = None, None
if self.guidance_projection is not None:
projected_guidance = self.guidance_projection(
appearance_guidance[0])
if self.text_guidance_projection is not None:
text_feats = text_feats.mean(dim=-2)
text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True)
projected_text_guidance = self.text_guidance_projection(text_feats)
for layer in self.layers:
corr_embed = layer(corr_embed, projected_guidance,
projected_text_guidance)
return dict(
corr_embed=corr_embed, appearance_feats=appearance_guidance[1:])

View File

@ -0,0 +1,116 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmseg.registry import MODELS
class UpBlock(nn.Module):
"""Upsample Block with two consecutive convolution layers."""
def __init__(self, in_channels, out_channels, guidance_channels):
super().__init__()
self.up = nn.ConvTranspose2d(
in_channels,
in_channels - guidance_channels,
kernel_size=2,
stride=2)
self.conv1 = ConvModule(
in_channels,
out_channels,
3,
padding=1,
bias=False,
norm_cfg=dict(type='GN', num_groups=out_channels // 16))
self.conv2 = ConvModule(
out_channels,
out_channels,
3,
padding=1,
bias=False,
norm_cfg=dict(type='GN', num_groups=out_channels // 16))
def forward(self, x, guidance=None):
"""Forward function with visual guidance."""
x = self.up(x)
if guidance is not None:
T = x.size(0) // guidance.size(0)
# guidance = repeat(guidance, "B C H W -> (B T) C H W", T=T)
guidance = guidance.repeat(T, 1, 1, 1)
x = torch.cat([x, guidance], dim=1)
x = self.conv1(x)
return self.conv2(x)
@MODELS.register_module()
class CATSegHead(BaseDecodeHead):
"""CATSeg Head.
This segmentation head is the mmseg implementation of
`CAT-Seg <https://arxiv.org/abs/2303.11797>`_.
Args:
embed_dims (int): The number of input dimensions.
decoder_dims (list): The number of decoder dimensions.
decoder_guidance_proj_dims (list): The number of appearance
guidance dimensions.
init_cfg
"""
def __init__(self,
embed_dims=128,
decoder_dims=(64, 32),
decoder_guidance_dims=(256, 128),
decoder_guidance_proj_dims=(32, 16),
**kwargs):
super().__init__(**kwargs)
self.decoder_guidance_projection = nn.ModuleList([
nn.Sequential(
nn.Conv2d(
dec_dims,
dec_dims_proj,
kernel_size=3,
stride=1,
padding=1),
nn.ReLU(),
) for dec_dims, dec_dims_proj in zip(decoder_guidance_dims,
decoder_guidance_proj_dims)
]) if decoder_guidance_dims[0] > 0 else None
self.decoder1 = UpBlock(embed_dims, decoder_dims[0],
decoder_guidance_proj_dims[0])
self.decoder2 = UpBlock(decoder_dims[0], decoder_dims[1],
decoder_guidance_proj_dims[1])
self.conv_seg = nn.Conv2d(
decoder_dims[1], 1, kernel_size=3, stride=1, padding=1)
def forward(self, inputs):
"""Forward function.
Args:
inputs (dict): Input features including the following features,
corr_embed: aggregated correlation embeddings.
appearance_feats: decoder appearance feature guidance.
"""
# decoder guidance projection
if self.decoder_guidance_projection is not None:
projected_decoder_guidance = [
proj(g) for proj, g in zip(self.decoder_guidance_projection,
inputs['appearance_feats'])
]
# decoder layers
B = inputs['corr_embed'].size(0)
corr_embed = inputs['corr_embed'].transpose(1, 2).flatten(0, 1)
corr_embed = self.decoder1(corr_embed, projected_decoder_guidance[0])
corr_embed = self.decoder2(corr_embed, projected_decoder_guidance[1])
output = self.cls_seg(corr_embed)
# rearrange the output to (B, T, H, W)
H_ori, W_ori = output.shape[-2:]
output = output.reshape(B, -1, H_ori, W_ori)
return output

View File

@ -0,0 +1,293 @@
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os
from typing import List
import torch
import torch.nn.functional as F
from huggingface_hub.utils._errors import LocalEntryNotFoundError
from mmengine.model import BaseModule
from mmseg.registry import MODELS
from mmseg.utils import ConfigType
from ..utils import clip_wrapper
from ..utils.clip_templates import (IMAGENET_TEMPLATES,
IMAGENET_TEMPLATES_SELECT)
@MODELS.register_module()
class CLIPOVCATSeg(BaseModule):
"""CLIP based Open Vocabulary CAT-Seg model backbone.
This backbone is the modified implementation of `CAT-Seg Backbone
<https://arxiv.org/abs/2303.11797>`_. It combines the CLIP model and
another feature extractor, a.k.a the appearance guidance extractor
in the original `CAT-Seg`.
Args:
feature_extractor (ConfigType): Appearance guidance extractor
config dict.
train_class_json (str): The training class json file.
test_class_json (str): The path to test class json file.
clip_pretrained (str): The pre-trained clip type.
clip_finetune (str): The finetuning settings of clip model.
custom_clip_weights (str): The custmized clip weights directory. When
encountering huggingface model download errors, you can manually
download the pretrained weights.
backbone_multiplier (float): The learning rate multiplier.
Default: 0.01.
prompt_depth (int): The prompt depth. Default: 0.
prompt_length (int): The prompt length. Default: 0.
prompt_ensemble_type (str): The prompt ensemble type.
Default: "imagenet".
pixel_mean (List[float]): The pixel mean for feature extractor.
pxiel_std (List[float]): The pixel std for feature extractor.
clip_pixel_mean (List[float]): The pixel mean for clip model.
clip_pxiel_std (List[float]): The pixel std for clip model.
clip_img_feat_size: (List[int]: Clip image embedding size from
image encoder.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(
self,
feature_extractor: ConfigType,
train_class_json: str,
test_class_json: str,
clip_pretrained: str,
clip_finetune: str,
custom_clip_weights: str = None,
backbone_multiplier=0.01,
prompt_depth: int = 0,
prompt_length: int = 0,
prompt_ensemble_type: str = 'imagenet',
pixel_mean: List[float] = [123.675, 116.280, 103.530],
pixel_std: List[float] = [58.395, 57.120, 57.375],
clip_pixel_mean: List[float] = [
122.7709383, 116.7460125, 104.09373615
],
clip_pixel_std: List[float] = [68.5005327, 66.6321579, 70.3231630],
clip_img_feat_size: List[int] = [24, 24],
init_cfg=None):
super().__init__(init_cfg=init_cfg)
# normalization parameters
self.register_buffer('pixel_mean',
torch.Tensor(pixel_mean).view(1, -1, 1, 1), False)
self.register_buffer('pixel_std',
torch.Tensor(pixel_std).view(1, -1, 1, 1), False)
self.register_buffer('clip_pixel_mean',
torch.Tensor(clip_pixel_mean).view(1, -1, 1, 1),
False)
self.register_buffer('clip_pixel_std',
torch.Tensor(clip_pixel_std).view(1, -1, 1, 1),
False)
self.clip_resolution = (
384, 384) if clip_pretrained == 'ViT-B/16' else (336, 336)
# modified clip image encoder with fixed size dense output
self.clip_img_feat_size = clip_img_feat_size
# prepare clip templates
self.prompt_ensemble_type = prompt_ensemble_type
if self.prompt_ensemble_type == 'imagenet_select':
prompt_templates = IMAGENET_TEMPLATES_SELECT
elif self.prompt_ensemble_type == 'imagenet':
prompt_templates = IMAGENET_TEMPLATES
elif self.prompt_ensemble_type == 'single':
prompt_templates = [
'A photo of a {} in the scene',
]
else:
raise NotImplementedError
self.prompt_templates = prompt_templates
# build the feature extractor
self.feature_extractor = MODELS.build(feature_extractor)
# build CLIP model
with open(train_class_json) as f_in:
self.class_texts = json.load(f_in)
with open(test_class_json) as f_in:
self.test_class_texts = json.load(f_in)
assert self.class_texts is not None
if self.test_class_texts is None:
self.test_class_texts = self.class_texts
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.tokenizer = None
if clip_pretrained == 'ViT-G' or clip_pretrained == 'ViT-H':
# for OpenCLIP models
import open_clip
name, pretrain = (
'ViT-H-14',
'laion2b_s32b_b79k') if clip_pretrained == 'ViT-H' else (
'ViT-bigG-14', 'laion2b_s39b_b160k')
try:
open_clip_model = open_clip.create_model_and_transforms(
name,
pretrained=pretrain,
device=device,
force_image_size=336,
)
clip_model, _, clip_preprocess = open_clip_model
except ConnectionError or LocalEntryNotFoundError as e:
print(f'Has {e} when loading weights from huggingface!')
print(
f'Will load {pretrain} weights from {custom_clip_weights}.'
)
assert custom_clip_weights is not None, 'Please specify custom weights directory.' # noqa
assert os.path.exists(
os.path.join(custom_clip_weights,
'open_clip_pytorch_model.bin')
), 'Please provide a valid directory for manually downloaded model.' # noqa
open_clip_model = open_clip.create_model_and_transforms(
name,
pretrained=None,
device='cpu',
force_image_size=336,
)
clip_model, _, clip_preprocess = open_clip_model
open_clip.load_checkpoint(
clip_model,
os.path.expanduser(
os.path.join(custom_clip_weights,
'open_clip_pytorch_model.bin')))
clip_model.to(torch.device(device))
self.tokenizer = open_clip.get_tokenizer(name)
else:
# for OpenAI models
clip_model, clip_preprocess = clip_wrapper.load(
clip_pretrained,
device=device,
jit=False,
prompt_depth=prompt_depth,
prompt_length=prompt_length)
# pre-encode classes text prompts
text_features = self.class_embeddings(self.class_texts,
prompt_templates, clip_model,
device).permute(1, 0, 2).float()
text_features_test = self.class_embeddings(self.test_class_texts,
prompt_templates,
clip_model,
device).permute(1, 0,
2).float()
self.register_buffer('text_features', text_features, False)
self.register_buffer('text_features_test', text_features_test, False)
# prepare CLIP model finetune
self.clip_finetune = clip_finetune
self.clip_model = clip_model.float()
self.clip_preprocess = clip_preprocess
for name, params in self.clip_model.named_parameters():
if 'visual' in name:
if clip_finetune == 'prompt':
params.requires_grad = True if 'prompt' in name else False
elif clip_finetune == 'attention':
if 'attn' in name or 'position' in name:
params.requires_grad = True
else:
params.requires_grad = False
elif clip_finetune == 'full':
params.requires_grad = True
else:
params.requires_grad = False
else:
params.requires_grad = False
finetune_backbone = backbone_multiplier > 0.
for name, params in self.feature_extractor.named_parameters():
if 'norm0' in name:
params.requires_grad = False
else:
params.requires_grad = finetune_backbone
@torch.no_grad()
def class_embeddings(self,
classnames,
templates,
clip_model,
device='cpu'):
"""Convert class names to text embeddings by clip model.
Args:
classnames (list): loaded from json file.
templates (dict): text template.
clip_model (nn.Module): prepared clip model.
device (str | torch.device): loading device of text
encoder results.
"""
zeroshot_weights = []
for classname in classnames:
if ', ' in classname:
classname_splits = classname.split(', ')
texts = []
for template in templates:
for cls_split in classname_splits:
texts.append(template.format(cls_split))
else:
texts = [template.format(classname)
for template in templates] # format with class
if self.tokenizer is not None:
texts = self.tokenizer(texts).to(device)
else:
texts = clip_wrapper.tokenize(texts).to(device)
class_embeddings = clip_model.encode_text(texts)
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
if len(templates) != class_embeddings.shape[0]:
class_embeddings = class_embeddings.reshape(
len(templates), -1, class_embeddings.shape[-1]).mean(dim=1)
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings
zeroshot_weights.append(class_embedding)
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
return zeroshot_weights
def custom_normalize(self, inputs):
"""Input normalization for clip model and feature extractor
respectively.
Args:
inputs: batched input images.
"""
# clip images
batched_clip = (inputs - self.clip_pixel_mean) / self.clip_pixel_std
batched_clip = F.interpolate(
batched_clip,
size=self.clip_resolution,
mode='bilinear',
align_corners=False)
# feature extractor images
batched = (inputs - self.pixel_mean) / self.pixel_std
return batched, batched_clip
def forward(self, inputs):
"""
Args:
inputs: minibatch image. (B, 3, H, W)
Returns:
outputs (dict):
'appearance_feat': list[torch.Tensor], w.r.t. out_indices of
`self.feature_extractor`.
'clip_text_feat': the text feature extracted by clip text encoder.
'clip_text_feat_test': the text feature extracted by clip text
encoder for testing.
'clip_img_feat': the image feature extracted clip image encoder.
"""
inputs, clip_inputs = self.custom_normalize(inputs)
outputs = dict()
# extract appearance guidance feature
outputs['appearance_feat'] = self.feature_extractor(inputs)
# extract clip features
outputs['clip_text_feat'] = self.text_features
outputs['clip_text_feat_test'] = self.text_features_test
clip_features = self.clip_model.encode_image(
clip_inputs, dense=True) # B, 577(24x24+1), C
B = clip_features.size(0)
outputs['clip_img_feat'] = clip_features[:, 1:, :].permute(
0, 2, 1).reshape(B, -1, *self.clip_img_feat_size)
return outputs

View File

@ -0,0 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .clip_templates import (IMAGENET_TEMPLATES, IMAGENET_TEMPLATES_SELECT,
IMAGENET_TEMPLATES_SELECT_CLIP, ViLD_templates)
from .self_attention_block import FullAttention, LinearAttention
__all__ = [
'FullAttention', 'LinearAttention', 'IMAGENET_TEMPLATES',
'IMAGENET_TEMPLATES_SELECT', 'IMAGENET_TEMPLATES_SELECT_CLIP',
'ViLD_templates'
]

View File

@ -0,0 +1,651 @@
# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
from typing import Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
class Bottleneck(nn.Module):
"""Custom implementation of Bottleneck in ResNet."""
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
# all conv layers have stride 1.
# an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool,
# and the subsequent convolution has stride 1
self.downsample = nn.Sequential(
OrderedDict([('-1', nn.AvgPool2d(stride)),
('0',
nn.Conv2d(
inplanes,
planes * self.expansion,
1,
stride=1,
bias=False)),
('1', nn.BatchNorm2d(planes * self.expansion))]))
def forward(self, x: torch.Tensor):
"""
Args:
x (torch.Tensor): the input feature.
"""
identity = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class AttentionPool2d(nn.Module):
"""Attention Pool2d."""
def __init__(self,
spacial_dim: int,
embed_dim: int,
num_heads: int,
output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(
torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
"""
Args:
x (torch.Tensor): the input feature.
"""
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x[:1],
key=x,
value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat(
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False)
return x.squeeze(0)
class ModifiedResNet(nn.Module):
"""A ResNet class that is similar to torchvision's but contains the
following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average
pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is
prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(self,
layers,
output_dim,
heads,
input_resolution=224,
width=64):
super().__init__()
self.output_dim = output_dim
self.input_resolution = input_resolution
# the 3-layer stem
self.conv1 = nn.Conv2d(
3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(width // 2)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(
width // 2, width // 2, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(width // 2)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(
width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.relu3 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(2)
# residual layers
# this is a *mutable* variable used during construction
self._inplanes = width
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim,
heads, output_dim)
def _make_layer(self, planes, blocks, stride=1):
"""Build resnet layers."""
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
"""
Args:
x (torch.Tensor): the input mini-batch images.
"""
def stem(x):
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
x = self.relu3(self.bn3(self.conv3(x)))
x = self.avgpool(x)
return x
x = x.type(self.conv1.weight.dtype)
x = stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
return x
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
"""
Args:
x (torch.Tensor): the input feature.
"""
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class QuickGELU(nn.Module):
"""Wrapper of GELU activation layer."""
def forward(self, x: torch.Tensor):
"""
Args:
x (torch.Tensor): the input feature.
"""
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
"""Attention block with residual connection."""
def __init__(self,
d_model: int,
n_head: int,
attn_mask: torch.Tensor = None):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(
OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
('gelu', QuickGELU()),
('c_proj', nn.Linear(d_model * 4, d_model))]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
self.mask_pre_mlp = True
def attention(self, x: torch.Tensor):
"""Calculate mask multi-head-attention."""
self.attn_mask = self.attn_mask.to(
dtype=x.dtype,
device=x.device) if self.attn_mask is not None else None
return self.attn(
x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor):
"""
Args:
x (torch.Tensor): the input feature.
"""
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
def forward_dense(self, x: torch.Tensor):
"""Reinplementation of forward function for dense prediction of image
encoder in CLIP model.
Args:
x (torch.Tensor): the input feature.
"""
y = self.ln_1(x)
y = F.linear(y, self.attn.in_proj_weight, self.attn.in_proj_bias)
L, N, D = y.shape # L N 3D
y = y.reshape(L, N, 3, D // 3).permute(2, 1, 0,
3).reshape(3 * N, L, D // 3)
y = F.linear(y, self.attn.out_proj.weight, self.attn.out_proj.bias)
q, k, v = y.tensor_split(3, dim=0)
v = v.transpose(1, 0) + x # L N D
v = v + self.mlp(self.ln_2(v))
return v
class Transformer(nn.Module):
"""General Transformer Architecture for both image and text encoder."""
def __init__(self,
width: int,
layers: int,
heads: int,
attn_mask: torch.Tensor = None,
prompt_length=0,
prompt_depth=0):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[
ResidualAttentionBlock(width, heads, attn_mask)
for _ in range(layers)
])
self.prompt_length = prompt_length
self.prompt_depth = prompt_depth
self.prompt_tokens = nn.Parameter(
torch.zeros(prompt_depth, prompt_length,
width)) if prompt_length > 0 else None
if self.prompt_tokens is not None:
nn.init.xavier_uniform_(self.prompt_tokens)
def forward(self, x: torch.Tensor, dense=False):
"""
Args:
x (torch.Tensor): input features.
dense (bool): whether use reimplemented dense forward
function in the last layer.
"""
for i, resblock in enumerate(self.resblocks):
if self.prompt_length > 0 and i < self.prompt_depth:
length = self.prompt_length + 1 if i > 0 else 1
x = torch.cat((x[0:1, :, :], self.prompt_tokens[i].repeat(
x.shape[1], 1, 1).permute(1, 0, 2), x[length:, :, :]))
if i == self.layers - 1 and dense:
x = resblock.forward_dense(x)
x = torch.cat((x[0:1, :, :], x[self.prompt_length + 1::, :]),
dim=0)
else:
x = resblock(x)
return x
class VisualTransformer(nn.Module):
"""Visual encoder for CLIP model."""
def __init__(self, input_resolution: int, patch_size: int, width: int,
layers: int, heads: int, output_dim: int, prompt_depth: int,
prompt_length: int):
super().__init__()
self.output_dim = output_dim
self.conv1 = nn.Conv2d(
in_channels=3,
out_channels=width,
kernel_size=patch_size,
stride=patch_size,
bias=False)
scale = width**-0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn(
(input_resolution // patch_size)**2 + 1, width))
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(
width,
layers,
heads,
prompt_depth=prompt_depth,
prompt_length=prompt_length)
self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
self.patch_size = patch_size
self.input_resolution = input_resolution
def forward(self, x: torch.Tensor, dense=False):
"""
Args:
x (torch.Tensor): input features.
dense (bool): whether use reimplemented dense forward
function in the last layer.
"""
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1],
-1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat([
self.class_embedding.to(x.dtype) + torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x
],
dim=1) # shape = [*, grid ** 2 + 1, width]
if dense and (x.shape[1] != self.positional_embedding.shape[0]):
x = x + self.resized_pos_embed(self.input_resolution,
x.shape[1]).to(x.dtype)
else:
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x, dense)
x = x.permute(1, 0, 2) # LND -> NLD
if dense:
x = self.ln_post(x[:, :, :])
else:
x = self.ln_post(x[:, 0, :])
if self.proj is not None:
x = x @ self.proj
return x
def resized_pos_embed(self, in_res, tgt_res, mode='bicubic'):
"""Resize the position embedding."""
# assert L == (input_resolution // self.patch_size) ** 2 + 1
L, D = self.positional_embedding.shape
in_side = in_res // self.patch_size
# tgt_side = tgt_res // self.patch_size
tgt_side = int((tgt_res - 1)**0.5)
cls_pos = self.positional_embedding[0].unsqueeze(0) # 1 D
pos_embed = self.positional_embedding[1:].reshape(
1, in_side, in_side, D).permute(0, 3, 1, 2) # L-1 D -> 1 D S S
resized_pos_embed = F.interpolate(
pos_embed,
size=(tgt_side, tgt_side),
mode=mode,
align_corners=False,
) # 1 D S S -> 1 D S' S'
resized_pos_embed = resized_pos_embed.squeeze(0).reshape(
D, -1).T # L'-1 D
return torch.cat((cls_pos, resized_pos_embed), dim=0)
class CLIP(nn.Module):
"""Custom implementation of CLIP model.
Refer to: https://github.com/openai/CLIP
"""
def __init__(
self,
embed_dim: int,
# vision
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int,
# prompt
prompt_depth: int = 0,
prompt_length: int = 0,
):
super().__init__()
self.context_length = context_length
self.image_resolution = image_resolution
if isinstance(vision_layers, (tuple, list)):
assert prompt_length == 0 and prompt_depth == 0
vision_heads = vision_width * 32 // 64
self.visual = ModifiedResNet(
layers=vision_layers,
output_dim=embed_dim,
heads=vision_heads,
input_resolution=image_resolution,
width=vision_width)
else:
vision_heads = vision_width // 64
self.visual = VisualTransformer(
input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim,
prompt_depth=prompt_depth,
prompt_length=prompt_length,
)
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask())
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(
torch.empty(self.context_length, transformer_width))
self.ln_final = LayerNorm(transformer_width)
self.text_projection = nn.Parameter(
torch.empty(transformer_width, embed_dim))
self.logit_scale = nn.Parameter(torch.ones([]))
def build_attention_mask(self):
"""Create causal attention mask."""
# lazily create causal attention mask, with full attention between
# the vision tokens pytorch uses additive attention mask; fill with
# -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float('-inf'))
mask.triu_(1) # zero out the lower diagonal
return mask
@property
def dtype(self):
"""Return the dtype of the model."""
return self.visual.conv1.weight.dtype
def encode_image(self, image, masks=None, pool_mask=None, dense=False):
"""Image encoding."""
if pool_mask is not None:
return self.visual(
image.type(self.dtype), mask=pool_mask, dense=dense)
if masks is None:
return self.visual(image.type(self.dtype), dense=dense)
else:
return self.visual(image.type(self.dtype), masks.type(self.dtype))
def encode_text(self, text):
"""Texts encoding."""
x = self.token_embedding(text).type(
self.dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number
# in each sequence)
x = x[torch.arange(x.shape[0]),
text.argmax(dim=-1)] @ self.text_projection
return x
def forward(self, image, text):
"""
Args:
image (torch.Tensor): input images.
text (torch.Tensor): input text.
"""
image_features = self.encode_image(image)
text_features = self.encode_text(text)
# import pdb; pdb.set_trace()
# normalized features
# image_features shape: [1, 1024]
image_features = image_features / image_features.norm(
dim=-1, keepdim=True)
text_features = text_features / text_features.norm(
dim=-1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_iamge = logit_scale * image_features @ text_features.t()
logits_per_text = logit_scale * text_features @ image_features.t()
# shape = [global_batch_size, global_batch_size]
return logits_per_iamge, logits_per_text
def convert_weights(model: nn.Module):
"""Convert applicable model parameters to fp16."""
def _convert_weights_to_fp16(layer):
if isinstance(layer, (nn.Conv1d, nn.Conv2d, nn.Linear)):
layer.weight.data = layer.weight.data.half()
if layer.bias is not None:
layer.bias.data = layer.bias.data.half()
if isinstance(layer, nn.MultiheadAttention):
for attr in [
*[f'{s}_proj_weight' for s in ['in', 'q', 'k', 'v']],
'in_proj_bias', 'bias_k', 'bias_v'
]:
tensor = getattr(layer, attr)
if tensor is not None:
tensor.data = tensor.data.half()
for name in ['text_projection', 'proj']:
if hasattr(layer, name):
attr = getattr(layer, name)
if attr is not None:
attr.data = attr.data.half()
model.apply(_convert_weights_to_fp16)
def build_model(state_dict: dict, prompt_depth=0, prompt_length=0):
"""Build a CLIP model from given pretrained weights."""
vit = 'visual.proj' in state_dict
if vit:
vision_width = state_dict['visual.conv1.weight'].shape[0]
vision_layers = len([
k for k in state_dict.keys()
if k.startswith('visual.') and k.endswith('.attn.in_proj_weight')
])
vision_patch_size = state_dict['visual.conv1.weight'].shape[-1]
grid_size = round(
(state_dict['visual.positional_embedding'].shape[0] - 1)**0.5)
image_resolution = vision_patch_size * grid_size
else:
counts: list = [
len({
k.split('.')[2]
for k in state_dict if k.startswith(f'visual.layer{b}')
}) for b in [1, 2, 3, 4]
]
vision_layers = tuple(counts)
vision_width = state_dict['visual.layer1.0.conv1.weight'].shape[0]
output_width = round(
(state_dict['visual.attnpool.positional_embedding'].shape[0] -
1)**0.5)
vision_patch_size = None
assert output_width**2 + 1 == state_dict[
'visual.attnpool.positional_embedding'].shape[0]
image_resolution = output_width * 32
embed_dim = state_dict['text_projection'].shape[1]
context_length = state_dict['positional_embedding'].shape[0]
vocab_size = state_dict['token_embedding.weight'].shape[0]
transformer_width = state_dict['ln_final.weight'].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len({
k.split('.')[2]
for k in state_dict if k.startswith('transformer.resblocks')
})
model = CLIP(
embed_dim,
image_resolution,
vision_layers,
vision_width,
vision_patch_size,
context_length,
vocab_size,
transformer_width,
transformer_heads,
transformer_layers,
prompt_depth=prompt_depth,
prompt_length=prompt_length,
)
for key in ['input_resolution', 'context_length', 'vocab_size']:
del state_dict[key]
convert_weights(model)
model.load_state_dict(state_dict, strict=False)
return model.eval()

View File

@ -0,0 +1,204 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Source: https://github.com/openai/CLIP.
IMAGENET_TEMPLATES = [
'a bad photo of a {}.',
'a photo of many {}.',
'a sculpture of a {}.',
'a photo of the hard to see {}.',
'a low resolution photo of the {}.',
'a rendering of a {}.',
'graffiti of a {}.',
'a bad photo of the {}.',
'a cropped photo of the {}.',
'a tattoo of a {}.',
'the embroidered {}.',
'a photo of a hard to see {}.',
'a bright photo of a {}.',
'a photo of a clean {}.',
'a photo of a dirty {}.',
'a dark photo of the {}.',
'a drawing of a {}.',
'a photo of my {}.',
'the plastic {}.',
'a photo of the cool {}.',
'a close-up photo of a {}.',
'a black and white photo of the {}.',
'a painting of the {}.',
'a painting of a {}.',
'a pixelated photo of the {}.',
'a sculpture of the {}.',
'a bright photo of the {}.',
'a cropped photo of a {}.',
'a plastic {}.',
'a photo of the dirty {}.',
'a jpeg corrupted photo of a {}.',
'a blurry photo of the {}.',
'a photo of the {}.',
'a good photo of the {}.',
'a rendering of the {}.',
'a {} in a video game.',
'a photo of one {}.',
'a doodle of a {}.',
'a close-up photo of the {}.',
'a photo of a {}.',
'the origami {}.',
'the {} in a video game.',
'a sketch of a {}.',
'a doodle of the {}.',
'a origami {}.',
'a low resolution photo of a {}.',
'the toy {}.',
'a rendition of the {}.',
'a photo of the clean {}.',
'a photo of a large {}.',
'a rendition of a {}.',
'a photo of a nice {}.',
'a photo of a weird {}.',
'a blurry photo of a {}.',
'a cartoon {}.',
'art of a {}.',
'a sketch of the {}.',
'a embroidered {}.',
'a pixelated photo of a {}.',
'itap of the {}.',
'a jpeg corrupted photo of the {}.',
'a good photo of a {}.',
'a plushie {}.',
'a photo of the nice {}.',
'a photo of the small {}.',
'a photo of the weird {}.',
'the cartoon {}.',
'art of the {}.',
'a drawing of the {}.',
'a photo of the large {}.',
'a black and white photo of a {}.',
'the plushie {}.',
'a dark photo of a {}.',
'itap of a {}.',
'graffiti of the {}.',
'a toy {}.',
'itap of my {}.',
'a photo of a cool {}.',
'a photo of a small {}.',
'a tattoo of the {}.',
# 'A photo of a {} in the scene.',
]
# v1: 59.0875
IMAGENET_TEMPLATES_SELECT = [
'itap of a {}.',
'a bad photo of the {}.',
'a origami {}.',
'a photo of the large {}.',
'a {} in a video game.',
'art of the {}.',
'a photo of the small {}.',
'A photo of a {} in the scene',
]
# v9
IMAGENET_TEMPLATES_SELECT_CLIP = [
'a bad photo of the {}.',
'a photo of the large {}.',
'a photo of the small {}.',
'a cropped photo of a {}.',
'This is a photo of a {}',
'This is a photo of a small {}',
'This is a photo of a medium {}',
'This is a photo of a large {}',
'This is a masked photo of a {}',
'This is a masked photo of a small {}',
'This is a masked photo of a medium {}',
'This is a masked photo of a large {}',
'This is a cropped photo of a {}',
'This is a cropped photo of a small {}',
'This is a cropped photo of a medium {}',
'This is a cropped photo of a large {}',
'A photo of a {} in the scene',
'a bad photo of the {} in the scene',
'a photo of the large {} in the scene',
'a photo of the small {} in the scene',
'a cropped photo of a {} in the scene',
'a photo of a masked {} in the scene',
'There is a {} in the scene',
'There is the {} in the scene',
'This is a {} in the scene',
'This is the {} in the scene',
'This is one {} in the scene',
'There is a masked {} in the scene',
'There is the masked {} in the scene',
'This is a masked {} in the scene',
'This is the masked {} in the scene',
'This is one masked {} in the scene',
]
# v10, for comparison
# IMAGENET_TEMPLATES_SELECT_CLIP = [
# 'a photo of a {}.',
#
# 'This is a photo of a {}',
# 'This is a photo of a small {}',
# 'This is a photo of a medium {}',
# 'This is a photo of a large {}',
#
# 'This is a photo of a {}',
# 'This is a photo of a small {}',
# 'This is a photo of a medium {}',
# 'This is a photo of a large {}',
#
# 'a photo of a {} in the scene',
# 'a photo of a {} in the scene',
#
# 'There is a {} in the scene',
# 'There is the {} in the scene',
# 'This is a {} in the scene',
# 'This is the {} in the scene',
# 'This is one {} in the scene',
# ]
ViLD_templates = [
'There is {article} {category} in the scene.',
'There is the {category} in the scene.',
'a photo of {article} {category} in the scene.',
'a photo of the {category} in the scene.',
'a photo of one {category} in the scene.', 'itap of {article} {category}.',
'itap of my {category}.', 'itap of the {category}.',
'a photo of {article} {category}.', 'a photo of my {category}.',
'a photo of the {category}.', 'a photo of one {category}.',
'a photo of many {category}.', 'a good photo of {article} {category}.',
'a good photo of the {category}.', 'a bad photo of {article} {category}.',
'a bad photo of the {category}.', 'a photo of a nice {category}.',
'a photo of the nice {category}.', 'a photo of a cool {category}.',
'a photo of the cool {category}.', 'a photo of a weird {category}.',
'a photo of the weird {category}.', 'a photo of a small {category}.',
'a photo of the small {category}.', 'a photo of a large {category}.',
'a photo of the large {category}.', 'a photo of a clean {category}.',
'a photo of the clean {category}.', 'a photo of a dirty {category}.',
'a photo of the dirty {category}.',
'a bright photo of {article} {category}.',
'a bright photo of the {category}.',
'a dark photo of {article} {category}.', 'a dark photo of the {category}.',
'a photo of a hard to see {category}.',
'a photo of the hard to see {category}.',
'a low resolution photo of {article} {category}.',
'a low resolution photo of the {category}.',
'a cropped photo of {article} {category}.',
'a cropped photo of the {category}.',
'a close-up photo of {article} {category}.',
'a close-up photo of the {category}.',
'a jpeg corrupted photo of {article} {category}.',
'a jpeg corrupted photo of the {category}.',
'a blurry photo of {article} {category}.',
'a blurry photo of the {category}.',
'a pixelated photo of {article} {category}.',
'a pixelated photo of the {category}.',
'a black and white photo of the {category}.',
'a black and white photo of {article} {category}.',
'a plastic {category}.', 'the plastic {category}.', 'a toy {category}.',
'the toy {category}.', 'a plushie {category}.', 'the plushie {category}.',
'a cartoon {category}.', 'the cartoon {category}.',
'an embroidered {category}.', 'the embroidered {category}.',
'a painting of the {category}.', 'a painting of a {category}.'
]

View File

@ -0,0 +1,275 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Referred to: https://github.com/KU-CVLAB/CAT-Seg/blob/main/cat_seg/third_party/clip.py # noqa
import hashlib
import os
import urllib
import warnings
from typing import List, Union
import torch
from PIL import Image
from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize,
ToTensor)
from tqdm import tqdm
from .clip_model import build_model
from .tokenizer import SimpleTokenizer as _Tokenizer
__all__ = ['available_models', 'load', 'tokenize']
_tokenizer = _Tokenizer()
_MODELS = {
'RN50':
'https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt', # noqa
'RN101':
'https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt', # noqa
'RN50x4':
'https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt', # noqa
'RN50x16':
'https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt', # noqa
'RN50x64':
'https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt', # noqa
'ViT-B/32':
'https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt', # noqa
'ViT-B/16':
'https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt', # noqa
'ViT-L/14':
'https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt', # noqa
'ViT-L/14@336px':
'https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt', # noqa
}
def _download(url: str, root: str = os.path.expanduser('~/.cache/clip')):
"""Download clip pretrained weights."""
os.makedirs(root, exist_ok=True)
filename = os.path.basename(url)
expected_sha256 = url.split('/')[-2]
download_target = os.path.join(root, filename)
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(
f'{download_target} exists and is not a regular file')
if os.path.isfile(download_target):
if hashlib.sha256(open(download_target,
'rb').read()).hexdigest() == expected_sha256:
return download_target
else:
warnings.warn(
f'{download_target} exists, but the SHA256 checksum does not\
match; re-downloading the file')
with urllib.request.urlopen(url) as source, open(download_target,
'wb') as output:
with tqdm(
total=int(source.info().get('Content-Length')),
ncols=80) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
if hashlib.sha256(open(download_target,
'rb').read()).hexdigest() != expected_sha256:
raise RuntimeError(
'Model has been downloaded but the SHA256 checksum does not not\
match')
return download_target
def available_models():
"""Returns a list of available models."""
return list(_MODELS.keys())
def load(name: str,
device: Union[str, torch.device] = 'cuda'
if torch.cuda.is_available() else 'cpu',
jit=True,
prompt_depth=0,
prompt_length=0):
"""Load target clip model."""
if name not in _MODELS:
raise RuntimeError(
f'Model {name} not found; available models = {available_models()}')
model_path = _download(_MODELS[name])
model = torch.jit.load(
model_path, map_location=device if jit else 'cpu').eval()
n_px = model.input_resolution.item()
transform = Compose([
Resize(n_px, interpolation=Image.BICUBIC),
CenterCrop(n_px),
lambda image: image.convert('RGB'),
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])
if not jit:
model = build_model(model.state_dict(), prompt_depth,
prompt_length).to(device)
return model, transform
# patch the device names
device_holder = torch.jit.trace(
lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
device_node = [
n for n in device_holder.graph.findAllNodes('prim::Constant')
if 'Device' in repr(n)
][-1]
def patch_device(module):
graphs = [module.graph] if hasattr(module, 'graph') else []
if hasattr(module, 'forward1'):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes('prim::Constant'):
if 'value' in node.attributeNames() and str(
node['value']).startswith('cuda'):
node.copyAttributes(device_node)
model.apply(patch_device)
patch_device(model.encode_image)
patch_device(model.encode_text)
# patch dtype to float32 on CPU
if device == 'cpu':
float_holder = torch.jit.trace(
lambda: torch.ones([]).float(), example_inputs=[])
float_input = list(float_holder.graph.findNode('aten::to').inputs())[1]
float_node = float_input.node()
def patch_float(module):
graphs = [module.graph] if hasattr(module, 'graph') else []
if hasattr(module, 'forward1'):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes('aten::to'):
inputs = list(node.inputs())
for i in [1, 2]:
# dtype can be the second or third argument to
# aten::to()
if inputs[i].node()['value'] == 5:
inputs[i].node().copyAttributes(float_node)
model.apply(patch_float)
patch_float(model.encode_image)
patch_float(model.encode_text)
model.float()
return model, transform
def load_custom(name: str,
device: Union[str, torch.device] = 'cuda'
if torch.cuda.is_available() else 'cpu',
jit=True,
n_px=224):
"""Load a customized clip model."""
if name not in _MODELS:
raise RuntimeError(
f'Model {name} not found; available models = {available_models()}')
model_path = _download(_MODELS[name])
model = torch.jit.load(
model_path, map_location=device if jit else 'cpu').eval()
# n_px = model.input_resolution.item()
transform = Compose([
Resize(n_px, interpolation=Image.BICUBIC),
CenterCrop(n_px),
lambda image: image.convert('RGB'),
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])
if not jit:
model = build_model(model.state_dict()).to(device)
return model, transform
# patch the device names
device_holder = torch.jit.trace(
lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
device_node = [
n for n in device_holder.graph.findAllNodes('prim::Constant')
if 'Device' in repr(n)
][-1]
def patch_device(module):
graphs = [module.graph] if hasattr(module, 'graph') else []
if hasattr(module, 'forward1'):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes('prim::Constant'):
if 'value' in node.attributeNames() and str(
node['value']).startswith('cuda'):
node.copyAttributes(device_node)
model.apply(patch_device)
patch_device(model.encode_image)
patch_device(model.encode_text)
# patch dtype to float32 on CPU
if device == 'cpu':
float_holder = torch.jit.trace(
lambda: torch.ones([]).float(), example_inputs=[])
float_input = list(float_holder.graph.findNode('aten::to').inputs())[1]
float_node = float_input.node()
def patch_float(module):
graphs = [module.graph] if hasattr(module, 'graph') else []
if hasattr(module, 'forward1'):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes('aten::to'):
inputs = list(node.inputs())
for i in [
1, 2
]: # dtype can be the second or third argument to
# aten::to()
if inputs[i].node()['value'] == 5:
inputs[i].node().copyAttributes(float_node)
model.apply(patch_float)
patch_float(model.encode_image)
patch_float(model.encode_text)
model.float()
return model, transform
def tokenize(texts: Union[str, List[str]], context_length: int = 77):
"""Convert texts to tokens."""
if isinstance(texts, str):
texts = [texts]
sot_token = _tokenizer.encoder['<|startoftext|>']
eot_token = _tokenizer.encoder['<|endoftext|>']
# encode each template text phrase
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token]
for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
raise RuntimeError(
f'Input {texts[i]} is too long for context length\
{context_length}')
result[i, :len(tokens)] = torch.tensor(tokens)
return result

View File

@ -0,0 +1,79 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import nn as nn
from torch.nn import functional as F
class LinearAttention(nn.Module):
"""Multi-Head linear attention proposed in "Transformers are RNNs".
Source: https://github.com/KU-CVLAB/CAT-Seg/blob/main/cat_seg/modeling/transformer/model.py#L247 # noqa
"""
def __init__(self, eps=1e-6):
super().__init__()
self.eps = eps
def forward(self, queries, keys, values):
"""
Args:
queries: [N, L, H, D]
keys: [N, S, H, D]
values: [N, S, H, D]
q_mask: [N, L]
kv_mask: [N, S]
Returns:
queried_values: (N, L, H, D)
"""
Q = F.elu(queries) + 1
K = F.elu(keys) + 1
v_length = values.size(1)
values = values / v_length # prevent fp16 overflow
KV = torch.einsum('nshd,nshv->nhdv', K, values) # (S,D)' @ S,V
Z = 1 / (torch.einsum('nlhd,nhd->nlh', Q, K.sum(dim=1)) + self.eps)
queried_values = torch.einsum('nlhd,nhdv,nlh->nlhv', Q, KV,
Z) * v_length
return queried_values.contiguous()
class FullAttention(nn.Module):
"""Multi-head scaled dot-product attention, a.k.a full attention.
Source: https://github.com/KU-CVLAB/CAT-Seg/blob/main/cat_seg/modeling/transformer/model.py#L276 # noqa
"""
def __init__(self, use_dropout=False, attention_dropout=0.1):
super().__init__()
self.use_dropout = use_dropout
self.dropout = nn.Dropout(attention_dropout)
def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
"""
Args:
queries: [N, L, H, D]
keys: [N, S, H, D]
values: [N, S, H, D]
q_mask: [N, L]
kv_mask: [N, S]
Returns:
queried_values: (N, L, H, D)
"""
# Compute the unnormalized attention and apply the masks
QK = torch.einsum('nlhd,nshd->nlsh', queries, keys)
if kv_mask is not None:
QK.masked_fill_(
~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]),
float('-inf'))
# Compute the attention and the weighted average
softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
A = torch.softmax(softmax_temp * QK, dim=2)
if self.use_dropout:
A = self.dropout(A)
queried_values = torch.einsum('nlsh,nshd->nlhd', A, values)
return queried_values.contiguous()

View File

@ -0,0 +1,160 @@
# Copyright (c) OpenMMLab. All rights reserved.
import gzip
import html
import os
from functools import lru_cache
import ftfy
import regex as re
@lru_cache()
def default_bpe():
"""Return default BPE vocabulary path."""
return os.path.join(
os.path.dirname(os.path.abspath(__file__)),
'bpe_vocab/bpe_simple_vocab_16e6.txt.gz')
@lru_cache()
def bytes_to_unicode():
"""Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings. This means you need a
large # of unicode characters in your vocab if you want to avoid UNKs. When
you're at something like a 10B token dataset you end up needing around 5K
for decent coverage. This is a significant percentage of your normal, say,
32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and
unicode strings. And avoids mapping to whitespace/control characters the
bpe code barfs on.
"""
bs = list(range(ord('!'),
ord('~') + 1)) + list(range(
ord('¡'),
ord('¬') + 1)) + list(range(ord('®'),
ord('ÿ') + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length
strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
"""Clean string."""
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
"""Clean whitespace in string."""
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
class SimpleTokenizer:
"""Customized Tokenizer implementation."""
def __init__(self, bpe_path: str = default_bpe()):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode('utf-8').split('\n')
merges = merges[1:49152 - 256 - 2 + 1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v + '</w>' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {
'<|startoftext|>': '<|startoftext|>',
'<|endoftext|>': '<|endoftext|>'
}
self.pat = re.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|\
'll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
def bpe(self, token):
"""Refer to bpe vocabulary dictionary."""
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + (token[-1] + '</w>', )
pairs = get_pairs(word)
if not pairs:
return token + '</w>'
while True:
bigram = min(
pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except ValueError:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[
i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
"""Encode text strings."""
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b]
for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token]
for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def decode(self, tokens):
"""Decoder tokens to strings."""
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode(
'utf-8', errors='replace').replace('</w>', ' ')
return text

View File

@ -0,0 +1,68 @@
# dataset settings
dataset_type = 'ADE20KDataset'
data_root = 'data/ade/ADEChallengeData2016'
crop_size = (384, 384)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(
type='RandomResize',
scale=(2048, 512),
ratio_range=(0.5, 2.0),
keep_ratio=True),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='PackSegInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(2048, 512), keep_ratio=True),
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', backend_args=None),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict(
batch_size=4,
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='images/training', seg_map_path='annotations/training'),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='images/validation',
seg_map_path='annotations/validation'),
pipeline=test_pipeline))
test_dataloader = val_dataloader
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
test_evaluator = val_evaluator

View File

@ -0,0 +1,62 @@
# dataset settings
dataset_type = 'COCOStuffDataset'
data_root = 'data/coco_stuff164k'
crop_size = (384, 384)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='PackSegInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(2048, 512), keep_ratio=True),
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', backend_args=None),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict(
batch_size=2,
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='images/train2017', seg_map_path='annotations/train2017'),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='images/val2017', seg_map_path='annotations/val2017'),
pipeline=test_pipeline))
test_dataloader = val_dataloader
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
test_evaluator = val_evaluator

View File

@ -0,0 +1,72 @@
# dataset settings
dataset_type = 'PascalContextDataset59'
data_root = 'data/VOCdevkit/VOC2010/'
img_scale = (520, 520)
crop_size = (384, 384)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(
type='RandomResize',
scale=img_scale,
ratio_range=(0.5, 2.0),
keep_ratio=True),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='PackSegInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=img_scale, keep_ratio=True),
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', backend_args=None),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict(
batch_size=4,
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='JPEGImages', seg_map_path='SegmentationClassContext'),
ann_file='ImageSets/SegmentationContext/train.txt',
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='JPEGImages', seg_map_path='SegmentationClassContext'),
ann_file='ImageSets/SegmentationContext/val.txt',
pipeline=test_pipeline))
test_dataloader = val_dataloader
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
test_evaluator = val_evaluator

View File

@ -0,0 +1,15 @@
default_scope = 'mmseg'
env_cfg = dict(
cudnn_benchmark=True,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'),
)
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer')
log_processor = dict(by_epoch=False)
log_level = 'INFO'
load_from = None
resume = False
tta_model = dict(type='SegTTAModel')

View File

@ -0,0 +1,24 @@
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
# learning policy
param_scheduler = [
dict(
type='PolyLR',
eta_min=1e-4,
power=0.9,
begin=0,
end=80000,
by_epoch=False)
]
# training schedule for 80k
train_cfg = dict(type='IterBasedTrainLoop', max_iters=80000, val_interval=8000)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=8000),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='SegVisualizationHook'))

View File

@ -0,0 +1,103 @@
_base_ = [
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py',
'../_base_/datasets/ade20k_384x384.py'
]
custom_imports = dict(imports=['cat_seg'])
norm_cfg = dict(type='SyncBN', requires_grad=True)
crop_size = (384, 384)
data_preprocessor = dict(
type='SegDataPreProcessor',
size=crop_size,
# due to the clip model, we do normalization in backbone forward()
bgr_to_rgb=True,
pad_val=0,
seg_pad_val=255)
# model_cfg
model = dict(
type='EncoderDecoder',
data_preprocessor=data_preprocessor,
backbone=dict(
type='CLIPOVCATSeg',
feature_extractor=dict(
type='ResNet',
depth=101,
# only use the first three layers
num_stages=3,
out_indices=(0, 1, 2),
dilations=(1, 1, 1),
strides=(1, 2, 2),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True,
init_cfg=dict(
type='Pretrained', checkpoint='torchvision://resnet101'),
),
train_class_json='data/ade150.json',
test_class_json='data/ade150.json',
clip_pretrained='ViT-B/16',
clip_finetune='attention',
),
neck=dict(
type='CATSegAggregator',
appearance_guidance_dim=1024,
num_layers=2,
pooling_size=(1, 1),
),
decode_head=dict(
type='CATSegHead',
in_channels=128,
channels=128,
num_classes=150,
embed_dims=128,
decoder_dims=(64, 32),
decoder_guidance_dims=(512, 256),
decoder_guidance_proj_dims=(32, 16),
loss_decode=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0,
avg_non_ignore=True)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='slide', stride=crop_size, crop_size=crop_size))
# dataset settings
train_dataloader = dict(
batch_size=2,
num_workers=4,
)
# training schedule for 80k
train_cfg = dict(type='IterBasedTrainLoop', max_iters=80000, val_interval=4000)
default_hooks = dict(
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=4000),
visualization=dict(type='SegVisualizationHook', draw=True, interval=4000))
# optimizer
optim_wrapper = dict(
_delete_=True,
type='OptimWrapper',
optimizer=dict(
type='AdamW', lr=0.0002, betas=(0.9, 0.999), weight_decay=0.0001),
paramwise_cfg=dict(
custom_keys={
'backbone.feature_extractor': dict(lr_mult=0.01),
'backbone.clip_model.visual': dict(lr_mult=0.01)
}))
# learning policy
param_scheduler = [
# Use a linear warm-up at [0, 100) iterations
dict(type='LinearLR', start_factor=0.01, by_epoch=False, begin=0, end=500),
# Use a cosine learning rate at [100, 900) iterations
dict(
type='CosineAnnealingLR',
T_max=79500,
by_epoch=False,
begin=500,
end=80000),
]

View File

@ -0,0 +1,103 @@
_base_ = [
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py',
'../_base_/datasets/pascal_context_59_384x384.py'
]
custom_imports = dict(imports=['cat_seg'])
norm_cfg = dict(type='SyncBN', requires_grad=True)
crop_size = (384, 384)
data_preprocessor = dict(
type='SegDataPreProcessor',
size=crop_size,
# due to the clip model, we do normalization in backbone forward()
bgr_to_rgb=True,
pad_val=0,
seg_pad_val=255)
# model_cfg
model = dict(
type='EncoderDecoder',
data_preprocessor=data_preprocessor,
backbone=dict(
type='CLIPOVCATSeg',
feature_extractor=dict(
type='ResNet',
depth=101,
# only use the first three layers
num_stages=3,
out_indices=(0, 1, 2),
dilations=(1, 1, 1),
strides=(1, 2, 2),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True,
init_cfg=dict(
type='Pretrained', checkpoint='torchvision://resnet101'),
),
train_class_json='data/pc59.json',
test_class_json='data/pc59.json',
clip_pretrained='ViT-B/16',
clip_finetune='attention',
),
neck=dict(
type='CATSegAggregator',
appearance_guidance_dim=1024,
num_layers=2,
pooling_size=(1, 1),
),
decode_head=dict(
type='CATSegHead',
in_channels=128,
channels=128,
num_classes=59,
embed_dims=128,
decoder_dims=(64, 32),
decoder_guidance_dims=(512, 256),
decoder_guidance_proj_dims=(32, 16),
loss_decode=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0,
avg_non_ignore=True)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='slide', stride=crop_size, crop_size=crop_size))
# dataset settings
train_dataloader = dict(
batch_size=2,
num_workers=4,
)
# training schedule for 80k
train_cfg = dict(type='IterBasedTrainLoop', max_iters=80000, val_interval=4000)
default_hooks = dict(
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=4000),
visualization=dict(type='SegVisualizationHook', draw=True, interval=4000))
# optimizer
optim_wrapper = dict(
_delete_=True,
type='OptimWrapper',
optimizer=dict(
type='AdamW', lr=0.0002, betas=(0.9, 0.999), weight_decay=0.0001),
paramwise_cfg=dict(
custom_keys={
'backbone.feature_extractor': dict(lr_mult=0.01),
'backbone.clip_model.visual': dict(lr_mult=0.01)
}))
# learning policy
param_scheduler = [
# Use a linear warm-up at [0, 100) iterations
dict(type='LinearLR', start_factor=0.01, by_epoch=False, begin=0, end=500),
# Use a cosine learning rate at [100, 900) iterations
dict(
type='CosineAnnealingLR',
T_max=79500,
by_epoch=False,
begin=500,
end=80000),
]

View File

@ -0,0 +1,102 @@
_base_ = [
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py',
'../_base_/datasets/coco-stuff164k_384x384.py'
]
custom_imports = dict(imports=['cat_seg'])
norm_cfg = dict(type='SyncBN', requires_grad=True)
crop_size = (384, 384)
data_preprocessor = dict(
type='SegDataPreProcessor',
size=crop_size,
# due to the clip model, we do normalization in backbone forward()
bgr_to_rgb=True,
pad_val=0,
seg_pad_val=255)
# model_cfg
model = dict(
type='EncoderDecoder',
data_preprocessor=data_preprocessor,
backbone=dict(
type='CLIPOVCATSeg',
feature_extractor=dict(
type='ResNet',
depth=101,
# only use the first three layers
num_stages=3,
out_indices=(0, 1, 2),
dilations=(1, 1, 1),
strides=(1, 2, 2),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True,
init_cfg=dict(
type='Pretrained', checkpoint='torchvision://resnet101'),
),
train_class_json='data/coco.json',
test_class_json='data/coco.json',
clip_pretrained='ViT-B/16',
clip_finetune='attention',
),
neck=dict(
type='CATSegAggregator',
appearance_guidance_dim=1024,
num_layers=2,
),
decode_head=dict(
type='CATSegHead',
in_channels=128,
channels=128,
num_classes=171,
embed_dims=128,
decoder_dims=(64, 32),
decoder_guidance_dims=(512, 256),
decoder_guidance_proj_dims=(32, 16),
loss_decode=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0,
avg_non_ignore=True)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='slide', stride=crop_size, crop_size=crop_size))
# dataset settings
train_dataloader = dict(
batch_size=2,
num_workers=4,
)
# training schedule for 80k
train_cfg = dict(type='IterBasedTrainLoop', max_iters=80000, val_interval=4000)
default_hooks = dict(
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=4000),
visualization=dict(type='SegVisualizationHook', draw=True, interval=4000))
# optimizer
optim_wrapper = dict(
_delete_=True,
type='OptimWrapper',
optimizer=dict(
type='AdamW', lr=0.0002, betas=(0.9, 0.999), weight_decay=0.0001),
paramwise_cfg=dict(
custom_keys={
'backbone.feature_extractor': dict(lr_mult=0.01),
'backbone.clip_model.visual': dict(lr_mult=0.01)
}))
# learning policy
param_scheduler = [
# Use a linear warm-up at [0, 100) iterations
dict(type='LinearLR', start_factor=0.01, by_epoch=False, begin=0, end=500),
# Use a cosine learning rate at [100, 900) iterations
dict(
type='CosineAnnealingLR',
T_max=79500,
by_epoch=False,
begin=500,
end=80000),
]

View File

@ -0,0 +1,11 @@
_base_ = './catseg_vitl-swin-b_4xb1-warmcoslr2e-4_adamw-80k_coco-stuff164k_384x384.py' # noqa
model = dict(
backbone=dict(
type='CLIPOVCATSeg',
clip_pretrained='ViT-G',
custom_clip_weights='~/CLIP-ViT-bigG-14-laion2B-39B-b160k'),
neck=dict(
text_guidance_dim=1280,
appearance_guidance_dim=512,
))

View File

@ -0,0 +1,11 @@
_base_ = './catseg_vitl-swin-b_4xb1-warmcoslr2e-4_adamw-80k_coco-stuff164k_384x384.py' # noqa
model = dict(
backbone=dict(
type='CLIPOVCATSeg',
clip_pretrained='ViT-H',
custom_clip_weights='~/CLIP-ViT-H-14-laion2B-s32B-b79K'),
neck=dict(
text_guidance_dim=1024,
appearance_guidance_dim=512,
))

View File

@ -0,0 +1,72 @@
_base_ = './catseg_vitb-r101_4xb2-warmcoslr2e-4-adamw-80k_coco-stuff164k-384x384.py' # noqa
pretrained = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window12_384_20220317-55b0104a.pth' # noqa
crop_size = (384, 384)
data_preprocessor = dict(size=crop_size)
model = dict(
backbone=dict(
type='CLIPOVCATSeg',
feature_extractor=dict(
_delete_=True,
type='SwinTransformer',
pretrain_img_size=384,
embed_dims=128,
depths=[2, 2, 18],
num_heads=[4, 8, 16],
window_size=12,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.3,
patch_norm=True,
out_indices=(0, 1, 2),
init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
clip_pretrained='ViT-L/14@336px',
),
neck=dict(
text_guidance_dim=768,
appearance_guidance_dim=512,
),
decode_head=dict(
embed_dims=128,
decoder_guidance_dims=(256, 128),
))
# dataset settings
train_dataloader = dict(
batch_size=1,
num_workers=2,
)
# training schedule for 80k
train_cfg = dict(type='IterBasedTrainLoop', max_iters=80000, val_interval=4000)
default_hooks = dict(
visualization=dict(type='SegVisualizationHook', draw=True, interval=4000))
# optimizer
optim_wrapper = dict(
_delete_=True,
type='OptimWrapper',
optimizer=dict(
type='AdamW', lr=0.0002, betas=(0.9, 0.999), weight_decay=0.0001),
paramwise_cfg=dict(
custom_keys={
'backbone.feature_extractor': dict(lr_mult=0.01),
'backbone.clip_model.visual': dict(lr_mult=0.01)
}))
# learning policy
param_scheduler = [
# Use a linear warm-up at [0, 100) iterations
dict(type='LinearLR', start_factor=0.01, by_epoch=False, begin=0, end=500),
# Use a cosine learning rate at [100, 900) iterations
dict(
type='CosineAnnealingLR',
T_max=79500,
by_epoch=False,
begin=500,
end=80000),
]

View File

@ -0,0 +1,7 @@
from .clip_templates import (IMAGENET_TEMPLATES, IMAGENET_TEMPLATES_SELECT,
IMAGENET_TEMPLATES_SELECT_CLIP, ViLD_templates)
__all__ = [
'IMAGENET_TEMPLATES', 'IMAGENET_TEMPLATES_SELECT',
'IMAGENET_TEMPLATES_SELECT_CLIP', 'ViLD_templates'
]