[Feature] Support Chinese CLIP. (#1576)
* support cn-clip * update README * Update progress bar * update order of category * fix lint * update * update readme and metafile * update * update docstring * refactor tokenizer * fix lint * Update README and progress bar --------- Co-authored-by: mzr1996 <mzr1996@163.com>pull/1586/head
parent
d04ef8a29e
commit
1e478462b8
mmpretrain
apis
datasets
models
multimodal
utils
|
@ -0,0 +1,69 @@
|
|||
# ChineseCLIP
|
||||
|
||||
> [Chinese CLIP: Contrastive Vision-Language Pretraining in Chinese](https://arxiv.org/abs/2211.01335)
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
The tremendous success of CLIP (Radford et al., 2021) has promoted the research and application of contrastive learning for vision-language pretraining. In this work, we construct a large-scale dataset of image-text pairs in Chinese, where most data are retrieved from publicly available datasets, and we pretrain Chinese CLIP models on the new dataset. We develop 5 Chinese CLIP models of multiple sizes, spanning from 77 to 958 million parameters. Furthermore, we propose a two-stage pretraining method, where the model is first trained with the image encoder frozen and then trained with all parameters being optimized, to achieve enhanced model performance. Our comprehensive experiments demonstrate that Chinese CLIP can achieve the state-of-the-art performance on MUGE, Flickr30K-CN, and COCO-CN in the setups of zero-shot learning and finetuning, and it is able to achieve competitive performance in zero-shot image classification based on the evaluation on the ELEVATER benchmark (Li et al., 2022). We have released our codes, models, and demos in https://github.com/OFA-Sys/Chinese-CLIP
|
||||
|
||||
<div align=center>
|
||||
<img src="https://github.com/open-mmlab/mmpretrain/assets/36138628/4d05e51f-d834-4ef5-bbf0-0e2f80fea461" width="80%"/>
|
||||
</div>
|
||||
|
||||
## How to use it?
|
||||
|
||||
<!-- [TABS-BEGIN] -->
|
||||
|
||||
**Use the model for zero-shot classification**
|
||||
|
||||
```python
|
||||
from mmpretrain import ImageClassificationInferencer
|
||||
|
||||
inferencer = ImageClassificationInferencer(
|
||||
'cn-clip_resnet50_zeroshot-cls_cifar100',
|
||||
pretrained=True,
|
||||
classes=['鸟', '狗', '猫', '蛇'],
|
||||
text_prototype=['鸟', '狗', '猫', '蛇'],
|
||||
)
|
||||
|
||||
prediction = inferencer('./demo/bird.JPEG')[0]
|
||||
print('Results:', prediction['pred_class'])
|
||||
```
|
||||
|
||||
**Train/Test Command**
|
||||
|
||||
Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset).
|
||||
|
||||
Test:
|
||||
|
||||
```shell
|
||||
python tools/test.py configs/chinese_clip/cn-clip_resnet50_zeroshot-cls_cifar100.py https://download.openmmlab.com/mmpretrain/v1.0/chinese_clip/cn-clip_resnet50_3rdparty_20230519-6a2b3eb2.pth
|
||||
```
|
||||
|
||||
<!-- [TABS-END] -->
|
||||
|
||||
## Models and results
|
||||
|
||||
### Image Classification on CIFAR100
|
||||
|
||||
| Model | Params (M) | Top-1 (%) | Config | Download |
|
||||
| :---------------------------------------------- | :--------: | :-------: | :------------------------------------------------------: | :----------------------------------------------------------------------------: |
|
||||
| `cn-clip_resnet50_zeroshot-cls_cifar100`\* | 77.00 | 40.70 | [config](cn-clip_resnet50_zeroshot-cls_cifar100.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/chinese_clip/cn-clip_resnet50_3rdparty_20230519-6a2b3eb2.pth) |
|
||||
| `cn-clip_vit-base-p16_zeroshot-cls_cifar100`\* | 188.00 | 64.50 | [config](cn-clip_vit-base-p16_zeroshot-cls_cifar100.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/chinese_clip/cn-clip_vit-base-p16_3rdparty_20230519-37fbc59e.pth) |
|
||||
| `cn-clip_vit-large-p14_zeroshot-cls_cifar100`\* | 406.00 | 74.80 | [config](cn-clip_vit-large-p14_zeroshot-cls_cifar100.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/chinese_clip/cn-clip_vit-large-p14_3rdparty_20230519-3f844503.pth) |
|
||||
| `cn-clip_vit-huge-p14_zeroshot-cls_cifar100`\* | 958.00 | 79.10 | [config](cn-clip_vit-huge-p14_zeroshot-cls_cifar100.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/chinese_clip/cn-clip_vit-huge-p14_3rdparty_20230519-e4f49b00.pth) |
|
||||
|
||||
*Models with * are converted from the [official repo](https://github.com/OFA-Sys/Chinese-CLIP). The config files of these models are only for inference. We haven't reprodcue the training results.*
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{chinese-clip,
|
||||
title={Chinese CLIP: Contrastive Vision-Language Pretraining in Chinese},
|
||||
author={Yang, An and Pan, Junshu and Lin, Junyang and Men, Rui and Zhang, Yichang and Zhou, Jingren and Zhou, Chang},
|
||||
journal={arXiv preprint arXiv:2211.01335},
|
||||
year={2022}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,72 @@
|
|||
_base_ = '../_base_/default_runtime.py'
|
||||
|
||||
# data settings
|
||||
data_preprocessor = dict(
|
||||
type='MultiModalDataPreprocessor',
|
||||
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
|
||||
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
|
||||
to_rgb=False,
|
||||
)
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='Resize', scale=(224, 224), interpolation='bicubic'),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
meta_keys=['image_id', 'scale_factor'],
|
||||
),
|
||||
]
|
||||
|
||||
train_dataloader = None
|
||||
test_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='CIFAR100',
|
||||
data_root='data/cifar100',
|
||||
split='test',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
test_evaluator = dict(type='Accuracy', topk=(1, ))
|
||||
|
||||
# schedule settings
|
||||
train_cfg = None
|
||||
val_cfg = None
|
||||
test_cfg = dict()
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='ChineseCLIP',
|
||||
vision_backbone=dict(
|
||||
type='ModifiedResNet',
|
||||
depth=50,
|
||||
base_channels=64,
|
||||
input_size=224,
|
||||
num_attn_heads=32,
|
||||
output_dim=1024,
|
||||
),
|
||||
text_backbone=dict(
|
||||
type='BertModelCN',
|
||||
config=dict(
|
||||
vocab_size=21128,
|
||||
pad_token_id=0,
|
||||
add_type_embeddings=True,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
hidden_act='gelu',
|
||||
hidden_dropout_prob=0.1,
|
||||
hidden_size=768,
|
||||
initializer_range=0.02,
|
||||
intermediate_size=3072,
|
||||
max_position_embeddings=512,
|
||||
num_attention_heads=12,
|
||||
num_hidden_layers=3,
|
||||
type_vocab_size=2,
|
||||
layer_norm_eps=1e-12)),
|
||||
tokenizer=dict(
|
||||
type='FullTokenizer',
|
||||
vocab_file= # noqa
|
||||
'https://download.openmmlab.com/mmpretrain/v1.0/chinese_clip/vocab.txt'
|
||||
),
|
||||
proj_dim=1024,
|
||||
text_prototype='cifar100',
|
||||
)
|
|
@ -0,0 +1,76 @@
|
|||
_base_ = '../_base_/default_runtime.py'
|
||||
|
||||
# data settings
|
||||
data_preprocessor = dict(
|
||||
type='MultiModalDataPreprocessor',
|
||||
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
|
||||
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
|
||||
to_rgb=False,
|
||||
)
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='Resize', scale=(224, 224), interpolation='bicubic'),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['text'],
|
||||
meta_keys=['image_id', 'scale_factor'],
|
||||
),
|
||||
]
|
||||
|
||||
train_dataloader = None
|
||||
test_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='CIFAR100',
|
||||
data_root='data/cifar100',
|
||||
split='test',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
test_evaluator = dict(type='Accuracy', topk=(1, ))
|
||||
|
||||
# schedule settings
|
||||
train_cfg = None
|
||||
val_cfg = None
|
||||
test_cfg = dict()
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='ChineseCLIP',
|
||||
vision_backbone=dict(
|
||||
type='VisionTransformer',
|
||||
arch='base',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
norm_cfg=dict(type='LN', eps=1e-5),
|
||||
final_norm=True,
|
||||
layer_cfgs=dict(act_cfg=dict(type='QuickGELU')),
|
||||
pre_norm=True,
|
||||
out_type='cls_token',
|
||||
),
|
||||
text_backbone=dict(
|
||||
type='BertModelCN',
|
||||
config=dict(
|
||||
vocab_size=21128,
|
||||
pad_token_id=0,
|
||||
add_type_embeddings=True,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
hidden_act='gelu',
|
||||
hidden_dropout_prob=0.1,
|
||||
hidden_size=768,
|
||||
initializer_range=0.02,
|
||||
intermediate_size=3072,
|
||||
max_position_embeddings=512,
|
||||
num_attention_heads=12,
|
||||
num_hidden_layers=12,
|
||||
type_vocab_size=2,
|
||||
layer_norm_eps=1e-12)),
|
||||
tokenizer=dict(
|
||||
type='FullTokenizer',
|
||||
vocab_file= # noqa
|
||||
'https://download.openmmlab.com/mmpretrain/v1.0/chinese_clip/vocab.txt'
|
||||
),
|
||||
proj_dim=512,
|
||||
text_prototype='cifar100',
|
||||
)
|
|
@ -0,0 +1,75 @@
|
|||
_base_ = '../_base_/default_runtime.py'
|
||||
|
||||
# data settings
|
||||
data_preprocessor = dict(
|
||||
type='MultiModalDataPreprocessor',
|
||||
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
|
||||
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
|
||||
to_rgb=False,
|
||||
)
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='Resize', scale=(224, 224), interpolation='bicubic'),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
meta_keys=['image_id', 'scale_factor'],
|
||||
),
|
||||
]
|
||||
|
||||
train_dataloader = None
|
||||
test_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='CIFAR100',
|
||||
data_root='data/cifar100',
|
||||
split='test',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
test_evaluator = dict(type='Accuracy', topk=(1, ))
|
||||
|
||||
# schedule settings
|
||||
train_cfg = None
|
||||
val_cfg = None
|
||||
test_cfg = dict()
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='ChineseCLIP',
|
||||
vision_backbone=dict(
|
||||
type='VisionTransformer',
|
||||
arch='huge',
|
||||
img_size=224,
|
||||
patch_size=14,
|
||||
norm_cfg=dict(type='LN', eps=1e-5),
|
||||
final_norm=True,
|
||||
layer_cfgs=dict(act_cfg=dict(type='QuickGELU')),
|
||||
pre_norm=True,
|
||||
out_type='cls_token',
|
||||
),
|
||||
text_backbone=dict(
|
||||
type='BertModelCN',
|
||||
config=dict(
|
||||
vocab_size=21128,
|
||||
pad_token_id=0,
|
||||
add_type_embeddings=True,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
hidden_act='gelu',
|
||||
hidden_dropout_prob=0.1,
|
||||
hidden_size=1024,
|
||||
initializer_range=0.02,
|
||||
intermediate_size=4096,
|
||||
max_position_embeddings=512,
|
||||
num_attention_heads=16,
|
||||
num_hidden_layers=24,
|
||||
type_vocab_size=2,
|
||||
layer_norm_eps=1e-12)),
|
||||
tokenizer=dict(
|
||||
type='FullTokenizer',
|
||||
vocab_file= # noqa
|
||||
'https://download.openmmlab.com/mmpretrain/v1.0/chinese_clip/vocab.txt'
|
||||
),
|
||||
proj_dim=1024,
|
||||
text_prototype='cifar100',
|
||||
)
|
|
@ -0,0 +1,75 @@
|
|||
_base_ = '../_base_/default_runtime.py'
|
||||
|
||||
# data settings
|
||||
data_preprocessor = dict(
|
||||
type='MultiModalDataPreprocessor',
|
||||
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
|
||||
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
|
||||
to_rgb=False,
|
||||
)
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='Resize', scale=(224, 224), interpolation='bicubic'),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
meta_keys=['image_id', 'scale_factor'],
|
||||
),
|
||||
]
|
||||
|
||||
train_dataloader = None
|
||||
test_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='CIFAR100',
|
||||
data_root='data/cifar100',
|
||||
split='test',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
test_evaluator = dict(type='Accuracy', topk=(1, ))
|
||||
|
||||
# schedule settings
|
||||
train_cfg = None
|
||||
val_cfg = None
|
||||
test_cfg = dict()
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='ChineseCLIP',
|
||||
vision_backbone=dict(
|
||||
type='VisionTransformer',
|
||||
arch='large',
|
||||
img_size=224,
|
||||
patch_size=14,
|
||||
norm_cfg=dict(type='LN', eps=1e-5),
|
||||
final_norm=True,
|
||||
layer_cfgs=dict(act_cfg=dict(type='QuickGELU')),
|
||||
pre_norm=True,
|
||||
out_type='cls_token',
|
||||
),
|
||||
text_backbone=dict(
|
||||
type='BertModelCN',
|
||||
config=dict(
|
||||
vocab_size=21128,
|
||||
pad_token_id=0,
|
||||
add_type_embeddings=True,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
hidden_act='gelu',
|
||||
hidden_dropout_prob=0.1,
|
||||
hidden_size=768,
|
||||
initializer_range=0.02,
|
||||
intermediate_size=3072,
|
||||
max_position_embeddings=512,
|
||||
num_attention_heads=12,
|
||||
num_hidden_layers=12,
|
||||
type_vocab_size=2,
|
||||
layer_norm_eps=1e-12)),
|
||||
tokenizer=dict(
|
||||
type='FullTokenizer',
|
||||
vocab_file= # noqa
|
||||
'https://download.openmmlab.com/mmpretrain/v1.0/chinese_clip/vocab.txt'
|
||||
),
|
||||
proj_dim=768,
|
||||
text_prototype='cifar100',
|
||||
)
|
|
@ -0,0 +1,79 @@
|
|||
Collections:
|
||||
- Name: ChineseCLIP
|
||||
Metadata:
|
||||
Training Data:
|
||||
- LAION-5B
|
||||
- WuKong
|
||||
- VisualGenome
|
||||
- MSCOCO
|
||||
Architecture:
|
||||
- Transformer
|
||||
Paper:
|
||||
Title: 'Chinese CLIP: Contrastive Vision-Language Pretraining in Chinese'
|
||||
URL: https://arxiv.org/abs/2211.01335
|
||||
README: configs/chinese_clip/README.md
|
||||
|
||||
Models:
|
||||
- Name: cn-clip_resnet50_zeroshot-cls_cifar100
|
||||
Metadata:
|
||||
FLOPs: null
|
||||
Parameters: 77000000
|
||||
In Collection: ChineseCLIP
|
||||
Results:
|
||||
- Task: Image Classification
|
||||
Dataset: CIFAR100
|
||||
Metrics:
|
||||
Top 1 Accuracy: 40.7
|
||||
Weights: https://download.openmmlab.com/mmpretrain/v1.0/chinese_clip/cn-clip_resnet50_3rdparty_20230519-6a2b3eb2.pth
|
||||
Config: configs/chinese_clip/cn-clip_resnet50_zeroshot-cls_cifar100.py
|
||||
Converted From:
|
||||
Weights: https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/clip_cn_rn50.pt
|
||||
Code: https://github.com/OFA-Sys/Chinese-CLIP
|
||||
|
||||
- Name: cn-clip_vit-base-p16_zeroshot-cls_cifar100
|
||||
Metadata:
|
||||
FLOPs: null
|
||||
Parameters: 188000000
|
||||
In Collection: ChineseCLIP
|
||||
Results:
|
||||
- Task: Image Classification
|
||||
Dataset: CIFAR100
|
||||
Metrics:
|
||||
Top 1 Accuracy: 64.5
|
||||
Weights: https://download.openmmlab.com/mmpretrain/v1.0/chinese_clip/cn-clip_vit-base-p16_3rdparty_20230519-37fbc59e.pth
|
||||
Config: configs/chinese_clip/cn-clip_vit-base-p16_zeroshot-cls_cifar100.py
|
||||
Converted From:
|
||||
Weights: https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/clip_cn_vit-b-16.pt
|
||||
Code: https://github.com/OFA-Sys/Chinese-CLIP
|
||||
|
||||
- Name: cn-clip_vit-large-p14_zeroshot-cls_cifar100
|
||||
Metadata:
|
||||
FLOPs: null
|
||||
Parameters: 406000000
|
||||
In Collection: ChineseCLIP
|
||||
Results:
|
||||
- Task: Image Classification
|
||||
Dataset: CIFAR100
|
||||
Metrics:
|
||||
Top 1 Accuracy: 74.8
|
||||
Weights: https://download.openmmlab.com/mmpretrain/v1.0/chinese_clip/cn-clip_vit-large-p14_3rdparty_20230519-3f844503.pth
|
||||
Config: configs/chinese_clip/cn-clip_vit-large-p14_zeroshot-cls_cifar100.py
|
||||
Converted From:
|
||||
Weights: https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/clip_cn_vit-l-14.pt
|
||||
Code: https://github.com/OFA-Sys/Chinese-CLIP
|
||||
|
||||
- Name: cn-clip_vit-huge-p14_zeroshot-cls_cifar100
|
||||
Metadata:
|
||||
FLOPs: null
|
||||
Parameters: 958000000
|
||||
In Collection: ChineseCLIP
|
||||
Results:
|
||||
- Task: Image Classification
|
||||
Dataset: CIFAR100
|
||||
Metrics:
|
||||
Top 1 Accuracy: 79.1
|
||||
Weights: https://download.openmmlab.com/mmpretrain/v1.0/chinese_clip/cn-clip_vit-huge-p14_3rdparty_20230519-e4f49b00.pth
|
||||
Config: configs/chinese_clip/cn-clip_vit-huge-p14_zeroshot-cls_cifar100.py
|
||||
Converted From:
|
||||
Weights: https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/clip_cn_vit-h-14.pt
|
||||
Code: https://github.com/OFA-Sys/Chinese-CLIP
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import abstractmethod
|
||||
from math import ceil
|
||||
from typing import Callable, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
@ -154,7 +155,8 @@ class BaseInferencer:
|
|||
inputs = self.preprocess(
|
||||
ori_inputs, batch_size=batch_size, **preprocess_kwargs)
|
||||
preds = []
|
||||
for data in track(inputs, 'Inference'):
|
||||
for data in track(
|
||||
inputs, 'Inference', total=ceil(len(ori_inputs) / batch_size)):
|
||||
preds.extend(self.forward(data, **forward_kwargs))
|
||||
visualization = self.visualize(ori_inputs, preds, **visualize_kwargs)
|
||||
results = self.postprocess(preds, visualization, return_datasamples,
|
||||
|
|
|
@ -1427,3 +1427,14 @@ FOOD101_CATEGORIES = (
|
|||
'shrimp_and_grits', 'spaghetti_bolognese', 'spaghetti_carbonara',
|
||||
'spring_rolls', 'steak', 'strawberry_shortcake', 'sushi', 'tacos',
|
||||
'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles')
|
||||
|
||||
CIFAR100_CATEGORIES_CN = (
|
||||
'苹果', '水族馆鱼', '婴儿', '熊', '河狸', '床', '蜜蜂', '甲虫', '自行车', '瓶子', '碗', '小男孩',
|
||||
'桥', '公共汽车', '蝴蝶', '骆驼', '易拉罐', '城堡', '毛毛虫', '牛', '椅子', '猩猩', '钟', '白云',
|
||||
'蟑螂', '沙发', '螃蟹', '鳄鱼', '杯子', '恐龙', '海豚', '大象', '比目鱼', '森林', '狐狸', '小女孩',
|
||||
'仓鼠', '屋子', '袋鼠', '键盘', '台灯', '割草机', '猎豹', '狮子', '蜥蜴', '龙虾', '男人', '枫树',
|
||||
'摩托车', '山', '老鼠', '蘑菇', '橡树', '橙子橘子', '兰花', '水獭', '棕榈树', '梨', '皮卡车', '松树',
|
||||
'田野', '盘子', '罂粟', '豪猪', '负鼠', '兔子', '浣熊', '鳐鱼', '公路', '火箭', '玫瑰', '大海',
|
||||
'海豹', '鲨鱼', '尖嘴小鼠', '臭鼬', '摩天大楼', '蜗牛', '蛇', '蜘蛛', '松鼠', '电车', '向日葵', '甜椒',
|
||||
'桌子', '坦克', '电话', '电视', '老虎', '拖拉机', '火车', '鳟鱼', '郁金香', '乌龟', '衣柜', '鲸鱼',
|
||||
'柳树', '狼', '女人', '蠕虫')
|
||||
|
|
|
@ -4,6 +4,7 @@ from mmpretrain.utils.dependency import WITH_MULTIMODAL
|
|||
if WITH_MULTIMODAL:
|
||||
from .blip import * # noqa: F401,F403
|
||||
from .blip2 import * # noqa: F401,F403
|
||||
from .chinese_clip import * # noqa: F401, F403
|
||||
from .flamingo import * # noqa: F401, F403
|
||||
from .ofa import * # noqa: F401, F403
|
||||
else:
|
||||
|
@ -13,5 +14,5 @@ else:
|
|||
register_multimodal_placeholder([
|
||||
'Blip2Caption', 'Blip2Retrieval', 'Blip2VQA', 'BlipCaption',
|
||||
'BlipNLVR', 'BlipRetrieval', 'BlipGrounding', 'BlipVQA', 'Flamingo',
|
||||
'OFA'
|
||||
'OFA', 'ChineseCLIP'
|
||||
], MODELS)
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .bert import BertModelCN
|
||||
from .chinese_clip import ChineseCLIP, ModifiedResNet
|
||||
|
||||
__all__ = ['ChineseCLIP', 'ModifiedResNet', 'BertModelCN']
|
|
@ -0,0 +1,263 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# flake8: noqa
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
try:
|
||||
from transformers.models.bert.configuration_bert import BertConfig
|
||||
except:
|
||||
BertConfig = None
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..blip.language_model import BertAttention, BertIntermediate, BertOutput
|
||||
|
||||
|
||||
def gelu(x):
|
||||
"""Original Implementation of the gelu activation function in Google Bert
|
||||
repo when initially created.
|
||||
|
||||
For information: OpenAI GPT's gelu is slightly different (and gives
|
||||
slightly different results):
|
||||
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
Also see https://arxiv.org/abs/1606.08415
|
||||
""" # noqa
|
||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
|
||||
def gelu_new(x):
|
||||
"""Implementation of the gelu activation function currently in Google Bert
|
||||
repo (identical to OpenAI GPT) https://arxiv.org/abs/1606.08415."""
|
||||
return 0.5 * x * (1 + torch.tanh(
|
||||
math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
ACT2FN = {
|
||||
'gelu': gelu,
|
||||
'relu': torch.nn.functional.relu,
|
||||
'swish': swish,
|
||||
'gelu_new': gelu_new
|
||||
}
|
||||
|
||||
|
||||
class BertEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type
|
||||
embeddings."""
|
||||
|
||||
def __init__(self, config):
|
||||
super(BertEmbeddings, self).__init__()
|
||||
self.word_embeddings = nn.Embedding(
|
||||
config.vocab_size, config.hidden_size, padding_idx=0)
|
||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings,
|
||||
config.hidden_size)
|
||||
self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
|
||||
config.hidden_size)
|
||||
|
||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model
|
||||
# variable name and be able to load any TensorFlow checkpoint file
|
||||
self.LayerNorm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, position_ids=None):
|
||||
seq_length = input_ids.size(1)
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(
|
||||
seq_length, dtype=torch.long, device=input_ids.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros_like(input_ids)
|
||||
|
||||
words_embeddings = self.word_embeddings(input_ids)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
|
||||
embeddings = words_embeddings + position_embeddings \
|
||||
+ token_type_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
class BertLayer(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super(BertLayer, self).__init__()
|
||||
self.attention = BertAttention(config)
|
||||
self.intermediate = BertIntermediate(config)
|
||||
self.output = BertOutput(config)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
||||
attention_outputs = self.attention(hidden_states, attention_mask,
|
||||
head_mask)
|
||||
attention_output = attention_outputs[0]
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
outputs = (layer_output, ) + attention_outputs[
|
||||
1:] # add attentions if we output them
|
||||
if len(outputs) == 1:
|
||||
return outputs[0]
|
||||
return outputs
|
||||
|
||||
|
||||
class BertEncoder(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super(BertEncoder, self).__init__()
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.grad_checkpointing = False
|
||||
self.layer = nn.ModuleList(
|
||||
[BertLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
||||
all_hidden_states = ()
|
||||
all_attentions = ()
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states, )
|
||||
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
layer_outputs = checkpoint(layer_module, hidden_states,
|
||||
attention_mask, head_mask[i])
|
||||
else:
|
||||
layer_outputs = layer_module(hidden_states, attention_mask,
|
||||
head_mask[i])
|
||||
if not isinstance(layer_outputs, tuple):
|
||||
layer_outputs = (layer_outputs, )
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if self.output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1], )
|
||||
|
||||
# Add last layer
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states, )
|
||||
|
||||
outputs = (hidden_states, )
|
||||
if self.output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states, )
|
||||
if self.output_attentions:
|
||||
outputs = outputs + (all_attentions, )
|
||||
# last-layer hidden state, (all hidden states), (all attentions)
|
||||
return outputs
|
||||
|
||||
|
||||
class BertPreTrainedModel(nn.Module):
|
||||
base_model_prefix = 'bert'
|
||||
|
||||
def __init__(self, config):
|
||||
super(BertPreTrainedModel, self).__init__()
|
||||
self.config = config
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version
|
||||
# which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(
|
||||
mean=0.0, std=self.config.initializer_range)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BertModelCN(BertPreTrainedModel):
|
||||
"""The BERT model implementation for Chinese CLIP."""
|
||||
|
||||
def __init__(self, config):
|
||||
config = BertConfig.from_dict(config)
|
||||
super(BertModelCN, self).__init__(config)
|
||||
|
||||
self.embeddings = BertEmbeddings(config)
|
||||
self.encoder = BertEncoder(config)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
if enable:
|
||||
assert not self.config.output_attentions, \
|
||||
'Grad checkpointing is currently conflict with ' \
|
||||
'output_attentions for BertEncoder, ' \
|
||||
'please set it to False in BertConfig'
|
||||
|
||||
self.encoder.grad_checkpointing = enable
|
||||
|
||||
def forward(self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None):
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros_like(input_ids)
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = extended_attention_mask.to(
|
||||
dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
if head_mask is not None:
|
||||
if head_mask.dim() == 1:
|
||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(
|
||||
-1).unsqueeze(-1)
|
||||
head_mask = head_mask.expand(self.config.num_hidden_layers, -1,
|
||||
-1, -1, -1)
|
||||
elif head_mask.dim() == 2:
|
||||
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(
|
||||
-1) # We can specify head_mask for each layer
|
||||
head_mask = head_mask.to(dtype=next(self.parameters(
|
||||
)).dtype) # switch to fload if need + fp16 compatibility
|
||||
else:
|
||||
head_mask = [None] * self.config.num_hidden_layers
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
input_ids,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output, extended_attention_mask, head_mask=head_mask)
|
||||
sequence_output = encoder_outputs[0]
|
||||
# pooled_output = self.pooler(sequence_output)
|
||||
pooled_output = None
|
||||
|
||||
# add hidden_states and attentions if they are here
|
||||
outputs = (
|
||||
sequence_output,
|
||||
pooled_output,
|
||||
) + encoder_outputs[1:]
|
||||
|
||||
# sequence_output, pooled_output, (hidden_states), (attentions)
|
||||
return outputs
|
|
@ -0,0 +1,446 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from collections import OrderedDict
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmengine.model import BaseModel, BaseModule
|
||||
from torch import nn
|
||||
|
||||
from mmpretrain.datasets.categories import CIFAR100_CATEGORIES_CN
|
||||
from mmpretrain.registry import MODELS, TOKENIZER
|
||||
from mmpretrain.structures import DataSample
|
||||
from mmpretrain.utils import track_on_main_process
|
||||
from .utils import OPENAI_PROMPT
|
||||
|
||||
PROTOTYPE_MAP = {'cifar100': CIFAR100_CATEGORIES_CN}
|
||||
PROMPT_MAP = {'openai': OPENAI_PROMPT}
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1):
|
||||
super().__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
|
||||
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
|
||||
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
||||
|
||||
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = None
|
||||
self.stride = stride
|
||||
|
||||
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
||||
self.downsample = nn.Sequential(
|
||||
OrderedDict([('-1', nn.AvgPool2d(stride)),
|
||||
('0',
|
||||
nn.Conv2d(
|
||||
inplanes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
stride=1,
|
||||
bias=False)),
|
||||
('1', nn.BatchNorm2d(planes * self.expansion))]))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
identity = x
|
||||
|
||||
out = self.relu(self.bn1(self.conv1(x)))
|
||||
out = self.relu(self.bn2(self.conv2(out)))
|
||||
out = self.avgpool(out)
|
||||
out = self.bn3(self.conv3(out))
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class AttentionPool2d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
spacial_dim: int,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
output_dim: int = None):
|
||||
super().__init__()
|
||||
self.positional_embedding = nn.Parameter(
|
||||
torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
|
||||
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def forward(self, x):
|
||||
x = x.reshape(x.shape[0], x.shape[1],
|
||||
x.shape[2] * x.shape[3]).permute(2, 0,
|
||||
1) # NCHW -> (HW)NC
|
||||
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
||||
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
||||
x, _ = F.multi_head_attention_forward(
|
||||
query=x,
|
||||
key=x,
|
||||
value=x,
|
||||
embed_dim_to_check=x.shape[-1],
|
||||
num_heads=self.num_heads,
|
||||
q_proj_weight=self.q_proj.weight,
|
||||
k_proj_weight=self.k_proj.weight,
|
||||
v_proj_weight=self.v_proj.weight,
|
||||
in_proj_weight=None,
|
||||
in_proj_bias=torch.cat(
|
||||
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
||||
bias_k=None,
|
||||
bias_v=None,
|
||||
add_zero_attn=False,
|
||||
dropout_p=0,
|
||||
out_proj_weight=self.c_proj.weight,
|
||||
out_proj_bias=self.c_proj.bias,
|
||||
use_separate_proj_weight=True,
|
||||
training=self.training,
|
||||
need_weights=False)
|
||||
|
||||
return x[0]
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ModifiedResNet(BaseModule):
|
||||
"""A modified ResNet contains the following changes:
|
||||
|
||||
- Apply deep stem with an average pool instead of a max pool.
|
||||
- Performs anti-aliasing strided convolutions, where an avgpool is
|
||||
prepended to convolutions with stride > 1
|
||||
- The final pooling layer is a QKV attention instead of an average pool
|
||||
""" # noqa
|
||||
|
||||
arch_settings = {
|
||||
50: (Bottleneck, (3, 4, 6, 3)),
|
||||
101: (Bottleneck, (3, 4, 23, 3)),
|
||||
152: (Bottleneck, (3, 8, 36, 3))
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
depth: int = 50,
|
||||
base_channels: int = 64,
|
||||
input_size: int = 224,
|
||||
num_attn_heads: int = 32,
|
||||
output_dim: int = 1024,
|
||||
init_cfg: Optional[dict] = None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.input_size = input_size
|
||||
self.block, stage_blocks = self.arch_settings[depth]
|
||||
|
||||
# the 3-layer stem
|
||||
self.conv1 = nn.Conv2d(
|
||||
3,
|
||||
base_channels // 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(base_channels // 2)
|
||||
self.conv2 = nn.Conv2d(
|
||||
base_channels // 2,
|
||||
base_channels // 2,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(base_channels // 2)
|
||||
self.conv3 = nn.Conv2d(
|
||||
base_channels // 2,
|
||||
base_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(base_channels)
|
||||
self.avgpool = nn.AvgPool2d(2)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
# residual layers
|
||||
# this is a *mutable* variable used during construction
|
||||
self._inplanes = base_channels
|
||||
self.layer1 = self._make_layer(base_channels, stage_blocks[0])
|
||||
self.layer2 = self._make_layer(
|
||||
base_channels * 2, stage_blocks[1], stride=2)
|
||||
self.layer3 = self._make_layer(
|
||||
base_channels * 4, stage_blocks[2], stride=2)
|
||||
self.layer4 = self._make_layer(
|
||||
base_channels * 8, stage_blocks[3], stride=2)
|
||||
|
||||
embed_dim = base_channels * 32
|
||||
self.attnpool = AttentionPool2d(input_size // 32, embed_dim,
|
||||
num_attn_heads, output_dim)
|
||||
|
||||
def _make_layer(self, planes, blocks, stride=1):
|
||||
layers = [Bottleneck(self._inplanes, planes, stride)]
|
||||
|
||||
self._inplanes = planes * Bottleneck.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(Bottleneck(self._inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
def stem(x):
|
||||
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2),
|
||||
(self.conv3, self.bn3)]:
|
||||
x = self.relu(bn(conv(x)))
|
||||
x = self.avgpool(x)
|
||||
return x
|
||||
|
||||
x = x.type(self.conv1.weight.dtype)
|
||||
x = stem(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
x = self.attnpool(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ChineseCLIP(BaseModel):
|
||||
"""The implementation of `ChineseCLIP <https://arxiv.org/abs/2211.01335>`_.
|
||||
|
||||
Args:
|
||||
vision_backbone (dict): Config dict for vision backbone.
|
||||
text_backbone (dict): Config dict for text backbone.
|
||||
tokenizer (dict): Config dict for text tokenizer.
|
||||
proj_dim (int): Projection dimension for similarity computation.
|
||||
text_prototype (str): Text prototype, which can be a key in
|
||||
`PROTOTYPE_MAP` or list of text.
|
||||
text_prompt (str): The prompt for text prototype. Defaults to 'openai'.
|
||||
context_length (int): The context length to use. Defaults to 52.
|
||||
data_preprocessor (Union[dict, nn.Module], optional): The config for
|
||||
preprocessing input data. If None or no specified type, it will use
|
||||
"MultiModalDataPreprocessor" as type.
|
||||
See :class:`MultiModalDataPreprocessor` for more details.
|
||||
Defaults to None.
|
||||
init_cfg (dict, optional): The config to control the initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
vision_backbone: dict,
|
||||
text_backbone: dict,
|
||||
tokenizer: dict,
|
||||
proj_dim: int,
|
||||
text_prototype: Union[str, List[str]],
|
||||
text_prompt: str = 'openai',
|
||||
context_length: int = 52,
|
||||
data_preprocessor: Optional[dict] = None,
|
||||
init_cfg: Optional[dict] = None):
|
||||
if data_preprocessor is None:
|
||||
data_preprocessor = {}
|
||||
data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor')
|
||||
data_preprocessor = MODELS.build(data_preprocessor)
|
||||
|
||||
super().__init__(
|
||||
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
|
||||
|
||||
self.vision_backbone = MODELS.build(vision_backbone)
|
||||
self.text_backbone = MODELS.build(text_backbone)
|
||||
|
||||
if not isinstance(self.vision_backbone, ModifiedResNet):
|
||||
self.vision_projection = nn.Parameter(
|
||||
torch.empty(self.vision_backbone.embed_dims, proj_dim))
|
||||
text_hidden_size = text_backbone['config']['hidden_size']
|
||||
self.text_projection = nn.Parameter(
|
||||
torch.empty(text_hidden_size, proj_dim))
|
||||
|
||||
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
||||
|
||||
self.tokenizer = TOKENIZER.build(tokenizer)
|
||||
self.context_length = context_length
|
||||
|
||||
# for zero-shot classification
|
||||
if isinstance(text_prototype,
|
||||
str) and text_prototype in PROTOTYPE_MAP.keys():
|
||||
self.prototype = PROTOTYPE_MAP[text_prototype]
|
||||
else:
|
||||
self.prototype = text_prototype
|
||||
self.text_prototype_embeds = None
|
||||
|
||||
self.prompt = PROMPT_MAP[text_prompt]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
images: torch.Tensor,
|
||||
data_samples: Optional[list] = None,
|
||||
mode: str = 'predict',
|
||||
**kwargs,
|
||||
):
|
||||
"""The unified entry for a forward process in both training and test.
|
||||
The method accepts the following modes:
|
||||
|
||||
- "predict": Forward and return a list of data samples contain the
|
||||
predict results.
|
||||
|
||||
Args:
|
||||
images (torch.Tensor): the preprocessed image tensor of shape
|
||||
``(N, C, H, W)``.
|
||||
data_samples (List[DataSample], optional): The annotation data
|
||||
of every samples. Defaults to None.
|
||||
mode (str): Return what kind of value. Defaults to 'predict'.
|
||||
"""
|
||||
if mode == 'predict':
|
||||
return self.predict(images, data_samples, **kwargs)
|
||||
else:
|
||||
raise RuntimeError(f'Invalid mode "{mode}".')
|
||||
|
||||
def extract_image_feat(self, images: torch.Tensor) -> torch.Tensor:
|
||||
"""The function to extract image latent features."""
|
||||
if isinstance(self.vision_backbone, ModifiedResNet):
|
||||
return self.vision_backbone(images)
|
||||
return self.vision_backbone(images)[-1] @ self.vision_projection
|
||||
|
||||
def extract_text_feat(self, texts: torch.Tensor) -> torch.Tensor:
|
||||
"""The function to extract text latent features."""
|
||||
pad_index = self.tokenizer.vocab['[PAD]']
|
||||
attn_mask = texts.ne(pad_index)
|
||||
# [batch_size, seq_length, hidden_size]
|
||||
x = self.text_backbone(texts, attention_mask=attn_mask)[0]
|
||||
return x[:, 0, :] @ self.text_projection
|
||||
|
||||
def extract_feat(
|
||||
self, images: torch.Tensor,
|
||||
texts: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
|
||||
"""The function to extract image and text latent features, the input
|
||||
image or text can not both be None."""
|
||||
|
||||
assert images is not None or texts is not None, \
|
||||
'text and image cannot both be None!'
|
||||
if images is None:
|
||||
return self.extract_text_feat(texts)
|
||||
elif texts is None:
|
||||
return self.extract_image_feat(images)
|
||||
|
||||
image_features = self.extract_image_feat(images)
|
||||
text_features = self.extract_text_feat(texts)
|
||||
|
||||
image_features = image_features / image_features.norm(
|
||||
dim=-1, keepdim=True)
|
||||
text_features = text_features / text_features.norm(
|
||||
dim=-1, keepdim=True)
|
||||
|
||||
return image_features, text_features
|
||||
|
||||
def compute_similarity(self, images, texts):
|
||||
"""Extract images and texts features and compute cosine similarity."""
|
||||
image_features, text_features = self.extract_feat(
|
||||
images=images, texts=texts)
|
||||
|
||||
# cosine similarity as logits
|
||||
logit_scale = self.logit_scale.exp()
|
||||
logits_per_image = logit_scale * image_features @ text_features.t()
|
||||
logits_per_text = logits_per_image.t()
|
||||
|
||||
# shape (N, N)
|
||||
return logits_per_image, logits_per_text
|
||||
|
||||
def predict(self,
|
||||
images: torch.Tensor,
|
||||
data_samples: DataSample = None) -> DataSample:
|
||||
"""Predict the classes of the input images.
|
||||
|
||||
The prediction is for zero-shot classification and the text prototypes
|
||||
will be prepared in thisfunction.
|
||||
|
||||
Args:
|
||||
images (torch.Tensor): The input images.
|
||||
data_samples (DataSample): The data samples with information from
|
||||
dataset.
|
||||
|
||||
Returns:
|
||||
DataSample: The results of prediction.
|
||||
"""
|
||||
|
||||
if self.text_prototype_embeds is None:
|
||||
self.prepare_text_prototype(device=images.device)
|
||||
|
||||
image_features = self.extract_image_feat(images=images)
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
# cosine similarity as logits
|
||||
logits_per_image = image_features @ self.text_prototype_embeds.to(
|
||||
image_features.device) * self.logit_scale.exp()
|
||||
|
||||
pred_scores = F.softmax(logits_per_image, dim=1)
|
||||
pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach()
|
||||
|
||||
out_data_samples = []
|
||||
if data_samples is None:
|
||||
data_samples = [None for _ in range(pred_scores.size(0))]
|
||||
|
||||
for data_sample, score, label in zip(data_samples, pred_scores,
|
||||
pred_labels):
|
||||
if data_sample is None:
|
||||
data_sample = DataSample()
|
||||
|
||||
data_sample.set_pred_score(score).set_pred_label(label)
|
||||
out_data_samples.append(data_sample)
|
||||
return out_data_samples
|
||||
|
||||
def prepare_text_prototype(self, device) -> None:
|
||||
"""The function to prepare text prototypes with prompt."""
|
||||
class_embeddings = []
|
||||
for classname in track_on_main_process(self.prototype,
|
||||
'Prepare text prototype...'):
|
||||
# format with class
|
||||
texts = [prompt(classname) for prompt in self.prompt]
|
||||
tokenized_texts = self.tokenize(texts)
|
||||
class_features = self.extract_text_feat(tokenized_texts.to(device))
|
||||
class_features /= class_features.norm(dim=-1, keepdim=True)
|
||||
class_feature = class_features.mean(dim=0)
|
||||
class_feature /= class_feature.norm()
|
||||
class_embeddings.append(class_feature)
|
||||
self.text_prototype_embeds = torch.stack(
|
||||
class_embeddings, dim=1).to(device)
|
||||
|
||||
def tokenize(self, texts: Union[str, List[str]]) -> torch.LongTensor:
|
||||
"""Returns the tokenized representation of given input string(s)
|
||||
|
||||
Args:
|
||||
texts (Union[str, List[str]]): An input string or a list of input
|
||||
strings to tokenize
|
||||
context_length (int): The context length to use. Defaults to 52.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Resulting tokens.
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
all_tokens = []
|
||||
for text in texts:
|
||||
# adapt the text to Chinese BERT vocab
|
||||
text = text.lower().replace('“', "\"").replace('”', "\"")
|
||||
|
||||
# add special tokens
|
||||
all_tokens.append(
|
||||
[self.tokenizer.vocab['[CLS]']] +
|
||||
self.tokenizer.convert_tokens_to_ids(
|
||||
self.tokenizer.tokenize(text))[:self.context_length - 2] +
|
||||
[self.tokenizer.vocab['[SEP]']])
|
||||
|
||||
result = torch.zeros(
|
||||
len(all_tokens), self.context_length, dtype=torch.long)
|
||||
|
||||
for i, tokens in enumerate(all_tokens):
|
||||
assert len(tokens) <= self.context_length
|
||||
result[i, :len(tokens)] = torch.tensor(tokens)
|
||||
|
||||
return result
|
|
@ -0,0 +1,186 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
OPENAI_PROMPT = [
|
||||
lambda c: f'{c}的照片',
|
||||
lambda c: f'质量差的{c}的照片',
|
||||
lambda c: f'许多{c}的照片',
|
||||
lambda c: f'{c}的雕塑',
|
||||
lambda c: f'难以看到{c}的照片',
|
||||
lambda c: f'{c}的低分辨率照片',
|
||||
lambda c: f'{c}的渲染',
|
||||
lambda c: f'涂鸦{c}',
|
||||
lambda c: f'{c}的糟糕照片',
|
||||
lambda c: f'{c}的裁剪照片',
|
||||
lambda c: f'{c}的纹身',
|
||||
lambda c: f'{c}的刺绣照片',
|
||||
lambda c: f'很难看到{c}的照片',
|
||||
lambda c: f'{c}的明亮照片',
|
||||
lambda c: f'一张干净的{c}的照片',
|
||||
lambda c: f'一张包含{c}的照片',
|
||||
lambda c: f'{c}的深色照片',
|
||||
lambda c: f'{c}的手绘画',
|
||||
lambda c: f'我的{c}的照片',
|
||||
lambda c: f'不自然的{c}的照片',
|
||||
lambda c: f'一张酷的{c}的照片',
|
||||
lambda c: f'{c}的特写照片',
|
||||
lambda c: f'{c}的黑白照片',
|
||||
lambda c: f'一幅{c}的画',
|
||||
lambda c: f'一幅{c}的绘画',
|
||||
lambda c: f'一张{c}的像素照片',
|
||||
lambda c: f'{c}的雕像',
|
||||
lambda c: f'一张{c}的明亮照片',
|
||||
lambda c: f'{c}的裁剪照片',
|
||||
lambda c: f'人造的{c}的照片',
|
||||
lambda c: f'一张关于{c}的照片',
|
||||
lambda c: f'损坏的{c}的jpeg照片',
|
||||
lambda c: f'{c}的模糊照片',
|
||||
lambda c: f'{c}的相片',
|
||||
lambda c: f'一张{c}的好照片',
|
||||
lambda c: f'{c}的渲染照',
|
||||
lambda c: f'视频游戏中的{c}',
|
||||
lambda c: f'一张{c}的照片',
|
||||
lambda c: f'{c}的涂鸦',
|
||||
lambda c: f'{c}的近距离照片',
|
||||
lambda c: f'{c}的折纸',
|
||||
lambda c: f'{c}在视频游戏中',
|
||||
lambda c: f'{c}的草图',
|
||||
lambda c: f'{c}的涂鸦照',
|
||||
lambda c: f'{c}的折纸形状',
|
||||
lambda c: f'低分辨率的{c}的照片',
|
||||
lambda c: f'玩具{c}',
|
||||
lambda c: f'{c}的副本',
|
||||
lambda c: f'{c}的干净的照片',
|
||||
lambda c: f'一张大{c}的照片',
|
||||
lambda c: f'{c}的重现',
|
||||
lambda c: f'一张漂亮的{c}的照片',
|
||||
lambda c: f'一张奇怪的{c}的照片',
|
||||
lambda c: f'模糊的{c}的照片',
|
||||
lambda c: f'卡通{c}',
|
||||
lambda c: f'{c}的艺术作品',
|
||||
lambda c: f'{c}的素描',
|
||||
lambda c: f'刺绣{c}',
|
||||
lambda c: f'{c}的像素照',
|
||||
lambda c: f'{c}的拍照',
|
||||
lambda c: f'{c}的损坏的照片',
|
||||
lambda c: f'高质量的{c}的照片',
|
||||
lambda c: f'毛绒玩具{c}',
|
||||
lambda c: f'漂亮的{c}的照片',
|
||||
lambda c: f'小{c}的照片',
|
||||
lambda c: f'照片是奇怪的{c}',
|
||||
lambda c: f'漫画{c}',
|
||||
lambda c: f'{c}的艺术照',
|
||||
lambda c: f'{c}的图形',
|
||||
lambda c: f'大{c}的照片',
|
||||
lambda c: f'黑白的{c}的照片',
|
||||
lambda c: f'{c}毛绒玩具',
|
||||
lambda c: f'一张{c}的深色照片',
|
||||
lambda c: f'{c}的摄影图',
|
||||
lambda c: f'{c}的涂鸦照',
|
||||
lambda c: f'玩具形状的{c}',
|
||||
lambda c: f'拍了{c}的照片',
|
||||
lambda c: f'酷酷的{c}的照片',
|
||||
lambda c: f'照片里的小{c}',
|
||||
lambda c: f'{c}的刺青',
|
||||
lambda c: f'{c}的可爱的照片',
|
||||
lambda c: f'一张{c}可爱的照片',
|
||||
lambda c: f'{c}可爱图片',
|
||||
lambda c: f'{c}酷炫图片',
|
||||
lambda c: f'一张{c}的酷炫的照片',
|
||||
lambda c: f'一张{c}的酷炫图片',
|
||||
lambda c: f'这是{c}',
|
||||
lambda c: f'{c}的好看照片',
|
||||
lambda c: f'一张{c}的好看的图片',
|
||||
lambda c: f'{c}的好看图片',
|
||||
lambda c: f'{c}的照片。',
|
||||
lambda c: f'质量差的{c}的照片。',
|
||||
lambda c: f'许多{c}的照片。',
|
||||
lambda c: f'{c}的雕塑。',
|
||||
lambda c: f'难以看到{c}的照片。',
|
||||
lambda c: f'{c}的低分辨率照片。',
|
||||
lambda c: f'{c}的渲染。',
|
||||
lambda c: f'涂鸦{c}。',
|
||||
lambda c: f'{c}的糟糕照片。',
|
||||
lambda c: f'{c}的裁剪照片。',
|
||||
lambda c: f'{c}的纹身。',
|
||||
lambda c: f'{c}的刺绣照片。',
|
||||
lambda c: f'很难看到{c}的照片。',
|
||||
lambda c: f'{c}的明亮照片。',
|
||||
lambda c: f'一张干净的{c}的照片。',
|
||||
lambda c: f'一张包含{c}的照片。',
|
||||
lambda c: f'{c}的深色照片。',
|
||||
lambda c: f'{c}的手绘画。',
|
||||
lambda c: f'我的{c}的照片。',
|
||||
lambda c: f'不自然的{c}的照片。',
|
||||
lambda c: f'一张酷的{c}的照片。',
|
||||
lambda c: f'{c}的特写照片。',
|
||||
lambda c: f'{c}的黑白照片。',
|
||||
lambda c: f'一幅{c}的画。',
|
||||
lambda c: f'一幅{c}的绘画。',
|
||||
lambda c: f'一张{c}的像素照片。',
|
||||
lambda c: f'{c}的雕像。',
|
||||
lambda c: f'一张{c}的明亮照片。',
|
||||
lambda c: f'{c}的裁剪照片。',
|
||||
lambda c: f'人造的{c}的照片。',
|
||||
lambda c: f'一张关于{c}的照片。',
|
||||
lambda c: f'损坏的{c}的jpeg照片。',
|
||||
lambda c: f'{c}的模糊照片。',
|
||||
lambda c: f'{c}的相片。',
|
||||
lambda c: f'一张{c}的好照片。',
|
||||
lambda c: f'{c}的渲染照。',
|
||||
lambda c: f'视频游戏中的{c}。',
|
||||
lambda c: f'一张{c}的照片。',
|
||||
lambda c: f'{c}的涂鸦。',
|
||||
lambda c: f'{c}的近距离照片。',
|
||||
lambda c: f'{c}的折纸。',
|
||||
lambda c: f'{c}在视频游戏中。',
|
||||
lambda c: f'{c}的草图。',
|
||||
lambda c: f'{c}的涂鸦照。',
|
||||
lambda c: f'{c}的折纸形状。',
|
||||
lambda c: f'低分辨率的{c}的照片。',
|
||||
lambda c: f'玩具{c}。',
|
||||
lambda c: f'{c}的副本。',
|
||||
lambda c: f'{c}的干净的照片。',
|
||||
lambda c: f'一张大{c}的照片。',
|
||||
lambda c: f'{c}的重现。',
|
||||
lambda c: f'一张漂亮的{c}的照片。',
|
||||
lambda c: f'一张奇怪的{c}的照片。',
|
||||
lambda c: f'模糊的{c}的照片。',
|
||||
lambda c: f'卡通{c}。',
|
||||
lambda c: f'{c}的艺术作品。',
|
||||
lambda c: f'{c}的素描。',
|
||||
lambda c: f'刺绣{c}。',
|
||||
lambda c: f'{c}的像素照。',
|
||||
lambda c: f'{c}的拍照。',
|
||||
lambda c: f'{c}的损坏的照片。',
|
||||
lambda c: f'高质量的{c}的照片。',
|
||||
lambda c: f'毛绒玩具{c}。',
|
||||
lambda c: f'漂亮的{c}的照片。',
|
||||
lambda c: f'小{c}的照片。',
|
||||
lambda c: f'照片是奇怪的{c}。',
|
||||
lambda c: f'漫画{c}。',
|
||||
lambda c: f'{c}的艺术照。',
|
||||
lambda c: f'{c}的图形。',
|
||||
lambda c: f'大{c}的照片。',
|
||||
lambda c: f'黑白的{c}的照片。',
|
||||
lambda c: f'{c}毛绒玩具。',
|
||||
lambda c: f'一张{c}的深色照片。',
|
||||
lambda c: f'{c}的摄影图。',
|
||||
lambda c: f'{c}的涂鸦照。',
|
||||
lambda c: f'玩具形状的{c}。',
|
||||
lambda c: f'拍了{c}的照片。',
|
||||
lambda c: f'酷酷的{c}的照片。',
|
||||
lambda c: f'照片里的小{c}。',
|
||||
lambda c: f'{c}的刺青。',
|
||||
lambda c: f'{c}的可爱的照片。',
|
||||
lambda c: f'一张{c}可爱的照片。',
|
||||
lambda c: f'{c}可爱图片。',
|
||||
lambda c: f'{c}酷炫图片。',
|
||||
lambda c: f'一张{c}的酷炫的照片。',
|
||||
lambda c: f'一张{c}的酷炫图片。',
|
||||
lambda c: f'这是{c}。',
|
||||
lambda c: f'{c}的好看照片。',
|
||||
lambda c: f'一张{c}的好看的图片。',
|
||||
lambda c: f'{c}的好看图片。',
|
||||
lambda c: f'一种叫{c}的花的照片',
|
||||
lambda c: f'一种叫{c}的食物的照片',
|
||||
lambda c: f'{c}的卫星照片',
|
||||
]
|
|
@ -83,9 +83,10 @@ __all__ = [
|
|||
if WITH_MULTIMODAL:
|
||||
from .huggingface import (no_load_hf_pretrained_model, register_hf_model,
|
||||
register_hf_tokenizer)
|
||||
from .tokenizer import Blip2Tokenizer, BlipTokenizer, OFATokenizer
|
||||
from .tokenizer import (Blip2Tokenizer, BlipTokenizer, FullTokenizer,
|
||||
OFATokenizer)
|
||||
|
||||
__all__.extend([
|
||||
'BlipTokenizer', 'OFATokenizer', 'Blip2Tokenizer', 'register_hf_model',
|
||||
'register_hf_tokenizer', 'no_load_hf_pretrained_model'
|
||||
'register_hf_tokenizer', 'no_load_hf_pretrained_model', 'FullTokenizer'
|
||||
])
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import collections
|
||||
import os
|
||||
|
||||
from transformers import (AutoTokenizer, BartTokenizer, BertTokenizer,
|
||||
BertTokenizerFast, LlamaTokenizer)
|
||||
from mmengine.fileio import list_from_file
|
||||
from transformers import (AutoTokenizer, BartTokenizer, BasicTokenizer,
|
||||
BertTokenizer, BertTokenizerFast, LlamaTokenizer,
|
||||
WordpieceTokenizer)
|
||||
|
||||
from mmpretrain.registry import TOKENIZER
|
||||
from .huggingface import register_hf_tokenizer
|
||||
|
||||
register_hf_tokenizer(AutoTokenizer)
|
||||
|
@ -110,3 +114,74 @@ class OFATokenizer(BartTokenizer):
|
|||
tokenizer.bin_offset = length + 8192
|
||||
tokenizer.num_bins = num_bins
|
||||
return tokenizer
|
||||
|
||||
|
||||
@TOKENIZER.register_module()
|
||||
class FullTokenizer(BertTokenizer):
|
||||
"""Runs end-to-end tokenziation."""
|
||||
|
||||
def __init__(self, vocab_file, do_lower_case=True):
|
||||
self.vocab = self.load_vocab(vocab_file)
|
||||
self.inv_vocab = {v: k for k, v in self.vocab.items()}
|
||||
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
|
||||
self.wordpiece_tokenizer = WordpieceTokenizer(
|
||||
vocab=self.vocab, unk_token='[UNK]', max_input_chars_per_word=200)
|
||||
|
||||
def load_vocab(self, vocab_file):
|
||||
"""Loads a vocabulary file into a dictionary."""
|
||||
vocab = collections.OrderedDict()
|
||||
index = 0
|
||||
vocab_list = list_from_file(vocab_file)
|
||||
for token in vocab_list:
|
||||
if not token:
|
||||
break
|
||||
token = token.strip()
|
||||
vocab[token] = index
|
||||
index += 1
|
||||
return vocab
|
||||
|
||||
def tokenize(self, text):
|
||||
split_tokens = []
|
||||
for token in self.basic_tokenizer.tokenize(text):
|
||||
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
||||
split_tokens.append(sub_token)
|
||||
|
||||
return split_tokens
|
||||
|
||||
def convert_by_vocab(self, vocab, items):
|
||||
"""Converts a sequence of [tokens|ids] using the vocab."""
|
||||
output = []
|
||||
for item in items:
|
||||
output.append(vocab[item])
|
||||
return output
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
return self.convert_by_vocab(self.vocab, tokens)
|
||||
|
||||
def convert_ids_to_tokens(self, ids):
|
||||
return self.convert_by_vocab(self.inv_vocab, ids)
|
||||
|
||||
@staticmethod
|
||||
def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True):
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
|
||||
def clean_up_tokenization(out_string):
|
||||
"""Clean up a list of simple English tokenization artifacts like
|
||||
spaces before punctuations and abbreviated forms."""
|
||||
out_string = (
|
||||
out_string.replace(' .', '.').replace(' ?', '?').replace(
|
||||
' !', '!').replace(' ,', ',').replace(" ' ", "'").replace(
|
||||
" n't", "n't").replace(" 'm", "'m").replace(
|
||||
" 's", "'s").replace(" 've",
|
||||
"'ve").replace(" 're", "'re"))
|
||||
return out_string
|
||||
|
||||
text = ' '.join(tokens).replace(' ##', '').strip()
|
||||
if clean_up_tokenization_spaces:
|
||||
clean_text = clean_up_tokenization(text)
|
||||
return clean_text
|
||||
else:
|
||||
return text
|
||||
|
||||
def vocab_size(self):
|
||||
return len(self.vocab)
|
||||
|
|
|
@ -1,19 +1,40 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional
|
||||
|
||||
import mmengine.dist as dist
|
||||
import rich.progress as progress
|
||||
from rich.live import Live
|
||||
|
||||
disable_progress_bar = False
|
||||
global_progress = progress.Progress(
|
||||
'{task.description}',
|
||||
progress.BarColumn(),
|
||||
progress.TaskProgressColumn(show_speed=True),
|
||||
progress.TimeRemainingColumn(),
|
||||
)
|
||||
global_live = Live(global_progress, refresh_per_second=10)
|
||||
|
||||
|
||||
def track(sequence, *args, **kwargs):
|
||||
def track(sequence, description: str = '', total: Optional[float] = None):
|
||||
if disable_progress_bar:
|
||||
return sequence
|
||||
yield from sequence
|
||||
else:
|
||||
return progress.track(sequence, *args, **kwargs)
|
||||
global_live.start()
|
||||
task_id = global_progress.add_task(description, total=total)
|
||||
task = global_progress._tasks[task_id]
|
||||
try:
|
||||
yield from global_progress.track(sequence, task_id=task_id)
|
||||
finally:
|
||||
if task.total is None:
|
||||
global_progress.update(task_id, total=task.completed)
|
||||
if all(task.finished for task in global_progress.tasks):
|
||||
global_live.stop()
|
||||
for task_id in global_progress.task_ids:
|
||||
global_progress.remove_task(task_id)
|
||||
|
||||
|
||||
def track_on_main_process(sequence, *args, **kwargs):
|
||||
def track_on_main_process(sequence, description='', total=None):
|
||||
if not dist.is_main_process() or disable_progress_bar:
|
||||
yield from sequence
|
||||
else:
|
||||
yield from progress.track(sequence, *args, **kwargs)
|
||||
yield from track(sequence, total=total, description=description)
|
||||
|
|
|
@ -75,3 +75,4 @@ Import:
|
|||
- configs/blip/metafile.yml
|
||||
- configs/flamingo/metafile.yml
|
||||
- configs/blip2/metafile.yml
|
||||
- configs/chinese_clip/metafile.yml
|
||||
|
|
Loading…
Reference in New Issue