mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers. ## Motivation Support CAT-Seg open-vocabulary semantic segmentation (CVPR2023). ## Modification Support CAT-Seg open-vocabulary semantic segmentation (CVPR2023). - [x] Support CAT-Seg model training. - [x] CLIP model based `backbone` (R101 & Swin-B), aggregation layers based `neck`, and `decoder` head. - [x] Provide customized coco-stuff164k_384x384 training configs. - [x] Language model supports for `open vocabulary` (OV) tasks. - [x] Support CLIP-based pretrained language model (LM) inference. - [x] Add commonly used prompts templates. - [x] Add README tutorials. - [x] Add zero-shot testing scripts. **Working on the following tasks.** - [x] Add unit test. ## BC-breaking (Optional) Does the modification introduce changes that break the backward-compatibility of the downstream repos? If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR. ## Use cases (Optional) If this PR introduces a new feature, it is better to list some use cases here, and update the documentation. ## Checklist 1. Pre-commit or other linting tools are used to fix the potential lint issues. 2. The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness. 3. If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDet or MMDet3D. 4. The documentation has been modified accordingly, like docstring or example tutorials. --------- Co-authored-by: xiexinch <xiexinch@outlook.com>
294 lines
12 KiB
Python
294 lines
12 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import json
|
|
import os
|
|
from typing import List
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from huggingface_hub.utils._errors import LocalEntryNotFoundError
|
|
from mmengine.model import BaseModule
|
|
|
|
from mmseg.registry import MODELS
|
|
from mmseg.utils import ConfigType
|
|
from ..utils import clip_wrapper
|
|
from ..utils.clip_templates import (IMAGENET_TEMPLATES,
|
|
IMAGENET_TEMPLATES_SELECT)
|
|
|
|
|
|
@MODELS.register_module()
|
|
class CLIPOVCATSeg(BaseModule):
|
|
"""CLIP based Open Vocabulary CAT-Seg model backbone.
|
|
|
|
This backbone is the modified implementation of `CAT-Seg Backbone
|
|
<https://arxiv.org/abs/2303.11797>`_. It combines the CLIP model and
|
|
another feature extractor, a.k.a the appearance guidance extractor
|
|
in the original `CAT-Seg`.
|
|
|
|
Args:
|
|
feature_extractor (ConfigType): Appearance guidance extractor
|
|
config dict.
|
|
train_class_json (str): The training class json file.
|
|
test_class_json (str): The path to test class json file.
|
|
clip_pretrained (str): The pre-trained clip type.
|
|
clip_finetune (str): The finetuning settings of clip model.
|
|
custom_clip_weights (str): The custmized clip weights directory. When
|
|
encountering huggingface model download errors, you can manually
|
|
download the pretrained weights.
|
|
backbone_multiplier (float): The learning rate multiplier.
|
|
Default: 0.01.
|
|
prompt_depth (int): The prompt depth. Default: 0.
|
|
prompt_length (int): The prompt length. Default: 0.
|
|
prompt_ensemble_type (str): The prompt ensemble type.
|
|
Default: "imagenet".
|
|
pixel_mean (List[float]): The pixel mean for feature extractor.
|
|
pxiel_std (List[float]): The pixel std for feature extractor.
|
|
clip_pixel_mean (List[float]): The pixel mean for clip model.
|
|
clip_pxiel_std (List[float]): The pixel std for clip model.
|
|
clip_img_feat_size: (List[int]: Clip image embedding size from
|
|
image encoder.
|
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
|
Default: None.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
feature_extractor: ConfigType,
|
|
train_class_json: str,
|
|
test_class_json: str,
|
|
clip_pretrained: str,
|
|
clip_finetune: str,
|
|
custom_clip_weights: str = None,
|
|
backbone_multiplier=0.01,
|
|
prompt_depth: int = 0,
|
|
prompt_length: int = 0,
|
|
prompt_ensemble_type: str = 'imagenet',
|
|
pixel_mean: List[float] = [123.675, 116.280, 103.530],
|
|
pixel_std: List[float] = [58.395, 57.120, 57.375],
|
|
clip_pixel_mean: List[float] = [
|
|
122.7709383, 116.7460125, 104.09373615
|
|
],
|
|
clip_pixel_std: List[float] = [68.5005327, 66.6321579, 70.3231630],
|
|
clip_img_feat_size: List[int] = [24, 24],
|
|
init_cfg=None):
|
|
super().__init__(init_cfg=init_cfg)
|
|
# normalization parameters
|
|
self.register_buffer('pixel_mean',
|
|
torch.Tensor(pixel_mean).view(1, -1, 1, 1), False)
|
|
self.register_buffer('pixel_std',
|
|
torch.Tensor(pixel_std).view(1, -1, 1, 1), False)
|
|
self.register_buffer('clip_pixel_mean',
|
|
torch.Tensor(clip_pixel_mean).view(1, -1, 1, 1),
|
|
False)
|
|
self.register_buffer('clip_pixel_std',
|
|
torch.Tensor(clip_pixel_std).view(1, -1, 1, 1),
|
|
False)
|
|
self.clip_resolution = (
|
|
384, 384) if clip_pretrained == 'ViT-B/16' else (336, 336)
|
|
# modified clip image encoder with fixed size dense output
|
|
self.clip_img_feat_size = clip_img_feat_size
|
|
|
|
# prepare clip templates
|
|
self.prompt_ensemble_type = prompt_ensemble_type
|
|
if self.prompt_ensemble_type == 'imagenet_select':
|
|
prompt_templates = IMAGENET_TEMPLATES_SELECT
|
|
elif self.prompt_ensemble_type == 'imagenet':
|
|
prompt_templates = IMAGENET_TEMPLATES
|
|
elif self.prompt_ensemble_type == 'single':
|
|
prompt_templates = [
|
|
'A photo of a {} in the scene',
|
|
]
|
|
else:
|
|
raise NotImplementedError
|
|
self.prompt_templates = prompt_templates
|
|
|
|
# build the feature extractor
|
|
self.feature_extractor = MODELS.build(feature_extractor)
|
|
|
|
# build CLIP model
|
|
with open(train_class_json) as f_in:
|
|
self.class_texts = json.load(f_in)
|
|
with open(test_class_json) as f_in:
|
|
self.test_class_texts = json.load(f_in)
|
|
assert self.class_texts is not None
|
|
if self.test_class_texts is None:
|
|
self.test_class_texts = self.class_texts
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
self.tokenizer = None
|
|
if clip_pretrained == 'ViT-G' or clip_pretrained == 'ViT-H':
|
|
# for OpenCLIP models
|
|
import open_clip
|
|
name, pretrain = (
|
|
'ViT-H-14',
|
|
'laion2b_s32b_b79k') if clip_pretrained == 'ViT-H' else (
|
|
'ViT-bigG-14', 'laion2b_s39b_b160k')
|
|
try:
|
|
open_clip_model = open_clip.create_model_and_transforms(
|
|
name,
|
|
pretrained=pretrain,
|
|
device=device,
|
|
force_image_size=336,
|
|
)
|
|
clip_model, _, clip_preprocess = open_clip_model
|
|
except ConnectionError or LocalEntryNotFoundError as e:
|
|
print(f'Has {e} when loading weights from huggingface!')
|
|
print(
|
|
f'Will load {pretrain} weights from {custom_clip_weights}.'
|
|
)
|
|
assert custom_clip_weights is not None, 'Please specify custom weights directory.' # noqa
|
|
assert os.path.exists(
|
|
os.path.join(custom_clip_weights,
|
|
'open_clip_pytorch_model.bin')
|
|
), 'Please provide a valid directory for manually downloaded model.' # noqa
|
|
open_clip_model = open_clip.create_model_and_transforms(
|
|
name,
|
|
pretrained=None,
|
|
device='cpu',
|
|
force_image_size=336,
|
|
)
|
|
clip_model, _, clip_preprocess = open_clip_model
|
|
|
|
open_clip.load_checkpoint(
|
|
clip_model,
|
|
os.path.expanduser(
|
|
os.path.join(custom_clip_weights,
|
|
'open_clip_pytorch_model.bin')))
|
|
clip_model.to(torch.device(device))
|
|
|
|
self.tokenizer = open_clip.get_tokenizer(name)
|
|
else:
|
|
# for OpenAI models
|
|
clip_model, clip_preprocess = clip_wrapper.load(
|
|
clip_pretrained,
|
|
device=device,
|
|
jit=False,
|
|
prompt_depth=prompt_depth,
|
|
prompt_length=prompt_length)
|
|
|
|
# pre-encode classes text prompts
|
|
text_features = self.class_embeddings(self.class_texts,
|
|
prompt_templates, clip_model,
|
|
device).permute(1, 0, 2).float()
|
|
text_features_test = self.class_embeddings(self.test_class_texts,
|
|
prompt_templates,
|
|
clip_model,
|
|
device).permute(1, 0,
|
|
2).float()
|
|
self.register_buffer('text_features', text_features, False)
|
|
self.register_buffer('text_features_test', text_features_test, False)
|
|
|
|
# prepare CLIP model finetune
|
|
self.clip_finetune = clip_finetune
|
|
self.clip_model = clip_model.float()
|
|
self.clip_preprocess = clip_preprocess
|
|
|
|
for name, params in self.clip_model.named_parameters():
|
|
if 'visual' in name:
|
|
if clip_finetune == 'prompt':
|
|
params.requires_grad = True if 'prompt' in name else False
|
|
elif clip_finetune == 'attention':
|
|
if 'attn' in name or 'position' in name:
|
|
params.requires_grad = True
|
|
else:
|
|
params.requires_grad = False
|
|
elif clip_finetune == 'full':
|
|
params.requires_grad = True
|
|
else:
|
|
params.requires_grad = False
|
|
else:
|
|
params.requires_grad = False
|
|
|
|
finetune_backbone = backbone_multiplier > 0.
|
|
for name, params in self.feature_extractor.named_parameters():
|
|
if 'norm0' in name:
|
|
params.requires_grad = False
|
|
else:
|
|
params.requires_grad = finetune_backbone
|
|
|
|
@torch.no_grad()
|
|
def class_embeddings(self,
|
|
classnames,
|
|
templates,
|
|
clip_model,
|
|
device='cpu'):
|
|
"""Convert class names to text embeddings by clip model.
|
|
|
|
Args:
|
|
classnames (list): loaded from json file.
|
|
templates (dict): text template.
|
|
clip_model (nn.Module): prepared clip model.
|
|
device (str | torch.device): loading device of text
|
|
encoder results.
|
|
"""
|
|
zeroshot_weights = []
|
|
for classname in classnames:
|
|
if ', ' in classname:
|
|
classname_splits = classname.split(', ')
|
|
texts = []
|
|
for template in templates:
|
|
for cls_split in classname_splits:
|
|
texts.append(template.format(cls_split))
|
|
else:
|
|
texts = [template.format(classname)
|
|
for template in templates] # format with class
|
|
if self.tokenizer is not None:
|
|
texts = self.tokenizer(texts).to(device)
|
|
else:
|
|
texts = clip_wrapper.tokenize(texts).to(device)
|
|
class_embeddings = clip_model.encode_text(texts)
|
|
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
|
if len(templates) != class_embeddings.shape[0]:
|
|
class_embeddings = class_embeddings.reshape(
|
|
len(templates), -1, class_embeddings.shape[-1]).mean(dim=1)
|
|
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
|
class_embedding = class_embeddings
|
|
zeroshot_weights.append(class_embedding)
|
|
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
|
|
return zeroshot_weights
|
|
|
|
def custom_normalize(self, inputs):
|
|
"""Input normalization for clip model and feature extractor
|
|
respectively.
|
|
|
|
Args:
|
|
inputs: batched input images.
|
|
"""
|
|
# clip images
|
|
batched_clip = (inputs - self.clip_pixel_mean) / self.clip_pixel_std
|
|
batched_clip = F.interpolate(
|
|
batched_clip,
|
|
size=self.clip_resolution,
|
|
mode='bilinear',
|
|
align_corners=False)
|
|
# feature extractor images
|
|
batched = (inputs - self.pixel_mean) / self.pixel_std
|
|
return batched, batched_clip
|
|
|
|
def forward(self, inputs):
|
|
"""
|
|
Args:
|
|
inputs: minibatch image. (B, 3, H, W)
|
|
Returns:
|
|
outputs (dict):
|
|
'appearance_feat': list[torch.Tensor], w.r.t. out_indices of
|
|
`self.feature_extractor`.
|
|
'clip_text_feat': the text feature extracted by clip text encoder.
|
|
'clip_text_feat_test': the text feature extracted by clip text
|
|
encoder for testing.
|
|
'clip_img_feat': the image feature extracted clip image encoder.
|
|
"""
|
|
inputs, clip_inputs = self.custom_normalize(inputs)
|
|
outputs = dict()
|
|
# extract appearance guidance feature
|
|
outputs['appearance_feat'] = self.feature_extractor(inputs)
|
|
|
|
# extract clip features
|
|
outputs['clip_text_feat'] = self.text_features
|
|
outputs['clip_text_feat_test'] = self.text_features_test
|
|
clip_features = self.clip_model.encode_image(
|
|
clip_inputs, dense=True) # B, 577(24x24+1), C
|
|
B = clip_features.size(0)
|
|
outputs['clip_img_feat'] = clip_features[:, 1:, :].permute(
|
|
0, 2, 1).reshape(B, -1, *self.clip_img_feat_size)
|
|
|
|
return outputs
|