mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
## Motivation Support SAN for Open-Vocabulary Semantic Segmentation Paper: [Side Adapter Network for Open-Vocabulary Semantic Segmentation](https://arxiv.org/abs/2302.12242) official Code: [SAN](https://github.com/MendelXu/SAN) ## Modification - Added the parameters of backbone vit for implementing the image encoder of CLIP. - Added text encoder code. - Added segmentor multimodel encoder-decoder code for open-vocabulary semantic segmentation. - Added SideAdapterNetwork decode head code. - Added config files for train and inference. - Added tools for converting pretrained models. - Added loss implementation for mask classification model, such as SAN, Maskformer and remove dependency on mmdetection. - Added test units for text encoder, multimodel encoder-decoder, san decode head and hungarian_assigner. ## Use cases ### Convert Models **pretrained SAN model** The official pretrained model can be downloaded from [san_clip_vit_b_16.pth](https://huggingface.co/Mendel192/san/blob/main/san_vit_b_16.pth) and [san_clip_vit_large_14.pth](https://huggingface.co/Mendel192/san/blob/main/san_vit_large_14.pth). Use tools/model_converters/san2mmseg.py to convert offcial model into mmseg style. `python tools/model_converters/san2mmseg.py <MODEL_PATH> <OUTPUT_PATH>` **pretrained CLIP model** Use the CLIP model provided by openai to train SAN. The CLIP model can be download from [ViT-B-16.pt](https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt) and [ViT-L-14-336px.pt](https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt). Use tools/model_converters/clip2mmseg.py to convert model into mmseg style. `python tools/model_converters/clip2mmseg.py <MODEL_PATH> <OUTPUT_PATH>` ### Inference test san_vit-base-16 model on coco-stuff164k dataset `python tools/test.py ./configs/san/san-vit-b16_coco-stuff164k-640x640.py <TRAINED_MODEL_PATH>` ### Train test san_vit-base-16 model on coco-stuff164k dataset `python tools/train.py ./configs/san/san-vit-b16_coco-stuff164k-640x640.py --cfg-options model.pretrained=<PRETRAINED_MODEL_PATH>` ## Comparision Results ### Train on COCO-Stuff164k | | | mIoU | mAcc | pAcc | | --------------- | ----- | ----- | ----- | ----- | | san-vit-base16 | official | 41.93 | 56.73 | 67.69 | | | mmseg | 41.93 | 56.84 | 67.84 | | san-vit-large14 | official | 45.57 | 59.52 | 69.76 | | | mmseg | 45.78 | 59.61 | 69.21 | ### Evaluate on Pascal Context | | | mIoU | mAcc | pAcc | | --------------- | ----- | ----- | ----- | ----- | | san-vit-base16 | official | 54.05 | 72.96 | 77.77 | | | mmseg | 54.04 | 73.74 | 77.71 | | san-vit-large14 | official | 57.53 | 77.56 | 78.89 | | | mmseg | 56.89 | 76.96 | 78.74 | ### Evaluate on Voc12Aug | | | mIoU | mAcc | pAcc | | --------------- | ----- | ----- | ----- | ----- | | san-vit-base16 | official | 93.86 | 96.61 | 97.11 | | | mmseg | 94.58 | 97.01 | 97.38 | | san-vit-large14 | official | 95.17 | 97.61 | 97.63 | | | mmseg | 95.58 | 97.75 | 97.79 | --------- Co-authored-by: CastleDream <35064479+CastleDream@users.noreply.github.com> Co-authored-by: yeedrag <46050186+yeedrag@users.noreply.github.com> Co-authored-by: Yang-ChangHui <71805205+Yang-Changhui@users.noreply.github.com> Co-authored-by: Xu CAO <49406546+SheffieldCao@users.noreply.github.com> Co-authored-by: xiexinch <xiexinch@outlook.com> Co-authored-by: 小飞猪 <106524776+ooooo-create@users.noreply.github.com>
241 lines
7.7 KiB
Python
241 lines
7.7 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
"""CLIP tokenizer.
|
|
|
|
Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright
|
|
(c) 2021 OpenAI.
|
|
"""
|
|
import gzip
|
|
import html
|
|
import os
|
|
from functools import lru_cache
|
|
from typing import List, Union
|
|
|
|
import ftfy
|
|
import regex as re
|
|
import torch
|
|
|
|
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
|
|
|
|
|
@lru_cache()
|
|
def default_bpe():
|
|
return os.path.join(
|
|
os.path.dirname(os.path.abspath(__file__)),
|
|
'bpe_simple_vocab_16e6.txt.gz')
|
|
|
|
|
|
@lru_cache()
|
|
def bytes_to_unicode():
|
|
"""Returns list of utf-8 byte and a corresponding list of unicode strings.
|
|
|
|
The reversible bpe codes work on unicode strings. This means you need a
|
|
large # of unicode characters in your vocab if you want to avoid UNKs. When
|
|
you're at something like a 10B token dataset you end up needing around 5K
|
|
for decent coverage. This is a significant percentage of your normal, say,
|
|
32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and
|
|
unicode strings. And avoids mapping to whitespace/control characters the
|
|
bpe code barfs on.
|
|
"""
|
|
bs = list(range(ord('!'),
|
|
ord('~') + 1)) + list(range(
|
|
ord('¡'),
|
|
ord('¬') + 1)) + list(range(ord('®'),
|
|
ord('ÿ') + 1))
|
|
cs = bs[:]
|
|
n = 0
|
|
for b in range(2**8):
|
|
if b not in bs:
|
|
bs.append(b)
|
|
cs.append(2**8 + n)
|
|
n += 1
|
|
cs = [chr(n) for n in cs]
|
|
return dict(zip(bs, cs))
|
|
|
|
|
|
def get_pairs(word):
|
|
"""Return set of symbol pairs in a word.
|
|
|
|
Word is represented as tuple of symbols (symbols being variable-length
|
|
strings).
|
|
"""
|
|
pairs = set()
|
|
prev_char = word[0]
|
|
for char in word[1:]:
|
|
pairs.add((prev_char, char))
|
|
prev_char = char
|
|
return pairs
|
|
|
|
|
|
def basic_clean(text):
|
|
text = ftfy.fix_text(text)
|
|
text = html.unescape(html.unescape(text))
|
|
return text.strip()
|
|
|
|
|
|
def whitespace_clean(text):
|
|
text = re.sub(r'\s+', ' ', text)
|
|
text = text.strip()
|
|
return text
|
|
|
|
|
|
class SimpleTokenizer:
|
|
|
|
def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
|
|
self.byte_encoder = bytes_to_unicode()
|
|
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
|
merges = gzip.open(bpe_path).read().decode('utf-8').split('\n')
|
|
merges = merges[1:49152 - 256 - 2 + 1]
|
|
merges = [tuple(merge.split()) for merge in merges]
|
|
vocab = list(bytes_to_unicode().values())
|
|
vocab = vocab + [v + '</w>' for v in vocab]
|
|
for merge in merges:
|
|
vocab.append(''.join(merge))
|
|
if not special_tokens:
|
|
special_tokens = ['<start_of_text>', '<end_of_text>']
|
|
else:
|
|
special_tokens = ['<start_of_text>', '<end_of_text>'
|
|
] + special_tokens
|
|
vocab.extend(special_tokens)
|
|
self.encoder = dict(zip(vocab, range(len(vocab))))
|
|
self.decoder = {v: k for k, v in self.encoder.items()}
|
|
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
|
self.cache = {t: t for t in special_tokens}
|
|
special = '|'.join(special_tokens)
|
|
self.pat = re.compile(
|
|
special +
|
|
r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
|
re.IGNORECASE)
|
|
|
|
self.vocab_size = len(self.encoder)
|
|
self.all_special_ids = [self.encoder[t] for t in special_tokens]
|
|
|
|
def bpe(self, token):
|
|
if token in self.cache:
|
|
return self.cache[token]
|
|
word = tuple(token[:-1]) + (token[-1] + '</w>', )
|
|
pairs = get_pairs(word)
|
|
|
|
if not pairs:
|
|
return token + '</w>'
|
|
|
|
while True:
|
|
bigram = min(
|
|
pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
|
if bigram not in self.bpe_ranks:
|
|
break
|
|
first, second = bigram
|
|
new_word = []
|
|
i = 0
|
|
while i < len(word):
|
|
try:
|
|
j = word.index(first, i)
|
|
new_word.extend(word[i:j])
|
|
i = j
|
|
except: # noqa: E722, E261
|
|
new_word.extend(word[i:])
|
|
break
|
|
|
|
if word[i] == first and i < len(word) - 1 and word[
|
|
i + 1] == second:
|
|
new_word.append(first + second)
|
|
i += 2
|
|
else:
|
|
new_word.append(word[i])
|
|
i += 1
|
|
new_word = tuple(new_word)
|
|
word = new_word
|
|
if len(word) == 1:
|
|
break
|
|
else:
|
|
pairs = get_pairs(word)
|
|
word = ' '.join(word)
|
|
self.cache[token] = word
|
|
return word
|
|
|
|
def encode(self, text):
|
|
bpe_tokens = []
|
|
text = whitespace_clean(basic_clean(text)).lower()
|
|
for token in re.findall(self.pat, text):
|
|
token = ''.join(self.byte_encoder[b]
|
|
for b in token.encode('utf-8'))
|
|
bpe_tokens.extend(self.encoder[bpe_token]
|
|
for bpe_token in self.bpe(token).split(' '))
|
|
return bpe_tokens
|
|
|
|
def decode(self, tokens):
|
|
text = ''.join([self.decoder[token] for token in tokens])
|
|
text = bytearray([self.byte_decoder[c] for c in text]).decode(
|
|
'utf-8', errors='replace').replace('</w>', ' ')
|
|
return text
|
|
|
|
|
|
_tokenizer = SimpleTokenizer()
|
|
|
|
|
|
def decode(output_ids: torch.Tensor):
|
|
output_ids = output_ids.cpu().numpy()
|
|
return _tokenizer.decode(output_ids)
|
|
|
|
|
|
def tokenize(texts: Union[str, List[str]],
|
|
context_length: int = 77) -> torch.LongTensor:
|
|
"""Returns the tokenized representation of given input string(s)
|
|
|
|
Parameters
|
|
----------
|
|
texts : Union[str, List[str]]
|
|
An input string or a list of input strings to tokenize
|
|
context_length : int
|
|
The context length to use; all CLIP models use 77 as the context length
|
|
|
|
Returns
|
|
-------
|
|
A two-dimensional tensor containing the resulting tokens,
|
|
shape = [number of input strings, context_length]
|
|
"""
|
|
if isinstance(texts, str):
|
|
texts = [texts]
|
|
|
|
sot_token = _tokenizer.encoder['<start_of_text>']
|
|
eot_token = _tokenizer.encoder['<end_of_text>']
|
|
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token]
|
|
for text in texts]
|
|
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
|
|
|
for i, tokens in enumerate(all_tokens):
|
|
if len(tokens) > context_length:
|
|
tokens = tokens[:context_length] # Truncate
|
|
tokens[-1] = eot_token
|
|
result[i, :len(tokens)] = torch.tensor(tokens)
|
|
|
|
return result
|
|
|
|
|
|
class HFTokenizer:
|
|
"""HuggingFace tokenizer wrapper."""
|
|
|
|
def __init__(self, tokenizer_name: str):
|
|
from transformers import AutoTokenizer
|
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
|
|
|
def save_pretrained(self, dest):
|
|
self.tokenizer.save_pretrained(dest)
|
|
|
|
def __call__(self,
|
|
texts: Union[str, List[str]],
|
|
context_length: int = 77) -> torch.Tensor:
|
|
# same cleaning as for default tokenizer, except lowercasing
|
|
# adding lower (for case-sensitive tokenizers) will make it
|
|
# more robust but less sensitive to nuance
|
|
if isinstance(texts, str):
|
|
texts = [texts]
|
|
texts = [whitespace_clean(basic_clean(text)) for text in texts]
|
|
input_ids = self.tokenizer(
|
|
texts,
|
|
return_tensors='pt',
|
|
max_length=context_length,
|
|
padding='max_length',
|
|
truncation=True,
|
|
).input_ids
|
|
return input_ids
|