[Feature] Support Chinese CLIP. ()

* support cn-clip

* update README

* Update progress bar

* update order of category

* fix lint

* update

* update readme and metafile

* update

* update docstring

* refactor tokenizer

* fix lint

* Update README and progress bar

---------

Co-authored-by: mzr1996 <mzr1996@163.com>
pull/1586/head
Yixiao Fang 2023-05-22 15:46:13 +08:00 committed by GitHub
parent d04ef8a29e
commit 1e478462b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 1469 additions and 11 deletions

View File

@ -0,0 +1,69 @@
# ChineseCLIP
> [Chinese CLIP: Contrastive Vision-Language Pretraining in Chinese](https://arxiv.org/abs/2211.01335)
<!-- [ALGORITHM] -->
## Abstract
The tremendous success of CLIP (Radford et al., 2021) has promoted the research and application of contrastive learning for vision-language pretraining. In this work, we construct a large-scale dataset of image-text pairs in Chinese, where most data are retrieved from publicly available datasets, and we pretrain Chinese CLIP models on the new dataset. We develop 5 Chinese CLIP models of multiple sizes, spanning from 77 to 958 million parameters. Furthermore, we propose a two-stage pretraining method, where the model is first trained with the image encoder frozen and then trained with all parameters being optimized, to achieve enhanced model performance. Our comprehensive experiments demonstrate that Chinese CLIP can achieve the state-of-the-art performance on MUGE, Flickr30K-CN, and COCO-CN in the setups of zero-shot learning and finetuning, and it is able to achieve competitive performance in zero-shot image classification based on the evaluation on the ELEVATER benchmark (Li et al., 2022). We have released our codes, models, and demos in https://github.com/OFA-Sys/Chinese-CLIP
<div align=center>
<img src="https://github.com/open-mmlab/mmpretrain/assets/36138628/4d05e51f-d834-4ef5-bbf0-0e2f80fea461" width="80%"/>
</div>
## How to use it?
<!-- [TABS-BEGIN] -->
**Use the model for zero-shot classification**
```python
from mmpretrain import ImageClassificationInferencer
inferencer = ImageClassificationInferencer(
'cn-clip_resnet50_zeroshot-cls_cifar100',
pretrained=True,
classes=['鸟', '狗', '猫', '蛇'],
text_prototype=['鸟', '狗', '猫', '蛇'],
)
prediction = inferencer('./demo/bird.JPEG')[0]
print('Results:', prediction['pred_class'])
```
**Train/Test Command**
Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset).
Test:
```shell
python tools/test.py configs/chinese_clip/cn-clip_resnet50_zeroshot-cls_cifar100.py https://download.openmmlab.com/mmpretrain/v1.0/chinese_clip/cn-clip_resnet50_3rdparty_20230519-6a2b3eb2.pth
```
<!-- [TABS-END] -->
## Models and results
### Image Classification on CIFAR100
| Model | Params (M) | Top-1 (%) | Config | Download |
| :---------------------------------------------- | :--------: | :-------: | :------------------------------------------------------: | :----------------------------------------------------------------------------: |
| `cn-clip_resnet50_zeroshot-cls_cifar100`\* | 77.00 | 40.70 | [config](cn-clip_resnet50_zeroshot-cls_cifar100.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/chinese_clip/cn-clip_resnet50_3rdparty_20230519-6a2b3eb2.pth) |
| `cn-clip_vit-base-p16_zeroshot-cls_cifar100`\* | 188.00 | 64.50 | [config](cn-clip_vit-base-p16_zeroshot-cls_cifar100.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/chinese_clip/cn-clip_vit-base-p16_3rdparty_20230519-37fbc59e.pth) |
| `cn-clip_vit-large-p14_zeroshot-cls_cifar100`\* | 406.00 | 74.80 | [config](cn-clip_vit-large-p14_zeroshot-cls_cifar100.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/chinese_clip/cn-clip_vit-large-p14_3rdparty_20230519-3f844503.pth) |
| `cn-clip_vit-huge-p14_zeroshot-cls_cifar100`\* | 958.00 | 79.10 | [config](cn-clip_vit-huge-p14_zeroshot-cls_cifar100.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/chinese_clip/cn-clip_vit-huge-p14_3rdparty_20230519-e4f49b00.pth) |
*Models with * are converted from the [official repo](https://github.com/OFA-Sys/Chinese-CLIP). The config files of these models are only for inference. We haven't reprodcue the training results.*
## Citation
```bibtex
@article{chinese-clip,
title={Chinese CLIP: Contrastive Vision-Language Pretraining in Chinese},
author={Yang, An and Pan, Junshu and Lin, Junyang and Men, Rui and Zhang, Yichang and Zhou, Jingren and Zhou, Chang},
journal={arXiv preprint arXiv:2211.01335},
year={2022}
}
```

View File

@ -0,0 +1,72 @@
_base_ = '../_base_/default_runtime.py'
# data settings
data_preprocessor = dict(
type='MultiModalDataPreprocessor',
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
to_rgb=False,
)
test_pipeline = [
dict(type='Resize', scale=(224, 224), interpolation='bicubic'),
dict(
type='PackInputs',
meta_keys=['image_id', 'scale_factor'],
),
]
train_dataloader = None
test_dataloader = dict(
batch_size=32,
num_workers=8,
dataset=dict(
type='CIFAR100',
data_root='data/cifar100',
split='test',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
test_evaluator = dict(type='Accuracy', topk=(1, ))
# schedule settings
train_cfg = None
val_cfg = None
test_cfg = dict()
# model settings
model = dict(
type='ChineseCLIP',
vision_backbone=dict(
type='ModifiedResNet',
depth=50,
base_channels=64,
input_size=224,
num_attn_heads=32,
output_dim=1024,
),
text_backbone=dict(
type='BertModelCN',
config=dict(
vocab_size=21128,
pad_token_id=0,
add_type_embeddings=True,
attention_probs_dropout_prob=0.1,
hidden_act='gelu',
hidden_dropout_prob=0.1,
hidden_size=768,
initializer_range=0.02,
intermediate_size=3072,
max_position_embeddings=512,
num_attention_heads=12,
num_hidden_layers=3,
type_vocab_size=2,
layer_norm_eps=1e-12)),
tokenizer=dict(
type='FullTokenizer',
vocab_file= # noqa
'https://download.openmmlab.com/mmpretrain/v1.0/chinese_clip/vocab.txt'
),
proj_dim=1024,
text_prototype='cifar100',
)

View File

@ -0,0 +1,76 @@
_base_ = '../_base_/default_runtime.py'
# data settings
data_preprocessor = dict(
type='MultiModalDataPreprocessor',
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
to_rgb=False,
)
test_pipeline = [
dict(type='Resize', scale=(224, 224), interpolation='bicubic'),
dict(
type='PackInputs',
algorithm_keys=['text'],
meta_keys=['image_id', 'scale_factor'],
),
]
train_dataloader = None
test_dataloader = dict(
batch_size=32,
num_workers=8,
dataset=dict(
type='CIFAR100',
data_root='data/cifar100',
split='test',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
test_evaluator = dict(type='Accuracy', topk=(1, ))
# schedule settings
train_cfg = None
val_cfg = None
test_cfg = dict()
# model settings
model = dict(
type='ChineseCLIP',
vision_backbone=dict(
type='VisionTransformer',
arch='base',
img_size=224,
patch_size=16,
norm_cfg=dict(type='LN', eps=1e-5),
final_norm=True,
layer_cfgs=dict(act_cfg=dict(type='QuickGELU')),
pre_norm=True,
out_type='cls_token',
),
text_backbone=dict(
type='BertModelCN',
config=dict(
vocab_size=21128,
pad_token_id=0,
add_type_embeddings=True,
attention_probs_dropout_prob=0.1,
hidden_act='gelu',
hidden_dropout_prob=0.1,
hidden_size=768,
initializer_range=0.02,
intermediate_size=3072,
max_position_embeddings=512,
num_attention_heads=12,
num_hidden_layers=12,
type_vocab_size=2,
layer_norm_eps=1e-12)),
tokenizer=dict(
type='FullTokenizer',
vocab_file= # noqa
'https://download.openmmlab.com/mmpretrain/v1.0/chinese_clip/vocab.txt'
),
proj_dim=512,
text_prototype='cifar100',
)

View File

@ -0,0 +1,75 @@
_base_ = '../_base_/default_runtime.py'
# data settings
data_preprocessor = dict(
type='MultiModalDataPreprocessor',
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
to_rgb=False,
)
test_pipeline = [
dict(type='Resize', scale=(224, 224), interpolation='bicubic'),
dict(
type='PackInputs',
meta_keys=['image_id', 'scale_factor'],
),
]
train_dataloader = None
test_dataloader = dict(
batch_size=32,
num_workers=8,
dataset=dict(
type='CIFAR100',
data_root='data/cifar100',
split='test',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
test_evaluator = dict(type='Accuracy', topk=(1, ))
# schedule settings
train_cfg = None
val_cfg = None
test_cfg = dict()
# model settings
model = dict(
type='ChineseCLIP',
vision_backbone=dict(
type='VisionTransformer',
arch='huge',
img_size=224,
patch_size=14,
norm_cfg=dict(type='LN', eps=1e-5),
final_norm=True,
layer_cfgs=dict(act_cfg=dict(type='QuickGELU')),
pre_norm=True,
out_type='cls_token',
),
text_backbone=dict(
type='BertModelCN',
config=dict(
vocab_size=21128,
pad_token_id=0,
add_type_embeddings=True,
attention_probs_dropout_prob=0.1,
hidden_act='gelu',
hidden_dropout_prob=0.1,
hidden_size=1024,
initializer_range=0.02,
intermediate_size=4096,
max_position_embeddings=512,
num_attention_heads=16,
num_hidden_layers=24,
type_vocab_size=2,
layer_norm_eps=1e-12)),
tokenizer=dict(
type='FullTokenizer',
vocab_file= # noqa
'https://download.openmmlab.com/mmpretrain/v1.0/chinese_clip/vocab.txt'
),
proj_dim=1024,
text_prototype='cifar100',
)

View File

@ -0,0 +1,75 @@
_base_ = '../_base_/default_runtime.py'
# data settings
data_preprocessor = dict(
type='MultiModalDataPreprocessor',
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
to_rgb=False,
)
test_pipeline = [
dict(type='Resize', scale=(224, 224), interpolation='bicubic'),
dict(
type='PackInputs',
meta_keys=['image_id', 'scale_factor'],
),
]
train_dataloader = None
test_dataloader = dict(
batch_size=32,
num_workers=8,
dataset=dict(
type='CIFAR100',
data_root='data/cifar100',
split='test',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
test_evaluator = dict(type='Accuracy', topk=(1, ))
# schedule settings
train_cfg = None
val_cfg = None
test_cfg = dict()
# model settings
model = dict(
type='ChineseCLIP',
vision_backbone=dict(
type='VisionTransformer',
arch='large',
img_size=224,
patch_size=14,
norm_cfg=dict(type='LN', eps=1e-5),
final_norm=True,
layer_cfgs=dict(act_cfg=dict(type='QuickGELU')),
pre_norm=True,
out_type='cls_token',
),
text_backbone=dict(
type='BertModelCN',
config=dict(
vocab_size=21128,
pad_token_id=0,
add_type_embeddings=True,
attention_probs_dropout_prob=0.1,
hidden_act='gelu',
hidden_dropout_prob=0.1,
hidden_size=768,
initializer_range=0.02,
intermediate_size=3072,
max_position_embeddings=512,
num_attention_heads=12,
num_hidden_layers=12,
type_vocab_size=2,
layer_norm_eps=1e-12)),
tokenizer=dict(
type='FullTokenizer',
vocab_file= # noqa
'https://download.openmmlab.com/mmpretrain/v1.0/chinese_clip/vocab.txt'
),
proj_dim=768,
text_prototype='cifar100',
)

View File

@ -0,0 +1,79 @@
Collections:
- Name: ChineseCLIP
Metadata:
Training Data:
- LAION-5B
- WuKong
- VisualGenome
- MSCOCO
Architecture:
- Transformer
Paper:
Title: 'Chinese CLIP: Contrastive Vision-Language Pretraining in Chinese'
URL: https://arxiv.org/abs/2211.01335
README: configs/chinese_clip/README.md
Models:
- Name: cn-clip_resnet50_zeroshot-cls_cifar100
Metadata:
FLOPs: null
Parameters: 77000000
In Collection: ChineseCLIP
Results:
- Task: Image Classification
Dataset: CIFAR100
Metrics:
Top 1 Accuracy: 40.7
Weights: https://download.openmmlab.com/mmpretrain/v1.0/chinese_clip/cn-clip_resnet50_3rdparty_20230519-6a2b3eb2.pth
Config: configs/chinese_clip/cn-clip_resnet50_zeroshot-cls_cifar100.py
Converted From:
Weights: https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/clip_cn_rn50.pt
Code: https://github.com/OFA-Sys/Chinese-CLIP
- Name: cn-clip_vit-base-p16_zeroshot-cls_cifar100
Metadata:
FLOPs: null
Parameters: 188000000
In Collection: ChineseCLIP
Results:
- Task: Image Classification
Dataset: CIFAR100
Metrics:
Top 1 Accuracy: 64.5
Weights: https://download.openmmlab.com/mmpretrain/v1.0/chinese_clip/cn-clip_vit-base-p16_3rdparty_20230519-37fbc59e.pth
Config: configs/chinese_clip/cn-clip_vit-base-p16_zeroshot-cls_cifar100.py
Converted From:
Weights: https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/clip_cn_vit-b-16.pt
Code: https://github.com/OFA-Sys/Chinese-CLIP
- Name: cn-clip_vit-large-p14_zeroshot-cls_cifar100
Metadata:
FLOPs: null
Parameters: 406000000
In Collection: ChineseCLIP
Results:
- Task: Image Classification
Dataset: CIFAR100
Metrics:
Top 1 Accuracy: 74.8
Weights: https://download.openmmlab.com/mmpretrain/v1.0/chinese_clip/cn-clip_vit-large-p14_3rdparty_20230519-3f844503.pth
Config: configs/chinese_clip/cn-clip_vit-large-p14_zeroshot-cls_cifar100.py
Converted From:
Weights: https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/clip_cn_vit-l-14.pt
Code: https://github.com/OFA-Sys/Chinese-CLIP
- Name: cn-clip_vit-huge-p14_zeroshot-cls_cifar100
Metadata:
FLOPs: null
Parameters: 958000000
In Collection: ChineseCLIP
Results:
- Task: Image Classification
Dataset: CIFAR100
Metrics:
Top 1 Accuracy: 79.1
Weights: https://download.openmmlab.com/mmpretrain/v1.0/chinese_clip/cn-clip_vit-huge-p14_3rdparty_20230519-e4f49b00.pth
Config: configs/chinese_clip/cn-clip_vit-huge-p14_zeroshot-cls_cifar100.py
Converted From:
Weights: https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/clip_cn_vit-h-14.pt
Code: https://github.com/OFA-Sys/Chinese-CLIP

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import abstractmethod
from math import ceil
from typing import Callable, Iterable, List, Optional, Tuple, Union
import numpy as np
@ -154,7 +155,8 @@ class BaseInferencer:
inputs = self.preprocess(
ori_inputs, batch_size=batch_size, **preprocess_kwargs)
preds = []
for data in track(inputs, 'Inference'):
for data in track(
inputs, 'Inference', total=ceil(len(ori_inputs) / batch_size)):
preds.extend(self.forward(data, **forward_kwargs))
visualization = self.visualize(ori_inputs, preds, **visualize_kwargs)
results = self.postprocess(preds, visualization, return_datasamples,

View File

@ -1427,3 +1427,14 @@ FOOD101_CATEGORIES = (
'shrimp_and_grits', 'spaghetti_bolognese', 'spaghetti_carbonara',
'spring_rolls', 'steak', 'strawberry_shortcake', 'sushi', 'tacos',
'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles')
CIFAR100_CATEGORIES_CN = (
'苹果', '水族馆鱼', '婴儿', '', '河狸', '', '蜜蜂', '甲虫', '自行车', '瓶子', '', '小男孩',
'', '公共汽车', '蝴蝶', '骆驼', '易拉罐', '城堡', '毛毛虫', '', '椅子', '猩猩', '', '白云',
'蟑螂', '沙发', '螃蟹', '鳄鱼', '杯子', '恐龙', '海豚', '大象', '比目鱼', '森林', '狐狸', '小女孩',
'仓鼠', '屋子', '袋鼠', '键盘', '台灯', '割草机', '猎豹', '狮子', '蜥蜴', '龙虾', '男人', '枫树',
'摩托车', '', '老鼠', '蘑菇', '橡树', '橙子橘子', '兰花', '水獭', '棕榈树', '', '皮卡车', '松树',
'田野', '盘子', '罂粟', '豪猪', '负鼠', '兔子', '浣熊', '鳐鱼', '公路', '火箭', '玫瑰', '大海',
'海豹', '鲨鱼', '尖嘴小鼠', '臭鼬', '摩天大楼', '蜗牛', '', '蜘蛛', '松鼠', '电车', '向日葵', '甜椒',
'桌子', '坦克', '电话', '电视', '老虎', '拖拉机', '火车', '鳟鱼', '郁金香', '乌龟', '衣柜', '鲸鱼',
'柳树', '', '女人', '蠕虫')

View File

@ -4,6 +4,7 @@ from mmpretrain.utils.dependency import WITH_MULTIMODAL
if WITH_MULTIMODAL:
from .blip import * # noqa: F401,F403
from .blip2 import * # noqa: F401,F403
from .chinese_clip import * # noqa: F401, F403
from .flamingo import * # noqa: F401, F403
from .ofa import * # noqa: F401, F403
else:
@ -13,5 +14,5 @@ else:
register_multimodal_placeholder([
'Blip2Caption', 'Blip2Retrieval', 'Blip2VQA', 'BlipCaption',
'BlipNLVR', 'BlipRetrieval', 'BlipGrounding', 'BlipVQA', 'Flamingo',
'OFA'
'OFA', 'ChineseCLIP'
], MODELS)

View File

@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .bert import BertModelCN
from .chinese_clip import ChineseCLIP, ModifiedResNet
__all__ = ['ChineseCLIP', 'ModifiedResNet', 'BertModelCN']

View File

@ -0,0 +1,263 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
# flake8: noqa
import math
import torch
from torch import nn
from torch.utils.checkpoint import checkpoint
try:
from transformers.models.bert.configuration_bert import BertConfig
except:
BertConfig = None
from mmpretrain.registry import MODELS
from ..blip.language_model import BertAttention, BertIntermediate, BertOutput
def gelu(x):
"""Original Implementation of the gelu activation function in Google Bert
repo when initially created.
For information: OpenAI GPT's gelu is slightly different (and gives
slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
Also see https://arxiv.org/abs/1606.08415
""" # noqa
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def gelu_new(x):
"""Implementation of the gelu activation function currently in Google Bert
repo (identical to OpenAI GPT) https://arxiv.org/abs/1606.08415."""
return 0.5 * x * (1 + torch.tanh(
math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
def swish(x):
return x * torch.sigmoid(x)
ACT2FN = {
'gelu': gelu,
'relu': torch.nn.functional.relu,
'swish': swish,
'gelu_new': gelu_new
}
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type
embeddings."""
def __init__(self, config):
super(BertEmbeddings, self).__init__()
self.word_embeddings = nn.Embedding(
config.vocab_size, config.hidden_size, padding_idx=0)
self.position_embeddings = nn.Embedding(config.max_position_embeddings,
config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model
# variable name and be able to load any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids, token_type_ids=None, position_ids=None):
seq_length = input_ids.size(1)
if position_ids is None:
position_ids = torch.arange(
seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = words_embeddings + position_embeddings \
+ token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class BertLayer(nn.Module):
def __init__(self, config):
super(BertLayer, self).__init__()
self.attention = BertAttention(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask=None, head_mask=None):
attention_outputs = self.attention(hidden_states, attention_mask,
head_mask)
attention_output = attention_outputs[0]
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
outputs = (layer_output, ) + attention_outputs[
1:] # add attentions if we output them
if len(outputs) == 1:
return outputs[0]
return outputs
class BertEncoder(nn.Module):
def __init__(self, config):
super(BertEncoder, self).__init__()
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.grad_checkpointing = False
self.layer = nn.ModuleList(
[BertLayer(config) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask=None, head_mask=None):
all_hidden_states = ()
all_attentions = ()
for i, layer_module in enumerate(self.layer):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states, )
if self.grad_checkpointing and not torch.jit.is_scripting():
layer_outputs = checkpoint(layer_module, hidden_states,
attention_mask, head_mask[i])
else:
layer_outputs = layer_module(hidden_states, attention_mask,
head_mask[i])
if not isinstance(layer_outputs, tuple):
layer_outputs = (layer_outputs, )
hidden_states = layer_outputs[0]
if self.output_attentions:
all_attentions = all_attentions + (layer_outputs[1], )
# Add last layer
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states, )
outputs = (hidden_states, )
if self.output_hidden_states:
outputs = outputs + (all_hidden_states, )
if self.output_attentions:
outputs = outputs + (all_attentions, )
# last-layer hidden state, (all hidden states), (all attentions)
return outputs
class BertPreTrainedModel(nn.Module):
base_model_prefix = 'bert'
def __init__(self, config):
super(BertPreTrainedModel, self).__init__()
self.config = config
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version
# which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(
mean=0.0, std=self.config.initializer_range)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
@MODELS.register_module()
class BertModelCN(BertPreTrainedModel):
"""The BERT model implementation for Chinese CLIP."""
def __init__(self, config):
config = BertConfig.from_dict(config)
super(BertModelCN, self).__init__(config)
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.apply(self._init_weights)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
if enable:
assert not self.config.output_attentions, \
'Grad checkpointing is currently conflict with ' \
'output_attentions for BertEncoder, ' \
'please set it to False in BertConfig'
self.encoder.grad_checkpointing = enable
def forward(self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(
dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(
-1).unsqueeze(-1)
head_mask = head_mask.expand(self.config.num_hidden_layers, -1,
-1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(
-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters(
)).dtype) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.config.num_hidden_layers
embedding_output = self.embeddings(
input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids)
encoder_outputs = self.encoder(
embedding_output, extended_attention_mask, head_mask=head_mask)
sequence_output = encoder_outputs[0]
# pooled_output = self.pooler(sequence_output)
pooled_output = None
# add hidden_states and attentions if they are here
outputs = (
sequence_output,
pooled_output,
) + encoder_outputs[1:]
# sequence_output, pooled_output, (hidden_states), (attentions)
return outputs

View File

@ -0,0 +1,446 @@
# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from mmengine.model import BaseModel, BaseModule
from torch import nn
from mmpretrain.datasets.categories import CIFAR100_CATEGORIES_CN
from mmpretrain.registry import MODELS, TOKENIZER
from mmpretrain.structures import DataSample
from mmpretrain.utils import track_on_main_process
from .utils import OPENAI_PROMPT
PROTOTYPE_MAP = {'cifar100': CIFAR100_CATEGORIES_CN}
PROMPT_MAP = {'openai': OPENAI_PROMPT}
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
self.downsample = nn.Sequential(
OrderedDict([('-1', nn.AvgPool2d(stride)),
('0',
nn.Conv2d(
inplanes,
planes * self.expansion,
1,
stride=1,
bias=False)),
('1', nn.BatchNorm2d(planes * self.expansion))]))
def forward(self, x: torch.Tensor):
identity = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class AttentionPool2d(nn.Module):
def __init__(self,
spacial_dim: int,
embed_dim: int,
num_heads: int,
output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(
torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.reshape(x.shape[0], x.shape[1],
x.shape[2] * x.shape[3]).permute(2, 0,
1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x,
key=x,
value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat(
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False)
return x[0]
@MODELS.register_module()
class ModifiedResNet(BaseModule):
"""A modified ResNet contains the following changes:
- Apply deep stem with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is
prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
""" # noqa
arch_settings = {
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3))
}
def __init__(self,
depth: int = 50,
base_channels: int = 64,
input_size: int = 224,
num_attn_heads: int = 32,
output_dim: int = 1024,
init_cfg: Optional[dict] = None):
super().__init__(init_cfg=init_cfg)
self.input_size = input_size
self.block, stage_blocks = self.arch_settings[depth]
# the 3-layer stem
self.conv1 = nn.Conv2d(
3,
base_channels // 2,
kernel_size=3,
stride=2,
padding=1,
bias=False)
self.bn1 = nn.BatchNorm2d(base_channels // 2)
self.conv2 = nn.Conv2d(
base_channels // 2,
base_channels // 2,
kernel_size=3,
padding=1,
bias=False)
self.bn2 = nn.BatchNorm2d(base_channels // 2)
self.conv3 = nn.Conv2d(
base_channels // 2,
base_channels,
kernel_size=3,
padding=1,
bias=False)
self.bn3 = nn.BatchNorm2d(base_channels)
self.avgpool = nn.AvgPool2d(2)
self.relu = nn.ReLU(inplace=True)
# residual layers
# this is a *mutable* variable used during construction
self._inplanes = base_channels
self.layer1 = self._make_layer(base_channels, stage_blocks[0])
self.layer2 = self._make_layer(
base_channels * 2, stage_blocks[1], stride=2)
self.layer3 = self._make_layer(
base_channels * 4, stage_blocks[2], stride=2)
self.layer4 = self._make_layer(
base_channels * 8, stage_blocks[3], stride=2)
embed_dim = base_channels * 32
self.attnpool = AttentionPool2d(input_size // 32, embed_dim,
num_attn_heads, output_dim)
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
def stem(x):
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2),
(self.conv3, self.bn3)]:
x = self.relu(bn(conv(x)))
x = self.avgpool(x)
return x
x = x.type(self.conv1.weight.dtype)
x = stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
return x
@MODELS.register_module()
class ChineseCLIP(BaseModel):
"""The implementation of `ChineseCLIP <https://arxiv.org/abs/2211.01335>`_.
Args:
vision_backbone (dict): Config dict for vision backbone.
text_backbone (dict): Config dict for text backbone.
tokenizer (dict): Config dict for text tokenizer.
proj_dim (int): Projection dimension for similarity computation.
text_prototype (str): Text prototype, which can be a key in
`PROTOTYPE_MAP` or list of text.
text_prompt (str): The prompt for text prototype. Defaults to 'openai'.
context_length (int): The context length to use. Defaults to 52.
data_preprocessor (Union[dict, nn.Module], optional): The config for
preprocessing input data. If None or no specified type, it will use
"MultiModalDataPreprocessor" as type.
See :class:`MultiModalDataPreprocessor` for more details.
Defaults to None.
init_cfg (dict, optional): The config to control the initialization.
Defaults to None.
"""
def __init__(self,
vision_backbone: dict,
text_backbone: dict,
tokenizer: dict,
proj_dim: int,
text_prototype: Union[str, List[str]],
text_prompt: str = 'openai',
context_length: int = 52,
data_preprocessor: Optional[dict] = None,
init_cfg: Optional[dict] = None):
if data_preprocessor is None:
data_preprocessor = {}
data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor')
data_preprocessor = MODELS.build(data_preprocessor)
super().__init__(
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
self.vision_backbone = MODELS.build(vision_backbone)
self.text_backbone = MODELS.build(text_backbone)
if not isinstance(self.vision_backbone, ModifiedResNet):
self.vision_projection = nn.Parameter(
torch.empty(self.vision_backbone.embed_dims, proj_dim))
text_hidden_size = text_backbone['config']['hidden_size']
self.text_projection = nn.Parameter(
torch.empty(text_hidden_size, proj_dim))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.tokenizer = TOKENIZER.build(tokenizer)
self.context_length = context_length
# for zero-shot classification
if isinstance(text_prototype,
str) and text_prototype in PROTOTYPE_MAP.keys():
self.prototype = PROTOTYPE_MAP[text_prototype]
else:
self.prototype = text_prototype
self.text_prototype_embeds = None
self.prompt = PROMPT_MAP[text_prompt]
def forward(
self,
images: torch.Tensor,
data_samples: Optional[list] = None,
mode: str = 'predict',
**kwargs,
):
"""The unified entry for a forward process in both training and test.
The method accepts the following modes:
- "predict": Forward and return a list of data samples contain the
predict results.
Args:
images (torch.Tensor): the preprocessed image tensor of shape
``(N, C, H, W)``.
data_samples (List[DataSample], optional): The annotation data
of every samples. Defaults to None.
mode (str): Return what kind of value. Defaults to 'predict'.
"""
if mode == 'predict':
return self.predict(images, data_samples, **kwargs)
else:
raise RuntimeError(f'Invalid mode "{mode}".')
def extract_image_feat(self, images: torch.Tensor) -> torch.Tensor:
"""The function to extract image latent features."""
if isinstance(self.vision_backbone, ModifiedResNet):
return self.vision_backbone(images)
return self.vision_backbone(images)[-1] @ self.vision_projection
def extract_text_feat(self, texts: torch.Tensor) -> torch.Tensor:
"""The function to extract text latent features."""
pad_index = self.tokenizer.vocab['[PAD]']
attn_mask = texts.ne(pad_index)
# [batch_size, seq_length, hidden_size]
x = self.text_backbone(texts, attention_mask=attn_mask)[0]
return x[:, 0, :] @ self.text_projection
def extract_feat(
self, images: torch.Tensor,
texts: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
"""The function to extract image and text latent features, the input
image or text can not both be None."""
assert images is not None or texts is not None, \
'text and image cannot both be None!'
if images is None:
return self.extract_text_feat(texts)
elif texts is None:
return self.extract_image_feat(images)
image_features = self.extract_image_feat(images)
text_features = self.extract_text_feat(texts)
image_features = image_features / image_features.norm(
dim=-1, keepdim=True)
text_features = text_features / text_features.norm(
dim=-1, keepdim=True)
return image_features, text_features
def compute_similarity(self, images, texts):
"""Extract images and texts features and compute cosine similarity."""
image_features, text_features = self.extract_feat(
images=images, texts=texts)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
# shape (N, N)
return logits_per_image, logits_per_text
def predict(self,
images: torch.Tensor,
data_samples: DataSample = None) -> DataSample:
"""Predict the classes of the input images.
The prediction is for zero-shot classification and the text prototypes
will be prepared in thisfunction.
Args:
images (torch.Tensor): The input images.
data_samples (DataSample): The data samples with information from
dataset.
Returns:
DataSample: The results of prediction.
"""
if self.text_prototype_embeds is None:
self.prepare_text_prototype(device=images.device)
image_features = self.extract_image_feat(images=images)
image_features /= image_features.norm(dim=-1, keepdim=True)
# cosine similarity as logits
logits_per_image = image_features @ self.text_prototype_embeds.to(
image_features.device) * self.logit_scale.exp()
pred_scores = F.softmax(logits_per_image, dim=1)
pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach()
out_data_samples = []
if data_samples is None:
data_samples = [None for _ in range(pred_scores.size(0))]
for data_sample, score, label in zip(data_samples, pred_scores,
pred_labels):
if data_sample is None:
data_sample = DataSample()
data_sample.set_pred_score(score).set_pred_label(label)
out_data_samples.append(data_sample)
return out_data_samples
def prepare_text_prototype(self, device) -> None:
"""The function to prepare text prototypes with prompt."""
class_embeddings = []
for classname in track_on_main_process(self.prototype,
'Prepare text prototype...'):
# format with class
texts = [prompt(classname) for prompt in self.prompt]
tokenized_texts = self.tokenize(texts)
class_features = self.extract_text_feat(tokenized_texts.to(device))
class_features /= class_features.norm(dim=-1, keepdim=True)
class_feature = class_features.mean(dim=0)
class_feature /= class_feature.norm()
class_embeddings.append(class_feature)
self.text_prototype_embeds = torch.stack(
class_embeddings, dim=1).to(device)
def tokenize(self, texts: Union[str, List[str]]) -> torch.LongTensor:
"""Returns the tokenized representation of given input string(s)
Args:
texts (Union[str, List[str]]): An input string or a list of input
strings to tokenize
context_length (int): The context length to use. Defaults to 52.
Returns:
torch.Tensor: Resulting tokens.
"""
if isinstance(texts, str):
texts = [texts]
all_tokens = []
for text in texts:
# adapt the text to Chinese BERT vocab
text = text.lower().replace('', "\"").replace('', "\"")
# add special tokens
all_tokens.append(
[self.tokenizer.vocab['[CLS]']] +
self.tokenizer.convert_tokens_to_ids(
self.tokenizer.tokenize(text))[:self.context_length - 2] +
[self.tokenizer.vocab['[SEP]']])
result = torch.zeros(
len(all_tokens), self.context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
assert len(tokens) <= self.context_length
result[i, :len(tokens)] = torch.tensor(tokens)
return result

View File

@ -0,0 +1,186 @@
# Copyright (c) OpenMMLab. All rights reserved.
OPENAI_PROMPT = [
lambda c: f'{c}的照片',
lambda c: f'质量差的{c}的照片',
lambda c: f'许多{c}的照片',
lambda c: f'{c}的雕塑',
lambda c: f'难以看到{c}的照片',
lambda c: f'{c}的低分辨率照片',
lambda c: f'{c}的渲染',
lambda c: f'涂鸦{c}',
lambda c: f'{c}的糟糕照片',
lambda c: f'{c}的裁剪照片',
lambda c: f'{c}的纹身',
lambda c: f'{c}的刺绣照片',
lambda c: f'很难看到{c}的照片',
lambda c: f'{c}的明亮照片',
lambda c: f'一张干净的{c}的照片',
lambda c: f'一张包含{c}的照片',
lambda c: f'{c}的深色照片',
lambda c: f'{c}的手绘画',
lambda c: f'我的{c}的照片',
lambda c: f'不自然的{c}的照片',
lambda c: f'一张酷的{c}的照片',
lambda c: f'{c}的特写照片',
lambda c: f'{c}的黑白照片',
lambda c: f'一幅{c}的画',
lambda c: f'一幅{c}的绘画',
lambda c: f'一张{c}的像素照片',
lambda c: f'{c}的雕像',
lambda c: f'一张{c}的明亮照片',
lambda c: f'{c}的裁剪照片',
lambda c: f'人造的{c}的照片',
lambda c: f'一张关于{c}的照片',
lambda c: f'损坏的{c}的jpeg照片',
lambda c: f'{c}的模糊照片',
lambda c: f'{c}的相片',
lambda c: f'一张{c}的好照片',
lambda c: f'{c}的渲染照',
lambda c: f'视频游戏中的{c}',
lambda c: f'一张{c}的照片',
lambda c: f'{c}的涂鸦',
lambda c: f'{c}的近距离照片',
lambda c: f'{c}的折纸',
lambda c: f'{c}在视频游戏中',
lambda c: f'{c}的草图',
lambda c: f'{c}的涂鸦照',
lambda c: f'{c}的折纸形状',
lambda c: f'低分辨率的{c}的照片',
lambda c: f'玩具{c}',
lambda c: f'{c}的副本',
lambda c: f'{c}的干净的照片',
lambda c: f'一张大{c}的照片',
lambda c: f'{c}的重现',
lambda c: f'一张漂亮的{c}的照片',
lambda c: f'一张奇怪的{c}的照片',
lambda c: f'模糊的{c}的照片',
lambda c: f'卡通{c}',
lambda c: f'{c}的艺术作品',
lambda c: f'{c}的素描',
lambda c: f'刺绣{c}',
lambda c: f'{c}的像素照',
lambda c: f'{c}的拍照',
lambda c: f'{c}的损坏的照片',
lambda c: f'高质量的{c}的照片',
lambda c: f'毛绒玩具{c}',
lambda c: f'漂亮的{c}的照片',
lambda c: f'{c}的照片',
lambda c: f'照片是奇怪的{c}',
lambda c: f'漫画{c}',
lambda c: f'{c}的艺术照',
lambda c: f'{c}的图形',
lambda c: f'{c}的照片',
lambda c: f'黑白的{c}的照片',
lambda c: f'{c}毛绒玩具',
lambda c: f'一张{c}的深色照片',
lambda c: f'{c}的摄影图',
lambda c: f'{c}的涂鸦照',
lambda c: f'玩具形状的{c}',
lambda c: f'拍了{c}的照片',
lambda c: f'酷酷的{c}的照片',
lambda c: f'照片里的小{c}',
lambda c: f'{c}的刺青',
lambda c: f'{c}的可爱的照片',
lambda c: f'一张{c}可爱的照片',
lambda c: f'{c}可爱图片',
lambda c: f'{c}酷炫图片',
lambda c: f'一张{c}的酷炫的照片',
lambda c: f'一张{c}的酷炫图片',
lambda c: f'这是{c}',
lambda c: f'{c}的好看照片',
lambda c: f'一张{c}的好看的图片',
lambda c: f'{c}的好看图片',
lambda c: f'{c}的照片。',
lambda c: f'质量差的{c}的照片。',
lambda c: f'许多{c}的照片。',
lambda c: f'{c}的雕塑。',
lambda c: f'难以看到{c}的照片。',
lambda c: f'{c}的低分辨率照片。',
lambda c: f'{c}的渲染。',
lambda c: f'涂鸦{c}',
lambda c: f'{c}的糟糕照片。',
lambda c: f'{c}的裁剪照片。',
lambda c: f'{c}的纹身。',
lambda c: f'{c}的刺绣照片。',
lambda c: f'很难看到{c}的照片。',
lambda c: f'{c}的明亮照片。',
lambda c: f'一张干净的{c}的照片。',
lambda c: f'一张包含{c}的照片。',
lambda c: f'{c}的深色照片。',
lambda c: f'{c}的手绘画。',
lambda c: f'我的{c}的照片。',
lambda c: f'不自然的{c}的照片。',
lambda c: f'一张酷的{c}的照片。',
lambda c: f'{c}的特写照片。',
lambda c: f'{c}的黑白照片。',
lambda c: f'一幅{c}的画。',
lambda c: f'一幅{c}的绘画。',
lambda c: f'一张{c}的像素照片。',
lambda c: f'{c}的雕像。',
lambda c: f'一张{c}的明亮照片。',
lambda c: f'{c}的裁剪照片。',
lambda c: f'人造的{c}的照片。',
lambda c: f'一张关于{c}的照片。',
lambda c: f'损坏的{c}的jpeg照片。',
lambda c: f'{c}的模糊照片。',
lambda c: f'{c}的相片。',
lambda c: f'一张{c}的好照片。',
lambda c: f'{c}的渲染照。',
lambda c: f'视频游戏中的{c}',
lambda c: f'一张{c}的照片。',
lambda c: f'{c}的涂鸦。',
lambda c: f'{c}的近距离照片。',
lambda c: f'{c}的折纸。',
lambda c: f'{c}在视频游戏中。',
lambda c: f'{c}的草图。',
lambda c: f'{c}的涂鸦照。',
lambda c: f'{c}的折纸形状。',
lambda c: f'低分辨率的{c}的照片。',
lambda c: f'玩具{c}',
lambda c: f'{c}的副本。',
lambda c: f'{c}的干净的照片。',
lambda c: f'一张大{c}的照片。',
lambda c: f'{c}的重现。',
lambda c: f'一张漂亮的{c}的照片。',
lambda c: f'一张奇怪的{c}的照片。',
lambda c: f'模糊的{c}的照片。',
lambda c: f'卡通{c}',
lambda c: f'{c}的艺术作品。',
lambda c: f'{c}的素描。',
lambda c: f'刺绣{c}',
lambda c: f'{c}的像素照。',
lambda c: f'{c}的拍照。',
lambda c: f'{c}的损坏的照片。',
lambda c: f'高质量的{c}的照片。',
lambda c: f'毛绒玩具{c}',
lambda c: f'漂亮的{c}的照片。',
lambda c: f'{c}的照片。',
lambda c: f'照片是奇怪的{c}',
lambda c: f'漫画{c}',
lambda c: f'{c}的艺术照。',
lambda c: f'{c}的图形。',
lambda c: f'{c}的照片。',
lambda c: f'黑白的{c}的照片。',
lambda c: f'{c}毛绒玩具。',
lambda c: f'一张{c}的深色照片。',
lambda c: f'{c}的摄影图。',
lambda c: f'{c}的涂鸦照。',
lambda c: f'玩具形状的{c}',
lambda c: f'拍了{c}的照片。',
lambda c: f'酷酷的{c}的照片。',
lambda c: f'照片里的小{c}',
lambda c: f'{c}的刺青。',
lambda c: f'{c}的可爱的照片。',
lambda c: f'一张{c}可爱的照片。',
lambda c: f'{c}可爱图片。',
lambda c: f'{c}酷炫图片。',
lambda c: f'一张{c}的酷炫的照片。',
lambda c: f'一张{c}的酷炫图片。',
lambda c: f'这是{c}',
lambda c: f'{c}的好看照片。',
lambda c: f'一张{c}的好看的图片。',
lambda c: f'{c}的好看图片。',
lambda c: f'一种叫{c}的花的照片',
lambda c: f'一种叫{c}的食物的照片',
lambda c: f'{c}的卫星照片',
]

View File

@ -83,9 +83,10 @@ __all__ = [
if WITH_MULTIMODAL:
from .huggingface import (no_load_hf_pretrained_model, register_hf_model,
register_hf_tokenizer)
from .tokenizer import Blip2Tokenizer, BlipTokenizer, OFATokenizer
from .tokenizer import (Blip2Tokenizer, BlipTokenizer, FullTokenizer,
OFATokenizer)
__all__.extend([
'BlipTokenizer', 'OFATokenizer', 'Blip2Tokenizer', 'register_hf_model',
'register_hf_tokenizer', 'no_load_hf_pretrained_model'
'register_hf_tokenizer', 'no_load_hf_pretrained_model', 'FullTokenizer'
])

View File

@ -1,9 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
import collections
import os
from transformers import (AutoTokenizer, BartTokenizer, BertTokenizer,
BertTokenizerFast, LlamaTokenizer)
from mmengine.fileio import list_from_file
from transformers import (AutoTokenizer, BartTokenizer, BasicTokenizer,
BertTokenizer, BertTokenizerFast, LlamaTokenizer,
WordpieceTokenizer)
from mmpretrain.registry import TOKENIZER
from .huggingface import register_hf_tokenizer
register_hf_tokenizer(AutoTokenizer)
@ -110,3 +114,74 @@ class OFATokenizer(BartTokenizer):
tokenizer.bin_offset = length + 8192
tokenizer.num_bins = num_bins
return tokenizer
@TOKENIZER.register_module()
class FullTokenizer(BertTokenizer):
"""Runs end-to-end tokenziation."""
def __init__(self, vocab_file, do_lower_case=True):
self.vocab = self.load_vocab(vocab_file)
self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
self.wordpiece_tokenizer = WordpieceTokenizer(
vocab=self.vocab, unk_token='[UNK]', max_input_chars_per_word=200)
def load_vocab(self, vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
index = 0
vocab_list = list_from_file(vocab_file)
for token in vocab_list:
if not token:
break
token = token.strip()
vocab[token] = index
index += 1
return vocab
def tokenize(self, text):
split_tokens = []
for token in self.basic_tokenizer.tokenize(text):
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
return split_tokens
def convert_by_vocab(self, vocab, items):
"""Converts a sequence of [tokens|ids] using the vocab."""
output = []
for item in items:
output.append(vocab[item])
return output
def convert_tokens_to_ids(self, tokens):
return self.convert_by_vocab(self.vocab, tokens)
def convert_ids_to_tokens(self, ids):
return self.convert_by_vocab(self.inv_vocab, ids)
@staticmethod
def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True):
"""Converts a sequence of tokens (string) in a single string."""
def clean_up_tokenization(out_string):
"""Clean up a list of simple English tokenization artifacts like
spaces before punctuations and abbreviated forms."""
out_string = (
out_string.replace(' .', '.').replace(' ?', '?').replace(
' !', '!').replace(' ,', ',').replace(" ' ", "'").replace(
" n't", "n't").replace(" 'm", "'m").replace(
" 's", "'s").replace(" 've",
"'ve").replace(" 're", "'re"))
return out_string
text = ' '.join(tokens).replace(' ##', '').strip()
if clean_up_tokenization_spaces:
clean_text = clean_up_tokenization(text)
return clean_text
else:
return text
def vocab_size(self):
return len(self.vocab)

View File

@ -1,19 +1,40 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import mmengine.dist as dist
import rich.progress as progress
from rich.live import Live
disable_progress_bar = False
global_progress = progress.Progress(
'{task.description}',
progress.BarColumn(),
progress.TaskProgressColumn(show_speed=True),
progress.TimeRemainingColumn(),
)
global_live = Live(global_progress, refresh_per_second=10)
def track(sequence, *args, **kwargs):
def track(sequence, description: str = '', total: Optional[float] = None):
if disable_progress_bar:
return sequence
yield from sequence
else:
return progress.track(sequence, *args, **kwargs)
global_live.start()
task_id = global_progress.add_task(description, total=total)
task = global_progress._tasks[task_id]
try:
yield from global_progress.track(sequence, task_id=task_id)
finally:
if task.total is None:
global_progress.update(task_id, total=task.completed)
if all(task.finished for task in global_progress.tasks):
global_live.stop()
for task_id in global_progress.task_ids:
global_progress.remove_task(task_id)
def track_on_main_process(sequence, *args, **kwargs):
def track_on_main_process(sequence, description='', total=None):
if not dist.is_main_process() or disable_progress_bar:
yield from sequence
else:
yield from progress.track(sequence, *args, **kwargs)
yield from track(sequence, total=total, description=description)

View File

@ -75,3 +75,4 @@ Import:
- configs/blip/metafile.yml
- configs/flamingo/metafile.yml
- configs/blip2/metafile.yml
- configs/chinese_clip/metafile.yml