Merge branch 'open-mmlab:main' into cky/starnet_backbone
commit
e898b1fd9e
10
README.md
10
README.md
|
@ -86,13 +86,15 @@ https://github.com/open-mmlab/mmpretrain/assets/26739999/e4dcd3a2-f895-4d1b-a351
|
|||
|
||||
## What's new
|
||||
|
||||
🌟 v1.0.2 was released in 15/08/2023
|
||||
🌟 v1.2.0 was released in 04/01/2023
|
||||
|
||||
Support [MFF](./configs/mff/) self-supervised algorithm and enhance the codebase. More details can be found in the [changelog](https://mmpretrain.readthedocs.io/en/latest/notes/changelog.html).
|
||||
- Support LLaVA 1.5.
|
||||
- Implement of RAM with a gradio interface.
|
||||
|
||||
🌟 v1.0.1 was released in 28/07/2023
|
||||
🌟 v1.1.0 was released in 12/10/2023
|
||||
|
||||
Fix some bugs and enhance the codebase. Please refer to [changelog](https://mmpretrain.readthedocs.io/en/latest/notes/changelog.html) for more details.
|
||||
- Support Mini-GPT4 training and provide a Chinese model (based on Baichuan-7B)
|
||||
- Support zero-shot classification based on CLIP.
|
||||
|
||||
🌟 v1.0.0 was released in 04/07/2023
|
||||
|
||||
|
|
|
@ -84,13 +84,15 @@ https://github.com/open-mmlab/mmpretrain/assets/26739999/e4dcd3a2-f895-4d1b-a351
|
|||
|
||||
## 更新日志
|
||||
|
||||
🌟 2023/8/15 发布了 v1.0.2 版本
|
||||
🌟 2024/01/04 发布了 v1.2.0 版本
|
||||
|
||||
支持了 [MFF](./configs/mff/) 自监督算法,增强算法库功能。细节请参考 [更新日志](https://mmpretrain.readthedocs.io/zh_CN/latest/notes/changelog.html)。
|
||||
- 支持了 LLaVA 1.5
|
||||
- 实现了一个 RAM 模型的 gradio 推理例程
|
||||
|
||||
🌟 2023/7/28 发布了 v1.0.1 版本
|
||||
🌟 2023/10/12 发布了 v1.1.0 版本
|
||||
|
||||
修复部分 bug 和增强算法库功能。细节请参考 [更新日志](https://mmpretrain.readthedocs.io/zh_CN/latest/notes/changelog.html)。
|
||||
- 支持 Mini-GPT4 训练并提供一个基于 Baichuan-7B 的中文模型
|
||||
- 支持基于 CLIP 的零样本分类。
|
||||
|
||||
🌟 2023/7/4 发布了 v1.0.0 版本
|
||||
|
||||
|
@ -333,10 +335,10 @@ MMPreTrain 是一款由不同学校和公司共同贡献的开源项目。我们
|
|||
|
||||
## 欢迎加入 OpenMMLab 社区
|
||||
|
||||
扫描下方的二维码可关注 OpenMMLab 团队的 [知乎官方账号](https://www.zhihu.com/people/openmmlab),加入 OpenMMLab 团队的 [官方交流 QQ 群](https://jq.qq.com/?_wv=1027&k=aCvMxdr3) 或联络 OpenMMLab 官方微信小助手
|
||||
扫描下方的二维码可关注 OpenMMLab 团队的 [知乎官方账号](https://www.zhihu.com/people/openmmlab),扫描下方微信二维码添加喵喵好友,进入 MMPretrain 微信交流社群。【加好友申请格式:研究方向+地区+学校/公司+姓名】
|
||||
|
||||
<div align="center">
|
||||
<img src="./resources/zhihu_qrcode.jpg" height="400"/> <img src="./resources/xiaozhushou_weixin_qrcode.jpeg" height="400"/>
|
||||
<img src="./resources/zhihu_qrcode.jpg" height="400"/> <img src="./resources/miaomiao_qrcode.jpg" height="400"/>
|
||||
</div>
|
||||
|
||||
我们会在 OpenMMLab 社区为大家
|
||||
|
|
|
@ -16,46 +16,28 @@ Instruction tuning large language models (LLMs) using machine-generated instruct
|
|||
|
||||
<!-- [TABS-BEGIN] -->
|
||||
|
||||
**Prepare the checkpoint**
|
||||
|
||||
According to the license of LLaMA, we cannot provide the merged checkpoint directly. Please use the below
|
||||
script to download and get the merged the checkpoint.
|
||||
|
||||
```shell
|
||||
python tools/model_converters/llava-delta2mmpre.py huggyllama/llama-7b liuhaotian/LLaVA-Lightning-7B-delta-v1-1 ./LLaVA-Lightning-7B-delta-v1-1.pth
|
||||
```
|
||||
|
||||
**Use the model**
|
||||
|
||||
```python
|
||||
import torch
|
||||
from mmpretrain import get_model, inference_model
|
||||
|
||||
model = get_model('llava-7b-v1_caption', pretrained='MERGED_CHECKPOINT_PATH', device='cuda')
|
||||
out = inference_model(model, 'demo/cat-dog.png')
|
||||
out = inference_model('llava-7b-v1_caption', 'demo/cat-dog.png', device='cuda')
|
||||
print(out)
|
||||
# {'pred_caption': 'In the image, there are two cats sitting on a blanket.'}
|
||||
```
|
||||
|
||||
**Test Command**
|
||||
|
||||
Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset).
|
||||
|
||||
Test:
|
||||
|
||||
```shell
|
||||
python tools/test.py configs/llava/llava-7b-v1_caption.py MERGED_CHECKPOINT_PATH
|
||||
```
|
||||
|
||||
<!-- [TABS-END] -->
|
||||
|
||||
## Models and results
|
||||
|
||||
### Image Caption on COCO
|
||||
|
||||
| Model | Params (M) | BLEU-4 | CIDER | Config | Download |
|
||||
| :-------------------- | :--------: | :------: | :------: | :------------------------------: | :--------------------: |
|
||||
| `llava-7b-v1_caption` | 7045.82 | Upcoming | Upcoming | [config](llava-7b-v1_caption.py) | See the above tutorial |
|
||||
| Model | Params (M) | Config | Download |
|
||||
| :---------------------- | :--------: | :--------------------------------: | :-------------------------------------------------------------------------------------------------------------: |
|
||||
| `llava-7b-v1_caption` | 7045.82 | [config](llava-7b-v1_caption.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1_liuhaotian_20231025-c9e119b6.pth) |
|
||||
| `llava-7b-v1.5_caption` | 7062.90 | [config](llava-7b-v1.5_caption.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth) |
|
||||
| `llava-7b-v1.5_vqa` | 7062.90 | [config](llava-7b-v1.5_vqa.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth) |
|
||||
|
||||
## Citation
|
||||
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
_base_ = '../_base_/default_runtime.py'
|
||||
|
||||
meta_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." # noqa: E501
|
||||
image_size = 336
|
||||
prompt_tmpl = f'''{meta_prompt} User: <image>
|
||||
Describe the image in detail. ASSISTANT:'''
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='Llava',
|
||||
tokenizer=dict(
|
||||
type='AutoTokenizer', name_or_path='liuhaotian/llava-v1.5-7b'),
|
||||
vision_encoder=dict(
|
||||
type='VisionTransformer',
|
||||
arch='l',
|
||||
patch_size=14,
|
||||
img_size=image_size,
|
||||
pre_norm=True,
|
||||
norm_cfg=dict(type='LN', eps=1e-5),
|
||||
layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')),
|
||||
final_norm=False,
|
||||
out_type='raw',
|
||||
pretrained='https://download.openmmlab.com/mmclassification/v0/clip/'
|
||||
'vit-large-p14_clip-openai-pre_336px_20231025-fb1315ed.pth',
|
||||
),
|
||||
mm_hidden_size=1024,
|
||||
use_im_patch=False,
|
||||
use_im_start_end=False,
|
||||
mm_proj_depth=2,
|
||||
lang_encoder=dict(
|
||||
type='AutoModelForCausalLM',
|
||||
name_or_path='huggyllama/llama-7b',
|
||||
),
|
||||
task='caption',
|
||||
prompt_tmpl=prompt_tmpl,
|
||||
generation_cfg=dict(num_beams=3, max_new_tokens=50, length_penalty=-1.0),
|
||||
)
|
||||
|
||||
# data settings
|
||||
data_preprocessor = dict(
|
||||
type='MultiModalDataPreprocessor',
|
||||
mean=[122.770938, 116.7460125, 104.09373615],
|
||||
std=[68.5005327, 66.6321579, 70.32316305],
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
scale=(image_size, image_size),
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='PackInputs', meta_keys=['image_id']),
|
||||
]
|
||||
|
||||
test_dataloader = dict(
|
||||
batch_size=8,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type='COCOCaption',
|
||||
data_root='data/coco',
|
||||
ann_file='annotations/coco_karpathy_val.json',
|
||||
pipeline=test_pipeline,
|
||||
),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
test_evaluator = dict(
|
||||
type='COCOCaption',
|
||||
ann_file='data/coco/annotations/coco_karpathy_val_gt.json',
|
||||
)
|
||||
|
||||
# schedule settings
|
||||
test_cfg = dict()
|
|
@ -0,0 +1,76 @@
|
|||
_base_ = '../_base_/default_runtime.py'
|
||||
|
||||
meta_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." # noqa: E501
|
||||
image_size = 336
|
||||
prompt_tmpl = f'''{meta_prompt} User: <image>
|
||||
{{question}} ASSISTANT:'''
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='Llava',
|
||||
tokenizer=dict(
|
||||
type='AutoTokenizer', name_or_path='liuhaotian/llava-v1.5-7b'),
|
||||
vision_encoder=dict(
|
||||
type='VisionTransformer',
|
||||
arch='l',
|
||||
patch_size=14,
|
||||
img_size=image_size,
|
||||
pre_norm=True,
|
||||
norm_cfg=dict(type='LN', eps=1e-5),
|
||||
layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')),
|
||||
final_norm=False,
|
||||
out_type='raw',
|
||||
pretrained='https://download.openmmlab.com/mmclassification/v0/clip/'
|
||||
'vit-large-p14_clip-openai-pre_336px_20231025-fb1315ed.pth',
|
||||
),
|
||||
mm_hidden_size=1024,
|
||||
use_im_patch=False,
|
||||
use_im_start_end=False,
|
||||
mm_proj_depth=2,
|
||||
lang_encoder=dict(
|
||||
type='AutoModelForCausalLM',
|
||||
name_or_path='huggyllama/llama-7b',
|
||||
),
|
||||
task='vqa',
|
||||
prompt_tmpl=prompt_tmpl,
|
||||
generation_cfg=dict(max_new_tokens=100),
|
||||
)
|
||||
|
||||
# data settings
|
||||
data_preprocessor = dict(
|
||||
type='MultiModalDataPreprocessor',
|
||||
mean=[122.770938, 116.7460125, 104.09373615],
|
||||
std=[68.5005327, 66.6321579, 70.32316305],
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
scale=(image_size, image_size),
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='PackInputs', meta_keys=['image_id', 'question']),
|
||||
]
|
||||
|
||||
test_dataloader = dict(
|
||||
batch_size=8,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type='COCOCaption',
|
||||
data_root='data/coco',
|
||||
ann_file='annotations/coco_karpathy_val.json',
|
||||
pipeline=test_pipeline,
|
||||
),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
test_evaluator = dict(
|
||||
type='COCOCaption',
|
||||
ann_file='data/coco/annotations/coco_karpathy_val_gt.json',
|
||||
)
|
||||
|
||||
# schedule settings
|
||||
test_cfg = dict()
|
|
@ -1,16 +1,9 @@
|
|||
_base_ = '../_base_/default_runtime.py'
|
||||
|
||||
meta_prompt = 'You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab.You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.Follow the instructions carefully and explain your answers in detail.' # noqa: E501
|
||||
im_patch_token = '<im_patch>'
|
||||
patch_size = 14
|
||||
image_size = 224
|
||||
num_patches = (image_size // patch_size)**2
|
||||
caption_prompt = ' '.join([
|
||||
meta_prompt,
|
||||
'User: a photo of\n',
|
||||
im_patch_token * num_patches,
|
||||
'ASSISTANT:',
|
||||
])
|
||||
prompt_tmpl = f'''{meta_prompt} User: <im_start><image><im_end>
|
||||
Describe the image in detail. ASSISTANT:'''
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
|
@ -22,6 +15,7 @@ model = dict(
|
|||
type='VisionTransformer',
|
||||
arch='l',
|
||||
patch_size=14,
|
||||
img_size=image_size,
|
||||
pre_norm=True,
|
||||
norm_cfg=dict(type='LN', eps=1e-5),
|
||||
layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')),
|
||||
|
@ -32,15 +26,16 @@ model = dict(
|
|||
'vit-large-p14_clip-openai-pre_3rdparty_20230517-95e2af0b.pth'),
|
||||
),
|
||||
mm_hidden_size=1024,
|
||||
use_im_start_end=False,
|
||||
use_mm_proj=True,
|
||||
use_im_patch=False,
|
||||
use_im_start_end=True,
|
||||
mm_proj_depth=1,
|
||||
lang_encoder=dict(
|
||||
type='AutoModelForCausalLM',
|
||||
name_or_path='huggyllama/llama-7b',
|
||||
),
|
||||
task='caption',
|
||||
prompt_tmpl=caption_prompt,
|
||||
generation_cfg=dict(num_beams=3, max_new_tokens=20, length_penalty=-2.0),
|
||||
prompt_tmpl=prompt_tmpl,
|
||||
generation_cfg=dict(max_new_tokens=50),
|
||||
)
|
||||
|
||||
# data settings
|
||||
|
|
|
@ -21,5 +21,31 @@ Models:
|
|||
Metrics:
|
||||
BLEU-4: null
|
||||
CIDER: null
|
||||
Weights: null
|
||||
Weights: https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1_liuhaotian_20231025-c9e119b6.pth
|
||||
Config: configs/llava/llava-7b-v1_caption.py
|
||||
- Name: llava-7b-v1.5_caption
|
||||
Metadata:
|
||||
FLOPs: null
|
||||
Parameters: 7062900736
|
||||
In Collection: LLaVA
|
||||
Results:
|
||||
- Task: Image Caption
|
||||
Dataset: COCO
|
||||
Metrics:
|
||||
BLEU-4: null
|
||||
CIDER: null
|
||||
Weights: https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth
|
||||
Config: configs/llava/llava-7b-v1.5_caption.py
|
||||
- Name: llava-7b-v1.5_vqa
|
||||
Metadata:
|
||||
FLOPs: null
|
||||
Parameters: 7062900736
|
||||
In Collection: LLaVA
|
||||
Results:
|
||||
- Task: Visual Question Answering
|
||||
Dataset: COCO
|
||||
Metrics:
|
||||
BLEU-4: null
|
||||
CIDER: null
|
||||
Weights: https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth
|
||||
Config: configs/llava/llava-7b-v1.5_vqa.py
|
||||
|
|
|
@ -34,9 +34,10 @@ For Vicuna model, please refer to [MiniGPT-4 page](https://github.com/Vision-CAI
|
|||
|
||||
### Pretrained models
|
||||
|
||||
| Model | Params (M) | Flops (G) | Config | Download |
|
||||
| :------------------------------ | :--------: | :-------: | :--------------------------------------: | :------------------------------------------------------------------------------------------------------------: |
|
||||
| `minigpt-4_vicuna-7b_caption`\* | 8121.32 | N/A | [config](minigpt-4_vicuna-7b_caption.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_linear-projection_20230615-714b5f52.pth) |
|
||||
| Model | Params (M) | Flops (G) | Config | Download |
|
||||
| :------------------------------ | :--------: | :-------: | :----------------------------------------: | :----------------------------------------------------------------------------------------------------------: |
|
||||
| `minigpt-4_baichuan-7b_caption` | 8094.77 | N/A | [config](minigpt-4_baichuan-7b_caption.py) | [model](https://download.openmmlab.com/mmclassification/v1/minigpt4/minigpt-4_linear_baichuan7b_20231011-5dca7ed6.pth) |
|
||||
| `minigpt-4_vicuna-7b_caption`\* | 8121.32 | N/A | [config](minigpt-4_vicuna-7b_caption.py) | [model](https://download.openmmlab.com/mmclassification/v1/minigpt4/minigpt-4_linear_vicuna7b_20230615-714b5f52.pth) |
|
||||
|
||||
*Models with * are converted from the [official repo](https://github.com/Vision-CAIR/MiniGPT-4/tree/main). The config files of these models are only for inference. We haven't reproduce the training results.*
|
||||
|
||||
|
|
|
@ -19,8 +19,19 @@ Models:
|
|||
- Task: Image Caption
|
||||
Dataset: COCO
|
||||
Metrics: null
|
||||
Weights: https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_linear-projection_20230615-714b5f52.pth
|
||||
Weights: https://download.openmmlab.com/mmclassification/v1/minigpt4/minigpt-4_linear_vicuna7b_20230615-714b5f52.pth
|
||||
Config: configs/minigpt4/minigpt-4_vicuna-7b_caption.py
|
||||
Converted From:
|
||||
Weights: https://github.com/Vision-CAIR/MiniGPT-4/tree/main
|
||||
Code: https://github.com/Vision-CAIR/MiniGPT-4/tree/main
|
||||
- Name: minigpt-4_baichuan-7b_caption
|
||||
Metadata:
|
||||
FLOPs: null
|
||||
Parameters: 8094769024
|
||||
In Collection: MiniGPT4
|
||||
Results:
|
||||
- Task: Image Caption
|
||||
Dataset: COCO
|
||||
Metrics: null
|
||||
Weights: https://download.openmmlab.com/mmclassification/v1/minigpt4/minigpt-4_linear_baichuan7b_20231011-5dca7ed6.pth
|
||||
Config: configs/minigpt4/minigpt-4_baichuan-7b_caption.py
|
||||
|
|
|
@ -0,0 +1,190 @@
|
|||
_base_ = [
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
data_preprocessor = dict(
|
||||
type='MultiModalDataPreprocessor',
|
||||
mean=[122.770938, 116.7460125, 104.09373615],
|
||||
std=[68.5005327, 66.6321579, 70.32316305],
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
# dataset settings
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
scale=(224, 224),
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='CleanCaption',
|
||||
keys='chat_content',
|
||||
remove_chars='',
|
||||
lowercase=False),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['chat_content', 'lang'],
|
||||
meta_keys=['image_id']),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=2,
|
||||
num_workers=4,
|
||||
dataset=dict(
|
||||
type='MiniGPT4Dataset',
|
||||
data_root='YOUR_DATA_DIRECTORY',
|
||||
ann_file='YOUR_DATA_FILE',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
collate_fn=dict(type='default_collate'),
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
scale=(224, 224),
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='PackInputs', meta_keys=['image_id']),
|
||||
]
|
||||
|
||||
test_evaluator = dict(
|
||||
type='COCOCaption',
|
||||
ann_file='data/coco/annotations/coco_karpathy_val_gt.json',
|
||||
)
|
||||
|
||||
test_dataloader = dict(
|
||||
batch_size=1,
|
||||
dataset=dict(
|
||||
type='COCOCaption',
|
||||
data_root='data/coco',
|
||||
ann_file='annotations/coco_karpathy_val.json',
|
||||
pipeline=test_pipeline))
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='MiniGPT4',
|
||||
vision_encoder=dict(
|
||||
type='BEiTViT',
|
||||
# eva-g without the final layer
|
||||
arch=dict(
|
||||
embed_dims=1408,
|
||||
num_layers=39,
|
||||
num_heads=16,
|
||||
feedforward_channels=6144,
|
||||
),
|
||||
img_size=224,
|
||||
patch_size=14,
|
||||
layer_scale_init_value=0.0,
|
||||
frozen_stages=39,
|
||||
use_abs_pos_emb=True,
|
||||
use_rel_pos_bias=False,
|
||||
final_norm=False,
|
||||
use_shared_rel_pos_bias=False,
|
||||
out_type='raw',
|
||||
pretrained= # noqa
|
||||
'https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_eva-g-p14_20230615-e908c021.pth' # noqa
|
||||
),
|
||||
q_former_model=dict(
|
||||
type='Qformer',
|
||||
model_style='bert-base-uncased',
|
||||
vision_model_width=1408,
|
||||
add_cross_attention=True,
|
||||
cross_attention_freq=2,
|
||||
num_query_token=32,
|
||||
pretrained= # noqa
|
||||
'https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_qformer_20230615-1dfa889c.pth' # noqa
|
||||
),
|
||||
lang_encoder=dict(
|
||||
type='AutoModelForCausalLM',
|
||||
name_or_path='baichuan-inc/baichuan-7B',
|
||||
trust_remote_code=True),
|
||||
tokenizer=dict(
|
||||
type='AutoTokenizer',
|
||||
name_or_path='baichuan-inc/baichuan-7B',
|
||||
trust_remote_code=True),
|
||||
task='caption',
|
||||
prompt_template=dict([('en', '###Ask: {} ###Answer: '),
|
||||
('zh', '###问:{} ###答:')]),
|
||||
raw_prompts=dict([
|
||||
('en', [('<Img><ImageHere></Img> '
|
||||
'Describe this image in detail.'),
|
||||
('<Img><ImageHere></Img> '
|
||||
'Take a look at this image and describe what you notice.'),
|
||||
('<Img><ImageHere></Img> '
|
||||
'Please provide a detailed description of the picture.'),
|
||||
('<Img><ImageHere></Img> '
|
||||
'Could you describe the contents of this image for me?')]),
|
||||
('zh', [('<Img><ImageHere></Img> '
|
||||
'详细描述这张图片。'), ('<Img><ImageHere></Img> '
|
||||
'浏览这张图片并描述你注意到什么。'),
|
||||
('<Img><ImageHere></Img> '
|
||||
'请对这张图片进行详细的描述。'),
|
||||
('<Img><ImageHere></Img> '
|
||||
'你能为我描述这张图片的内容吗?')])
|
||||
]),
|
||||
max_txt_len=160,
|
||||
end_sym='###')
|
||||
|
||||
strategy = dict(
|
||||
type='DeepSpeedStrategy',
|
||||
fp16=dict(
|
||||
enabled=True,
|
||||
auto_cast=False,
|
||||
fp16_master_weights_and_grads=False,
|
||||
loss_scale=0,
|
||||
loss_scale_window=1000,
|
||||
hysteresis=1,
|
||||
min_loss_scale=1,
|
||||
initial_scale_power=16,
|
||||
),
|
||||
inputs_to_half=[0],
|
||||
zero_optimization=dict(
|
||||
stage=2,
|
||||
allgather_partitions=True,
|
||||
allgather_bucket_size=2e8,
|
||||
reduce_scatter=True,
|
||||
reduce_bucket_size='auto',
|
||||
overlap_comm=True,
|
||||
contiguous_gradients=True,
|
||||
),
|
||||
)
|
||||
|
||||
# schedule settings
|
||||
optim_wrapper = dict(
|
||||
type='DeepSpeedOptimWrapper',
|
||||
optimizer=dict(type='AdamW', lr=1e-3, weight_decay=0.05))
|
||||
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-3 / 500,
|
||||
by_epoch=False,
|
||||
begin=0,
|
||||
end=500,
|
||||
),
|
||||
dict(
|
||||
type='CosineAnnealingLR',
|
||||
eta_min=2e-4,
|
||||
by_epoch=False,
|
||||
begin=500,
|
||||
),
|
||||
]
|
||||
|
||||
train_cfg = dict(by_epoch=True, max_epochs=6)
|
||||
test_cfg = dict()
|
||||
|
||||
runner_type = 'FlexibleRunner'
|
||||
|
||||
default_hooks = dict(
|
||||
checkpoint=dict(
|
||||
type='CheckpointHook',
|
||||
interval=1,
|
||||
by_epoch=True,
|
||||
save_last=True,
|
||||
max_keep_ckpts=1,
|
||||
))
|
|
@ -55,13 +55,25 @@ model = dict(
|
|||
type='AutoModelForCausalLM', name_or_path='YOUR_PATH_TO_VICUNA'),
|
||||
tokenizer=dict(type='LlamaTokenizer', name_or_path='YOUR_PATH_TO_VICUNA'),
|
||||
task='caption',
|
||||
prompt_template='###Human: {} ###Assistant: ',
|
||||
raw_prompts=[
|
||||
'<Img><ImageHere></Img> Describe this image in detail.',
|
||||
'<Img><ImageHere></Img> Take a look at this image and describe what you notice.', # noqa
|
||||
'<Img><ImageHere></Img> Please provide a detailed description of the picture.', # noqa
|
||||
'<Img><ImageHere></Img> Could you describe the contents of this image for me?', # noqa
|
||||
],
|
||||
prompt_template=dict([('en', '###Ask: {} ###Answer: '),
|
||||
('zh', '###问:{} ###答:')]),
|
||||
raw_prompts=dict([
|
||||
('en', [('<Img><ImageHere></Img> '
|
||||
'Describe this image in detail.'),
|
||||
('<Img><ImageHere></Img> '
|
||||
'Take a look at this image and describe what you notice.'),
|
||||
('<Img><ImageHere></Img> '
|
||||
'Please provide a detailed description of the picture.'),
|
||||
('<Img><ImageHere></Img> '
|
||||
'Could you describe the contents of this image for me?')]),
|
||||
('zh', [('<Img><ImageHere></Img> '
|
||||
'详细描述这张图片。'), ('<Img><ImageHere></Img> '
|
||||
'浏览这张图片并描述你注意到什么。'),
|
||||
('<Img><ImageHere></Img> '
|
||||
'请对这张图片进行详细的描述。'),
|
||||
('<Img><ImageHere></Img> '
|
||||
'你能为我描述这张图片的内容吗?')])
|
||||
]),
|
||||
max_txt_len=160,
|
||||
end_sym='###')
|
||||
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
imagenet1k:
|
||||
dataset: ImageNet-1K
|
||||
dataset: OpenDataLab/ImageNet-1K
|
||||
download_root: data
|
||||
data_root: data/imagenet
|
||||
script: tools/dataset_converters/odl_imagenet1k_preprocess.sh
|
||||
|
||||
cub:
|
||||
dataset: CUB-200-2011
|
||||
dataset: OpenDataLab/CUB-200-2011
|
||||
download_root: data
|
||||
data_root: data/CUB_200_2011
|
||||
script: tools/dataset_converters/odl_cub_preprocess.sh
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
ARG PYTORCH="1.12.1"
|
||||
ARG CUDA="11.3"
|
||||
ARG PYTORCH="2.0.1"
|
||||
ARG CUDA="11.7"
|
||||
ARG CUDNN="8"
|
||||
FROM pytorch/torchserve:latest-gpu
|
||||
|
||||
ARG MMPRE="1.0.2"
|
||||
ARG MMPRE="1.2.0"
|
||||
|
||||
ENV PYTHONUNBUFFERED TRUE
|
||||
|
||||
|
|
|
@ -1,5 +1,38 @@
|
|||
# Changelog (MMPreTrain)
|
||||
|
||||
## v1.2.0(04/01/2024)
|
||||
|
||||
### New Features
|
||||
|
||||
- [Feature] Support LLaVA 1.5 ([#1853](https://github.com/open-mmlab/mmpretrain/pull/1853))
|
||||
- [Feature] Implement of RAM with a gradio interface. ([#1802](https://github.com/open-mmlab/mmpretrain/pull/1802))
|
||||
|
||||
### Bug Fix
|
||||
|
||||
- [Fix] Fix resize mix argument bug.
|
||||
|
||||
## v1.1.0(12/10/2023)
|
||||
|
||||
### New Features
|
||||
|
||||
- [Feature] Implement of Zero-Shot CLIP Classifier ([#1737](https://github.com/open-mmlab/mmpretrain/pull/1737))
|
||||
- [Feature] Add minigpt4 gradio demo and training script. ([#1758](https://github.com/open-mmlab/mmpretrain/pull/1758))
|
||||
|
||||
### Improvements
|
||||
|
||||
- [Config] New Version of config Adapting MobileNet Algorithm ([#1774](https://github.com/open-mmlab/mmpretrain/pull/1774))
|
||||
- [Config] Support DINO self-supervised learning in project ([#1756](https://github.com/open-mmlab/mmpretrain/pull/1756))
|
||||
- [Config] New Version of config Adapting Swin Transformer Algorithm ([#1780](https://github.com/open-mmlab/mmpretrain/pull/1780))
|
||||
- [Enhance] Add iTPN Supports for Non-three channel image ([#1735](https://github.com/open-mmlab/mmpretrain/pull/1735))
|
||||
- [Docs] Update dataset download script from opendatalab to openXlab ([#1765](https://github.com/open-mmlab/mmpretrain/pull/1765))
|
||||
- [Docs] Update COCO-Retrieval dataset docs. ([#1806](https://github.com/open-mmlab/mmpretrain/pull/1806))
|
||||
|
||||
### Bug Fix
|
||||
|
||||
- Update `train.py` to compat with new config.
|
||||
- Update OFA module to compat with the latest huggingface.
|
||||
- Fix pipeline bug in ImageRetrievalInferencer.
|
||||
|
||||
## v1.0.2(15/08/2023)
|
||||
|
||||
### New Features
|
||||
|
|
|
@ -16,7 +16,8 @@ and make sure you fill in all required information in the template.
|
|||
|
||||
| MMPretrain version | MMEngine version | MMCV version |
|
||||
| :----------------: | :---------------: | :--------------: |
|
||||
| 1.0.2 (main) | mmengine >= 0.8.3 | mmcv >= 2.0.0 |
|
||||
| 1.2.0 (main) | mmengine >= 0.8.3 | mmcv >= 2.0.0 |
|
||||
| 1.1.1 | mmengine >= 0.8.3 | mmcv >= 2.0.0 |
|
||||
| 1.0.0 | mmengine >= 0.8.0 | mmcv >= 2.0.0 |
|
||||
| 1.0.0rc8 | mmengine >= 0.7.1 | mmcv >= 2.0.0rc4 |
|
||||
| 1.0.0rc7 | mmengine >= 0.5.0 | mmcv >= 2.0.0rc4 |
|
||||
|
|
|
@ -144,15 +144,15 @@ ImageNet has multiple versions, but the most commonly used one is [ILSVRC 2012](
|
|||
|
||||
````{group-tab} Download by MIM
|
||||
|
||||
MIM supports downloading from [OpenDataLab](https://opendatalab.com/) and preprocessing ImageNet dataset with one command line.
|
||||
MIM supports downloading from [OpenXlab](https://openxlab.org.cn/datasets) and preprocessing ImageNet dataset with one command line.
|
||||
|
||||
_You need to register an account at [OpenDataLab official website](https://opendatalab.com/) and login by CLI._
|
||||
_You need to register an account at [OpenXlab official website](https://openxlab.org.cn/datasets) and login by CLI._
|
||||
|
||||
```Bash
|
||||
# install OpenDataLab CLI tools
|
||||
pip install -U opendatalab
|
||||
# log in OpenDataLab, register if you don't have an account.
|
||||
odl login
|
||||
# install OpenXlab CLI tools
|
||||
pip install -U openxlab
|
||||
# log in OpenXLab
|
||||
openxlab login
|
||||
# download and preprocess by MIM, better to execute in $MMPreTrain directory.
|
||||
mim download mmpretrain --dataset imagenet1k
|
||||
```
|
||||
|
@ -278,7 +278,7 @@ test_dataloader = val_dataloader
|
|||
| [`SUN397`](mmpretrain.datasets.SUN397)(data_root[, split, pipeline, ...]) | ["train", "test"] | [SUN397](https://vision.princeton.edu/projects/2010/SUN/) Dataset. |
|
||||
| [`VOC`](mmpretrain.datasets.VOC)(data_root[, image_set_path, pipeline, ...]) | ["train", "val", "tranval", "test"] | [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC/) Dataset. |
|
||||
|
||||
Some dataset homepage links may be unavailable, and you can download datasets through [OpenDataLab](https://opendatalab.com/), such as [Stanford Cars](https://opendatalab.com/Stanford_Cars/download).
|
||||
Some dataset homepage links may be unavailable, and you can download datasets through [OpenXLab](https://openxlab.org.cn/datasets), such as [Stanford Cars](https://openxlab.org.cn/datasets/OpenDataLab/Stanford_Cars).
|
||||
|
||||
## Supported Multi-modality Datasets
|
||||
|
||||
|
|
|
@ -13,7 +13,8 @@
|
|||
|
||||
| MMPretrain 版本 | MMEngine 版本 | MMCV 版本 |
|
||||
| :-------------: | :---------------: | :--------------: |
|
||||
| 1.0.2 (main) | mmengine >= 0.8.3 | mmcv >= 2.0.0 |
|
||||
| 1.2.0 (main) | mmengine >= 0.8.3 | mmcv >= 2.0.0 |
|
||||
| 1.1.1 | mmengine >= 0.8.3 | mmcv >= 2.0.0 |
|
||||
| 1.0.0 | mmengine >= 0.8.0 | mmcv >= 2.0.0 |
|
||||
| 1.0.0rc8 | mmengine >= 0.7.1 | mmcv >= 2.0.0rc4 |
|
||||
| 1.0.0rc7 | mmengine >= 0.5.0 | mmcv >= 2.0.0rc4 |
|
||||
|
|
|
@ -142,15 +142,15 @@ ImageNet 有多个版本,但最常用的一个是 [ILSVRC 2012](http://www.ima
|
|||
|
||||
````{group-tab} MIM 下载
|
||||
|
||||
MIM支持使用一条命令行从 [OpenDataLab](https://opendatalab.com/) 下载并预处理 ImageNet 数据集。
|
||||
MIM支持使用一条命令行从 [OpenXLab](https://openxlab.org.cn/datasets?lang=zh-CN) 下载并预处理 ImageNet 数据集。
|
||||
|
||||
_需要在 [OpenDataLab 官网](https://opendatalab.com/) 注册账号并命令行登录_。
|
||||
_需要在 [OpenXLab 官网](https://openxlab.org.cn/datasets?lang=zh-CN) 注册账号并命令行登录_。
|
||||
|
||||
```Bash
|
||||
# 安装opendatalab库
|
||||
pip install -U opendatalab
|
||||
# 登录到 OpenDataLab, 如果还没有注册,请到官网注册一个
|
||||
odl login
|
||||
# 安装 OpenXLab CLI 工具
|
||||
pip install -U openxlab
|
||||
# 登录 OpenXLab
|
||||
openxlab login
|
||||
# 使用 MIM 下载数据集, 最好在 $MMPreTrain 目录执行
|
||||
mim download mmpretrain --dataset imagenet1k
|
||||
```
|
||||
|
@ -276,7 +276,7 @@ test_dataloader = val_dataloader
|
|||
| [`SUN397`](mmpretrain.datasets.SUN397)(data_root[, split, pipeline, ...]) | ["train", "test"] | [SUN397](https://vision.princeton.edu/projects/2010/SUN/) 数据集 |
|
||||
| [`VOC`](mmpretrain.datasets.VOC)(data_root[, image_set_path, pipeline, ...]) | ["train", "val", "tranval", "test"] | [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC/) 数据集 |
|
||||
|
||||
有些数据集主页链接可能已经失效,您可以通过[OpenDataLab](https://opendatalab.com/)下载数据集,例如 [Stanford Cars](https://opendatalab.com/Stanford_Cars/download)数据集。
|
||||
有些数据集主页链接可能已经失效,您可以通过[OpenXLab](https://openxlab.org.cn/datasets?lang=zh-CN)下载数据集,例如 [Stanford Cars](https://openxlab.org.cn/datasets/OpenDataLab/Stanford_Cars)数据集。
|
||||
|
||||
## OpenMMLab 2.0 标准数据集
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from .apis import * # noqa: F401, F403
|
|||
from .version import __version__
|
||||
|
||||
mmcv_minimum_version = '2.0.0'
|
||||
mmcv_maximum_version = '2.1.0'
|
||||
mmcv_maximum_version = '2.4.0'
|
||||
mmcv_version = digit_version(mmcv.__version__)
|
||||
|
||||
mmengine_minimum_version = '0.8.3'
|
||||
|
|
|
@ -108,6 +108,7 @@ class ImageRetrievalInferencer(BaseInferencer):
|
|||
# A config of dataset
|
||||
from mmpretrain.registry import DATASETS
|
||||
test_pipeline = [dict(type='LoadImageFromFile'), self.pipeline]
|
||||
prototype.setdefault('pipeline', test_pipeline)
|
||||
dataset = DATASETS.build(prototype)
|
||||
dataloader = build_dataloader(dataset)
|
||||
elif isinstance(prototype, DataLoader):
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.dataset import DefaultSampler
|
||||
|
||||
from mmpretrain.datasets import (CUB, CenterCrop, LoadImageFromFile,
|
||||
PackInputs, RandomCrop, RandomFlip, Resize)
|
||||
from mmpretrain.evaluation import Accuracy
|
||||
|
||||
# dataset settings
|
||||
dataset_type = CUB
|
||||
data_preprocessor = dict(
|
||||
num_classes=200,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type=LoadImageFromFile),
|
||||
dict(type=Resize, scale=510),
|
||||
dict(type=RandomCrop, crop_size=384),
|
||||
dict(type=RandomFlip, prob=0.5, direction='horizontal'),
|
||||
dict(type=PackInputs),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type=LoadImageFromFile),
|
||||
dict(type=Resize, scale=510),
|
||||
dict(type=CenterCrop, crop_size=384),
|
||||
dict(type=PackInputs),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=8,
|
||||
num_workers=2,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/CUB_200_2011',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type=DefaultSampler, shuffle=True),
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=8,
|
||||
num_workers=2,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/CUB_200_2011',
|
||||
split='test',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type=DefaultSampler, shuffle=False),
|
||||
)
|
||||
val_evaluator = dict(type=Accuracy, topk=(1, ))
|
||||
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -0,0 +1,89 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.dataset import DefaultSampler
|
||||
|
||||
from mmpretrain.datasets import (CenterCrop, ImageNet, LoadImageFromFile,
|
||||
PackInputs, RandAugment, RandomErasing,
|
||||
RandomFlip, RandomResizedCrop, ResizeEdge)
|
||||
from mmpretrain.evaluation import Accuracy
|
||||
|
||||
# dataset settings
|
||||
dataset_type = ImageNet
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
bgr_mean = data_preprocessor['mean'][::-1]
|
||||
bgr_std = data_preprocessor['std'][::-1]
|
||||
|
||||
train_pipeline = [
|
||||
dict(type=LoadImageFromFile),
|
||||
dict(
|
||||
type=RandomResizedCrop,
|
||||
scale=256,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type=RandomFlip, prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type=RandAugment,
|
||||
policies='timm_increasing',
|
||||
num_policies=2,
|
||||
total_level=10,
|
||||
magnitude_level=9,
|
||||
magnitude_std=0.5,
|
||||
hparams=dict(
|
||||
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
|
||||
dict(
|
||||
type=RandomErasing,
|
||||
erase_prob=0.25,
|
||||
mode='rand',
|
||||
min_area_ratio=0.02,
|
||||
max_area_ratio=1 / 3,
|
||||
fill_color=bgr_mean,
|
||||
fill_std=bgr_std),
|
||||
dict(type=PackInputs),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type=LoadImageFromFile),
|
||||
dict(
|
||||
type=ResizeEdge,
|
||||
scale=292, # ( 256 / 224 * 256 )
|
||||
edge='short',
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type=CenterCrop, crop_size=256),
|
||||
dict(type=PackInputs),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=64,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type=DefaultSampler, shuffle=True),
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=64,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type=DefaultSampler, shuffle=False),
|
||||
)
|
||||
val_evaluator = dict(type=Accuracy, topk=(1, 5))
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmpretrain.models import (CrossEntropyLoss, GlobalAveragePooling,
|
||||
ImageClassifier, LinearClsHead, SwinTransformer)
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type=ImageClassifier,
|
||||
backbone=dict(
|
||||
type=SwinTransformer,
|
||||
arch='base',
|
||||
img_size=384,
|
||||
stage_cfgs=dict(block_cfgs=dict(window_size=12))),
|
||||
neck=dict(type=GlobalAveragePooling),
|
||||
head=dict(
|
||||
type=LinearClsHead,
|
||||
num_classes=1000,
|
||||
in_channels=1024,
|
||||
loss=dict(type=CrossEntropyLoss, loss_weight=1.0),
|
||||
topk=(1, 5)))
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmpretrain.models import (GlobalAveragePooling, ImageClassifier,
|
||||
LabelSmoothLoss, LinearClsHead,
|
||||
SwinTransformerV2)
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type=ImageClassifier,
|
||||
backbone=dict(
|
||||
type=SwinTransformerV2, arch='base', img_size=384, drop_path_rate=0.2),
|
||||
neck=dict(type=GlobalAveragePooling),
|
||||
head=dict(
|
||||
type=LinearClsHead,
|
||||
num_classes=1000,
|
||||
in_channels=1024,
|
||||
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
||||
loss=dict(type=LabelSmoothLoss, label_smooth_val=0.1, mode='original'),
|
||||
cal_acc=False))
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.optim import CosineAnnealingLR, LinearLR
|
||||
from torch.optim import SGD
|
||||
|
||||
# optimizer
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(
|
||||
type=SGD, lr=0.01, momentum=0.9, weight_decay=0.0005, nesterov=True))
|
||||
|
||||
# learning policy
|
||||
param_scheduler = [
|
||||
# warm up learning rate scheduler
|
||||
dict(
|
||||
type=LinearLR,
|
||||
start_factor=0.01,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=5,
|
||||
# update by iter
|
||||
convert_to_iter_based=True),
|
||||
# main learning rate scheduler
|
||||
dict(
|
||||
type=CosineAnnealingLR,
|
||||
T_max=95,
|
||||
by_epoch=True,
|
||||
begin=5,
|
||||
end=100,
|
||||
)
|
||||
]
|
||||
|
||||
# train, val, test setting
|
||||
train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1)
|
||||
val_cfg = dict()
|
||||
test_cfg = dict()
|
||||
|
||||
# NOTE: `auto_scale_lr` is for automatically scaling LR
|
||||
# based on the actual training batch size.
|
||||
auto_scale_lr = dict(base_batch_size=64)
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, LabelSmoothLoss, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_224 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(img_size=224, drop_path_rate=0.5, stage_cfgs=None),
|
||||
head=dict(
|
||||
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
||||
loss=dict(
|
||||
type=LabelSmoothLoss,
|
||||
label_smooth_val=0.1,
|
||||
mode='original',
|
||||
loss_weight=0),
|
||||
topk=None,
|
||||
cal_acc=False),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
||||
|
||||
# schedule settings
|
||||
optim_wrapper = dict(clip_grad=dict(max_norm=5.0))
|
|
@ -0,0 +1,12 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_384 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# schedule settings
|
||||
optim_wrapper = dict(clip_grad=dict(max_norm=5.0))
|
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_224 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(arch='large', img_size=224, stage_cfgs=None),
|
||||
head=dict(in_channels=1536),
|
||||
)
|
||||
|
||||
# schedule settings
|
||||
optim_wrapper = dict(clip_grad=dict(max_norm=5.0))
|
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_384 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(arch='large'),
|
||||
head=dict(in_channels=1536),
|
||||
)
|
||||
|
||||
# schedule settings
|
||||
optim_wrapper = dict(clip_grad=dict(max_norm=5.0))
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.hooks import CheckpointHook, LoggerHook
|
||||
from mmengine.model import PretrainedInit
|
||||
from torch.optim.adamw import AdamW
|
||||
|
||||
from mmpretrain.models import ImageClassifier
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.cub_bs8_384 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_base import *
|
||||
from .._base_.schedules.cub_bs64 import *
|
||||
|
||||
# model settings
|
||||
checkpoint = 'https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin-large_3rdparty_in21k-384px.pth' # noqa
|
||||
|
||||
model.update(
|
||||
backbone=dict(
|
||||
arch='large',
|
||||
init_cfg=dict(
|
||||
type=PretrainedInit, checkpoint=checkpoint, prefix='backbone')),
|
||||
head=dict(num_classes=200, in_channels=1536))
|
||||
|
||||
# schedule settings
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(
|
||||
_delete_=True,
|
||||
type=AdamW,
|
||||
lr=5e-6,
|
||||
weight_decay=0.0005,
|
||||
eps=1e-8,
|
||||
betas=(0.9, 0.999)),
|
||||
paramwise_cfg=dict(
|
||||
norm_decay_mult=0.0,
|
||||
bias_decay_mult=0.0,
|
||||
custom_keys={
|
||||
'.absolute_pos_embed': dict(decay_mult=0.0),
|
||||
'.relative_position_bias_table': dict(decay_mult=0.0)
|
||||
}),
|
||||
clip_grad=dict(max_norm=5.0),
|
||||
)
|
||||
|
||||
default_hooks = dict(
|
||||
# log every 20 intervals
|
||||
logger=dict(type=LoggerHook, interval=20),
|
||||
# save last three checkpoints
|
||||
checkpoint=dict(type=CheckpointHook, interval=1, max_keep_ckpts=3))
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, LabelSmoothLoss, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_224 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(
|
||||
arch='small', img_size=224, drop_path_rate=0.3, stage_cfgs=None),
|
||||
head=dict(
|
||||
in_channels=768,
|
||||
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
||||
loss=dict(
|
||||
type=LabelSmoothLoss,
|
||||
label_smooth_val=0.1,
|
||||
mode='original',
|
||||
loss_weight=0),
|
||||
topk=None,
|
||||
cal_acc=False),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
||||
|
||||
# schedule settings
|
||||
optim_wrapper = dict(clip_grad=dict(max_norm=5.0))
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, LabelSmoothLoss, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_224 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(
|
||||
arch='tiny', img_size=224, drop_path_rate=0.2, stage_cfgs=None),
|
||||
head=dict(
|
||||
in_channels=768,
|
||||
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
||||
loss=dict(
|
||||
type=LabelSmoothLoss,
|
||||
label_smooth_val=0.1,
|
||||
mode='original',
|
||||
loss_weight=0),
|
||||
topk=None,
|
||||
cal_acc=False),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
||||
|
||||
# schedule settings
|
||||
optim_wrapper = dict(clip_grad=dict(max_norm=5.0))
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet21k_bs128 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(
|
||||
img_size=192, drop_path_rate=0.5, window_size=[12, 12, 12, 6]),
|
||||
head=dict(num_classes=21841),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
||||
|
||||
# dataset settings
|
||||
data_preprocessor = dict(num_classes=21841)
|
||||
|
||||
_base_['train_pipeline'][1]['scale'] = 192 # RandomResizedCrop
|
||||
_base_['test_pipeline'][1]['scale'] = 219 # ResizeEdge
|
||||
_base_['test_pipeline'][2]['crop_size'] = 192 # CenterCrop
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_256 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(
|
||||
img_size=256, drop_path_rate=0.5, window_size=[16, 16, 16, 8]),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_256 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(
|
||||
img_size=256,
|
||||
window_size=[16, 16, 16, 8],
|
||||
pretrained_window_sizes=[12, 12, 12, 6]),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_384 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(
|
||||
window_size=[24, 24, 24, 12], pretrained_window_sizes=[12, 12, 12, 6]))
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_256 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(img_size=256, drop_path_rate=0.5),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet21k_bs128 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(
|
||||
img_size=192, drop_path_rate=0.5, window_size=[12, 12, 12, 6]),
|
||||
head=dict(num_classes=21841),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
||||
|
||||
# dataset settings
|
||||
data_preprocessor = dict(num_classes=21841)
|
||||
|
||||
_base_['train_pipeline'][1]['scale'] = 192 # RandomResizedCrop
|
||||
_base_['test_pipeline'][1]['scale'] = 219 # ResizeEdge
|
||||
_base_['test_pipeline'][2]['crop_size'] = 192 # CenterCrop
|
|
@ -0,0 +1,24 @@
|
|||
# Only for evaluation
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
|
||||
from mmpretrain.models import CrossEntropyLoss
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_256 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(
|
||||
arch='large',
|
||||
img_size=256,
|
||||
window_size=[16, 16, 16, 8],
|
||||
pretrained_window_sizes=[12, 12, 12, 6]),
|
||||
head=dict(
|
||||
in_channels=1536,
|
||||
loss=dict(type=CrossEntropyLoss, loss_weight=1.0),
|
||||
topk=(1, 5)))
|
|
@ -0,0 +1,24 @@
|
|||
# Only for evaluation
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
|
||||
from mmpretrain.models import CrossEntropyLoss
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_384 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(
|
||||
arch='large',
|
||||
img_size=384,
|
||||
window_size=[24, 24, 24, 12],
|
||||
pretrained_window_sizes=[12, 12, 12, 6]),
|
||||
head=dict(
|
||||
in_channels=1536,
|
||||
loss=dict(type=CrossEntropyLoss, loss_weight=1.0),
|
||||
topk=(1, 5)))
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_256 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(
|
||||
arch='small',
|
||||
img_size=256,
|
||||
drop_path_rate=0.3,
|
||||
window_size=[16, 16, 16, 8]),
|
||||
head=dict(in_channels=768),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_256 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(arch='small', img_size=256, drop_path_rate=0.3),
|
||||
head=dict(in_channels=768),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_256 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(
|
||||
arch='tiny',
|
||||
img_size=256,
|
||||
drop_path_rate=0.2,
|
||||
window_size=[16, 16, 16, 8]),
|
||||
head=dict(in_channels=768),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_256 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(arch='tiny', img_size=256, drop_path_rate=0.2),
|
||||
head=dict(in_channels=768),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
|
@ -43,6 +43,7 @@ if WITH_MULTIMODAL:
|
|||
from .gqa_dataset import GQA
|
||||
from .iconqa import IconQA
|
||||
from .infographic_vqa import InfographicVQA
|
||||
from .minigpt4_dataset import MiniGPT4Dataset
|
||||
from .nocaps import NoCaps
|
||||
from .ocr_vqa import OCRVQA
|
||||
from .refcoco import RefCOCO
|
||||
|
@ -56,5 +57,6 @@ if WITH_MULTIMODAL:
|
|||
'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption',
|
||||
'FlamingoEvalCOCOVQA', 'Flickr30kCaption', 'Flickr30kRetrieval',
|
||||
'RefCOCO', 'VisualGenomeQA', 'ScienceQA', 'NoCaps', 'GQA', 'TextVQA',
|
||||
'VSR', 'VizWiz', 'OCRVQA', 'InfographicVQA', 'IconQA'
|
||||
'VSR', 'VizWiz', 'OCRVQA', 'InfographicVQA', 'IconQA',
|
||||
'MiniGPT4Dataset'
|
||||
])
|
||||
|
|
|
@ -1,18 +1,45 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import json
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
from typing import List
|
||||
from os import PathLike
|
||||
from typing import List, Sequence, Union
|
||||
|
||||
from mmengine import get_file_backend
|
||||
|
||||
from mmpretrain.registry import DATASETS
|
||||
from mmpretrain.registry import DATASETS, TRANSFORMS
|
||||
from .base_dataset import BaseDataset
|
||||
|
||||
|
||||
def expanduser(data_prefix):
|
||||
if isinstance(data_prefix, (str, PathLike)):
|
||||
return osp.expanduser(data_prefix)
|
||||
else:
|
||||
return data_prefix
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class COCORetrieval(BaseDataset):
|
||||
"""COCO Retrieval dataset.
|
||||
|
||||
COCO (Common Objects in Context): The COCO dataset contains more than
|
||||
330K images,each of which has approximately 5 descriptive annotations.
|
||||
This dataset was releasedin collaboration between Microsoft and Carnegie
|
||||
Mellon University
|
||||
|
||||
COCO_2014 dataset directory: ::
|
||||
|
||||
COCO_2014
|
||||
├── val2014
|
||||
├── train2014
|
||||
├── annotations
|
||||
├── instances_train2014.json
|
||||
├── instances_val2014.json
|
||||
├── person_keypoints_train2014.json
|
||||
├── person_keypoints_val2014.json
|
||||
├── captions_train2014.json
|
||||
├── captions_val2014.json
|
||||
|
||||
Args:
|
||||
ann_file (str): Annotation file path.
|
||||
test_mode (bool): Whether dataset is used for evaluation. This will
|
||||
|
@ -23,8 +50,52 @@ class COCORetrieval(BaseDataset):
|
|||
data_prefix (str | dict): Prefix for training data. Defaults to ''.
|
||||
pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
|
||||
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
||||
|
||||
Examples:
|
||||
>>> from mmpretrain.datasets import COCORetrieval
|
||||
>>> train_dataset=COCORetrieval(data_root='coco2014/')
|
||||
>>> train_dataset
|
||||
Dataset COCORetrieval
|
||||
Number of samples: 414113
|
||||
Annotation file: /coco2014/annotations/captions_train2014.json
|
||||
Prefix of images: /coco2014/
|
||||
>>> from mmpretrain.datasets import COCORetrieval
|
||||
>>> val_dataset = COCORetrieval(data_root='coco2014/')
|
||||
>>> val_dataset
|
||||
Dataset COCORetrieval
|
||||
Number of samples: 202654
|
||||
Annotation file: /coco2014/annotations/captions_val2014.json
|
||||
Prefix of images: /coco2014/
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ann_file: str,
|
||||
test_mode: bool = False,
|
||||
data_prefix: Union[str, dict] = '',
|
||||
data_root: str = '',
|
||||
pipeline: Sequence = (),
|
||||
**kwargs):
|
||||
|
||||
if isinstance(data_prefix, str):
|
||||
data_prefix = dict(img_path=expanduser(data_prefix))
|
||||
|
||||
ann_file = expanduser(ann_file)
|
||||
transforms = []
|
||||
for transform in pipeline:
|
||||
if isinstance(transform, dict):
|
||||
transforms.append(TRANSFORMS.build(transform))
|
||||
else:
|
||||
transforms.append(transform)
|
||||
|
||||
super().__init__(
|
||||
data_root=data_root,
|
||||
data_prefix=data_prefix,
|
||||
test_mode=test_mode,
|
||||
pipeline=transforms,
|
||||
ann_file=ann_file,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load data list."""
|
||||
# get file backend
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List
|
||||
|
||||
import mmengine
|
||||
from mmengine.dataset import BaseDataset
|
||||
from mmengine.fileio import get_file_backend
|
||||
|
||||
from mmpretrain.registry import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class MiniGPT4Dataset(BaseDataset):
|
||||
"""Dataset for training MiniGPT4.
|
||||
|
||||
MiniGPT4 dataset directory:
|
||||
|
||||
minigpt4_dataset
|
||||
├── image
|
||||
│ ├── id0.jpg
|
||||
│ │── id1.jpg
|
||||
│ │── id2.jpg
|
||||
│ └── ...
|
||||
└── conversation_data.json
|
||||
|
||||
The structure of conversation_data.json:
|
||||
|
||||
[
|
||||
// English data
|
||||
{
|
||||
"id": str(id0),
|
||||
"conversation": "###Ask: <Img><ImageHere></Img> [Ask content]
|
||||
###Answer: [Answer content]"
|
||||
},
|
||||
|
||||
// Chinese data
|
||||
{
|
||||
"id": str(id1),
|
||||
"conversation": "###问:<Img><ImageHere></Img> [Ask content]
|
||||
###答:[Answer content]"
|
||||
},
|
||||
|
||||
...
|
||||
]
|
||||
|
||||
Args:
|
||||
data_root (str): The root directory for ``ann_file`` and ``image``.
|
||||
ann_file (str): Conversation file path.
|
||||
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
||||
"""
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
file_backend = get_file_backend(self.data_root)
|
||||
conversation_path = file_backend.join_path(self.data_root,
|
||||
self.ann_file)
|
||||
conversation = mmengine.load(conversation_path)
|
||||
img_ids = {}
|
||||
n = 0
|
||||
for conv in conversation:
|
||||
img_id = conv['id']
|
||||
if img_id not in img_ids.keys():
|
||||
img_ids[img_id] = n
|
||||
n += 1
|
||||
|
||||
img_root = file_backend.join_path(self.data_root, 'image')
|
||||
data_list = []
|
||||
for conv in conversation:
|
||||
img_file = '{}.jpg'.format(conv['id'])
|
||||
chat_content = conv['conversation']
|
||||
lang = 'en' if chat_content.startswith('###Ask: ') else 'zh'
|
||||
data_info = {
|
||||
'image_id': img_ids[conv['id']],
|
||||
'img_path': file_backend.join_path(img_root, img_file),
|
||||
'chat_content': chat_content,
|
||||
'lang': lang,
|
||||
}
|
||||
|
||||
data_list.append(data_info)
|
||||
|
||||
return data_list
|
|
@ -14,15 +14,18 @@ class MAEPretrainHead(BaseModule):
|
|||
norm_pix_loss (bool): Whether or not normalize target.
|
||||
Defaults to False.
|
||||
patch_size (int): Patch size. Defaults to 16.
|
||||
in_channels (int): Number of input channels. Defaults to 3.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
loss: dict,
|
||||
norm_pix: bool = False,
|
||||
patch_size: int = 16) -> None:
|
||||
patch_size: int = 16,
|
||||
in_channels: int = 3) -> None:
|
||||
super().__init__()
|
||||
self.norm_pix = norm_pix
|
||||
self.patch_size = patch_size
|
||||
self.in_channels = in_channels
|
||||
self.loss_module = MODELS.build(loss)
|
||||
|
||||
def patchify(self, imgs: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -30,19 +33,19 @@ class MAEPretrainHead(BaseModule):
|
|||
|
||||
Args:
|
||||
imgs (torch.Tensor): A batch of images. The shape should
|
||||
be :math:`(B, 3, H, W)`.
|
||||
be :math:`(B, C, H, W)`.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Patchified images. The shape is
|
||||
:math:`(B, L, \text{patch_size}^2 \times 3)`.
|
||||
:math:`(B, L, \text{patch_size}^2 \times C)`.
|
||||
"""
|
||||
p = self.patch_size
|
||||
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
|
||||
|
||||
h = w = imgs.shape[2] // p
|
||||
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
|
||||
x = imgs.reshape(shape=(imgs.shape[0], self.in_channels, h, p, w, p))
|
||||
x = torch.einsum('nchpwq->nhwpqc', x)
|
||||
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
|
||||
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * self.in_channels))
|
||||
return x
|
||||
|
||||
def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -50,18 +53,18 @@ class MAEPretrainHead(BaseModule):
|
|||
|
||||
Args:
|
||||
x (torch.Tensor): The shape is
|
||||
:math:`(B, L, \text{patch_size}^2 \times 3)`.
|
||||
:math:`(B, L, \text{patch_size}^2 \times C)`.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The shape is :math:`(B, 3, H, W)`.
|
||||
torch.Tensor: The shape is :math:`(B, C, H, W)`.
|
||||
"""
|
||||
p = self.patch_size
|
||||
h = w = int(x.shape[1]**.5)
|
||||
assert h * w == x.shape[1]
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, self.in_channels))
|
||||
x = torch.einsum('nhwpqc->nchpwq', x)
|
||||
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
|
||||
imgs = x.reshape(shape=(x.shape[0], self.in_channels, h * p, h * p))
|
||||
return imgs
|
||||
|
||||
def construct_target(self, target: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -71,7 +74,7 @@ class MAEPretrainHead(BaseModule):
|
|||
normalize the image according to ``norm_pix``.
|
||||
|
||||
Args:
|
||||
target (torch.Tensor): Image with the shape of B x 3 x H x W
|
||||
target (torch.Tensor): Image with the shape of B x C x H x W
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Tokenized images with the shape of B x L x C
|
||||
|
|
|
@ -11,6 +11,7 @@ if WITH_MULTIMODAL:
|
|||
from .minigpt4 import * # noqa: F401, F403
|
||||
from .ofa import * # noqa: F401, F403
|
||||
from .otter import * # noqa: F401, F403
|
||||
from .ram import * # noqa: F401, F403
|
||||
else:
|
||||
from mmpretrain.registry import MODELS
|
||||
from mmpretrain.utils.dependency import register_multimodal_placeholder
|
||||
|
@ -19,5 +20,5 @@ else:
|
|||
'Blip2Caption', 'Blip2Retrieval', 'Blip2VQA', 'BlipCaption',
|
||||
'BlipNLVR', 'BlipRetrieval', 'BlipGrounding', 'BlipVQA', 'Flamingo',
|
||||
'OFA', 'ChineseCLIP', 'MiniGPT4', 'Llava', 'Otter', 'CLIP',
|
||||
'CLIPZeroShot'
|
||||
'CLIPZeroShot', 'RAM', 'RAMNormal', 'RAMOpenset'
|
||||
], MODELS)
|
||||
|
|
|
@ -24,8 +24,8 @@ class Llava(BaseModel):
|
|||
use_im_start_end (bool): Whether to use the im_start and im_end tokens
|
||||
mm_vision_select_layer (int): The index from vision encoder output.
|
||||
Defaults to -1.
|
||||
use_mm_proj (bool): Whether to enable multi-modal projection.
|
||||
Defaults to True.
|
||||
mm_proj_depth (int): The number of linear layers for multi-modal
|
||||
projection. Defaults to 1.
|
||||
load_lang_pretrained (bool): Whether to load the pretrained model of
|
||||
language encoder. Defaults to False.
|
||||
generation_cfg (dict): The extra generation config, accept the keyword
|
||||
|
@ -51,9 +51,10 @@ class Llava(BaseModel):
|
|||
mm_hidden_size: int,
|
||||
prompt_tmpl: str,
|
||||
task: str = 'caption',
|
||||
use_im_patch: bool = True,
|
||||
use_im_start_end: bool = False,
|
||||
mm_vision_select_layer: int = -1,
|
||||
use_mm_proj: bool = True,
|
||||
mm_proj_depth: int = 1,
|
||||
generation_cfg: dict = dict(),
|
||||
load_lang_pretrained: bool = False,
|
||||
data_preprocessor: Optional[dict] = None,
|
||||
|
@ -75,7 +76,9 @@ class Llava(BaseModel):
|
|||
# init tokenizer
|
||||
self.tokenizer = TOKENIZER.build(tokenizer)
|
||||
# add Llava special tokens to the tokenizer
|
||||
self.tokenizer.add_tokens([self.im_patch_token], special_tokens=True)
|
||||
if use_im_patch:
|
||||
self.tokenizer.add_tokens([self.im_patch_token],
|
||||
special_tokens=True)
|
||||
if use_im_start_end:
|
||||
self.tokenizer.add_tokens([self.im_start_token, self.im_end_token],
|
||||
special_tokens=True)
|
||||
|
@ -108,14 +111,12 @@ class Llava(BaseModel):
|
|||
vision_encoder=vision_encoder,
|
||||
lang_encoder=lang_encoder,
|
||||
mm_hidden_size=mm_hidden_size,
|
||||
use_mm_proj=use_mm_proj,
|
||||
mm_proj_depth=mm_proj_depth,
|
||||
use_im_start_end=use_im_start_end,
|
||||
im_start_token=self.tokenizer.convert_tokens_to_ids(
|
||||
self.im_start_token),
|
||||
im_end_token=self.tokenizer.convert_tokens_to_ids(
|
||||
self.im_end_token),
|
||||
im_patch_token=self.tokenizer.convert_tokens_to_ids(
|
||||
self.im_patch_token),
|
||||
mm_vision_select_layer=mm_vision_select_layer)
|
||||
|
||||
self.generation_cfg = generation_cfg
|
||||
|
@ -207,16 +208,24 @@ class Llava(BaseModel):
|
|||
Returns:
|
||||
List[DataSample]: Return list of data samples.
|
||||
"""
|
||||
prompts = []
|
||||
tokens = []
|
||||
for sample in data_samples:
|
||||
final_prompt = self.prompt_tmpl.format(**sample.to_dict())
|
||||
prompts.append(final_prompt)
|
||||
prompt = self.prompt_tmpl.format(**sample.to_dict())
|
||||
input_ids = []
|
||||
while '<image>' in prompt:
|
||||
prefix, _, prompt = prompt.partition('<image>')
|
||||
input_ids.extend(
|
||||
self.tokenizer(prefix, add_special_tokens=False).input_ids)
|
||||
input_ids.append(-200)
|
||||
if prompt:
|
||||
input_ids.extend(
|
||||
self.tokenizer(prompt, add_special_tokens=False).input_ids)
|
||||
tokens.append(dict(input_ids=input_ids))
|
||||
|
||||
self.tokenizer.padding_side = 'left'
|
||||
input_text = self.tokenizer(
|
||||
prompts,
|
||||
input_text = self.tokenizer.pad(
|
||||
tokens,
|
||||
padding='longest',
|
||||
truncation=True,
|
||||
return_tensors='pt',
|
||||
max_length=2000,
|
||||
).to(device)
|
||||
|
|
|
@ -31,10 +31,10 @@ class LlavaLlamaForCausalLM(PreTrainedModel):
|
|||
lang_encoder,
|
||||
mm_hidden_size,
|
||||
use_im_start_end=True,
|
||||
use_mm_proj=True,
|
||||
mm_proj_depth=1,
|
||||
im_start_token: Optional[int] = None,
|
||||
im_end_token: Optional[int] = None,
|
||||
im_patch_token: Optional[int] = None,
|
||||
im_token_index: int = -200,
|
||||
mm_vision_select_layer: int = -1):
|
||||
super().__init__(lang_encoder.config)
|
||||
self.vision_tower = vision_encoder
|
||||
|
@ -43,16 +43,26 @@ class LlavaLlamaForCausalLM(PreTrainedModel):
|
|||
self.use_im_start_end = use_im_start_end
|
||||
self.im_start_token = im_start_token
|
||||
self.im_end_token = im_end_token
|
||||
self.im_patch_token = im_patch_token
|
||||
self.mm_hidden_size = mm_hidden_size
|
||||
self.mm_vision_select_layer = mm_vision_select_layer
|
||||
self.im_token_index = im_token_index
|
||||
self.lang_hidden_size = lang_encoder.config.hidden_size
|
||||
|
||||
if use_mm_proj and not hasattr(lang_encoder.model, 'mm_projector'):
|
||||
if mm_proj_depth == 1:
|
||||
# Llava V1
|
||||
mm_projector = nn.Linear(self.mm_hidden_size,
|
||||
self.lang_hidden_size)
|
||||
self.lang_encoder.model.add_module('mm_projector', mm_projector)
|
||||
elif not use_mm_proj:
|
||||
elif mm_proj_depth > 1:
|
||||
# Llava V1.5
|
||||
modules = [nn.Linear(self.mm_hidden_size, self.lang_hidden_size)]
|
||||
for _ in range(1, mm_proj_depth):
|
||||
modules.append(nn.GELU())
|
||||
modules.append(
|
||||
nn.Linear(self.lang_hidden_size, self.lang_hidden_size))
|
||||
mm_projector = nn.Sequential(*modules)
|
||||
self.lang_encoder.model.add_module('mm_projector', mm_projector)
|
||||
elif mm_proj_depth == 0:
|
||||
self.lang_encoder.model.add_module('mm_projector', nn.Identity())
|
||||
|
||||
self.post_init()
|
||||
|
@ -80,16 +90,12 @@ class LlavaLlamaForCausalLM(PreTrainedModel):
|
|||
return_dict
|
||||
if return_dict is not None else self.config.use_return_dict)
|
||||
|
||||
# decoder outputs consists of
|
||||
# (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.lang_encoder.model.embed_tokens(input_ids)
|
||||
|
||||
inputs_embeds = self.forward_vision_tower(input_ids, inputs_embeds,
|
||||
images)
|
||||
(input_ids, attention_mask, past_key_values, inputs_embeds,
|
||||
labels) = self.forward_vision_tower(input_ids, attention_mask,
|
||||
past_key_values, labels, images)
|
||||
|
||||
return self.lang_encoder(
|
||||
input_ids=None,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
@ -127,106 +133,93 @@ class LlavaLlamaForCausalLM(PreTrainedModel):
|
|||
def forward_vision_tower(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
inputs_embeds: torch.FloatTensor,
|
||||
images: Union[torch.FloatTensor, list, None] = None,
|
||||
attention_mask: torch.LongTensor,
|
||||
past_key_values: torch.FloatTensor,
|
||||
labels: torch.LongTensor,
|
||||
images: Union[torch.FloatTensor, None] = None,
|
||||
):
|
||||
if self.use_im_start_end:
|
||||
assert self.im_start_token is not None
|
||||
assert self.im_end_token is not None
|
||||
if images is not None:
|
||||
assert self.im_patch_token is not None
|
||||
|
||||
if self.vision_tower is None or images is None or (
|
||||
input_ids.shape[1] == 1 and not self.training):
|
||||
return inputs_embeds
|
||||
if self.vision_tower is None or images is None or input_ids.shape[
|
||||
1] == 1:
|
||||
if (past_key_values is not None and self.vision_tower is not None
|
||||
and images is not None and input_ids.shape[1] == 1):
|
||||
attention_mask = torch.ones(
|
||||
(attention_mask.shape[0],
|
||||
past_key_values[-1][-1].shape[-2] + 1),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device)
|
||||
return input_ids, attention_mask, past_key_values, None, labels
|
||||
|
||||
with torch.no_grad():
|
||||
if isinstance(images, (list, tuple)):
|
||||
# variable length images
|
||||
image_features = []
|
||||
for image in images:
|
||||
feats = self.vision_tower(image.unsqueeze(0))
|
||||
image_feature = feats[self.mm_vision_select_layer][:, 1:]
|
||||
image_features.append(image_feature)
|
||||
else:
|
||||
feats = self.vision_tower(images)
|
||||
image_features = feats[self.mm_vision_select_layer][:, 1:]
|
||||
# TODO: support variable number of images (single now)
|
||||
feats = self.vision_tower(images)
|
||||
image_features = feats[-1][:, 1:]
|
||||
|
||||
mm_projector = self.lang_encoder.model.mm_projector
|
||||
if isinstance(images, (list, tuple)):
|
||||
image_features = [
|
||||
mm_projector(image_feature)[0]
|
||||
for image_feature in image_features
|
||||
]
|
||||
else:
|
||||
image_features = mm_projector(image_features)
|
||||
|
||||
dummy_image_features = torch.zeros(
|
||||
256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
|
||||
dummy_image_features = mm_projector(dummy_image_features)
|
||||
image_features = self.lang_encoder.model.mm_projector(image_features)
|
||||
|
||||
new_input_embeds = []
|
||||
cur_image_idx = 0
|
||||
for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
|
||||
if (cur_input_ids != self.im_patch_token).all():
|
||||
# multimodal LLM, but the current sample is not multimodal
|
||||
cur_input_embeds = cur_input_embeds + (
|
||||
0. * dummy_image_features).sum()
|
||||
new_input_embeds.append(cur_input_embeds)
|
||||
cur_image_idx += 1
|
||||
continue
|
||||
if self.use_im_start_end:
|
||||
cur_image_features = image_features[cur_image_idx]
|
||||
num_patches = cur_image_features.shape[0]
|
||||
if (cur_input_ids == self.im_start_token).sum() != (
|
||||
cur_input_ids == self.im_end_token).sum():
|
||||
raise ValueError('The number of image start tokens and '
|
||||
'image end tokens should be the same.')
|
||||
image_start_tokens = torch.where(
|
||||
cur_input_ids == self.im_start_token)[0]
|
||||
for image_start_token_pos in image_start_tokens:
|
||||
cur_image_features = image_features[cur_image_idx].to(
|
||||
device=cur_input_embeds.device)
|
||||
num_patches = cur_image_features.shape[0]
|
||||
if cur_input_ids[image_start_token_pos + num_patches +
|
||||
1] != self.im_end_token:
|
||||
raise ValueError('The image end token should follow '
|
||||
'the image start token.')
|
||||
cur_new_input_embeds = torch.cat(
|
||||
(cur_input_embeds[:image_start_token_pos + 1],
|
||||
cur_image_features,
|
||||
cur_input_embeds[image_start_token_pos + num_patches +
|
||||
1:]),
|
||||
dim=0)
|
||||
cur_image_idx += 1
|
||||
new_input_embeds.append(cur_new_input_embeds)
|
||||
else:
|
||||
cur_image_features = image_features[cur_image_idx]
|
||||
num_patches = cur_image_features.shape[0]
|
||||
if (cur_input_ids == self.im_patch_token).sum() != num_patches:
|
||||
print(f'Debug: num_patches: {num_patches}')
|
||||
raise ValueError(
|
||||
'The number of image patch tokens should '
|
||||
'be the same as the number of image patches.')
|
||||
masked_indices = torch.where(
|
||||
cur_input_ids == self.im_patch_token)[0]
|
||||
mask_index_start = masked_indices[0]
|
||||
if (masked_indices != torch.arange(
|
||||
mask_index_start,
|
||||
mask_index_start + num_patches,
|
||||
device=masked_indices.device,
|
||||
dtype=masked_indices.dtype)).any():
|
||||
raise ValueError(
|
||||
'The image patch tokens should be consecutive.')
|
||||
cur_new_input_embeds = torch.cat(
|
||||
(cur_input_embeds[:mask_index_start], cur_image_features,
|
||||
cur_input_embeds[mask_index_start + num_patches:]),
|
||||
dim=0)
|
||||
new_input_embeds.append(cur_new_input_embeds)
|
||||
cur_image_idx += 1
|
||||
inputs_embeds = torch.stack(new_input_embeds, dim=0)
|
||||
new_labels = [] if labels is not None else None
|
||||
new_attn_mask = [] if attention_mask is not None else None
|
||||
for batch_idx, cur_input_ids in enumerate(input_ids):
|
||||
cur_img = image_features[batch_idx]
|
||||
|
||||
return inputs_embeds
|
||||
if (cur_input_ids != self.im_token_index).all():
|
||||
# multimodal LLM, but the current sample is not multimodal
|
||||
new_input_embeds.append(self.embed_tokens(cur_input_ids))
|
||||
if labels is not None:
|
||||
new_labels.append(labels[batch_idx])
|
||||
if attention_mask is not None:
|
||||
new_attn_mask.append(attention_mask[batch_idx])
|
||||
continue
|
||||
|
||||
img_idx = torch.where(cur_input_ids == self.im_token_index)[0][0]
|
||||
if self.use_im_start_end:
|
||||
cur_new_input_embeds = torch.cat(
|
||||
[
|
||||
self.embed_tokens(cur_input_ids[:img_idx - 1]),
|
||||
self.embed_tokens(cur_input_ids[img_idx - 1:img_idx]),
|
||||
cur_img,
|
||||
self.embed_tokens(
|
||||
cur_input_ids[img_idx + 1:img_idx + 2]),
|
||||
self.embed_tokens(cur_input_ids[img_idx + 2:]),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
else:
|
||||
cur_new_input_embeds = torch.cat(
|
||||
[
|
||||
self.embed_tokens(cur_input_ids[:img_idx]),
|
||||
cur_img,
|
||||
self.embed_tokens(cur_input_ids[img_idx + 1:]),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
new_input_embeds.append(cur_new_input_embeds)
|
||||
|
||||
if labels is not None:
|
||||
cur_new_labels = torch.cat([
|
||||
labels[batch_idx, :img_idx],
|
||||
labels.new_full((cur_img.size(0), ), -100),
|
||||
labels[batch_idx, img_idx + 1:],
|
||||
],
|
||||
dim=0)
|
||||
new_labels.append(cur_new_labels)
|
||||
|
||||
if attention_mask is not None:
|
||||
cur_attn_mask = torch.cat([
|
||||
attention_mask[batch_idx, :img_idx],
|
||||
attention_mask.new_full((cur_img.size(0), ), True),
|
||||
attention_mask[batch_idx, img_idx + 1:],
|
||||
],
|
||||
dim=0)
|
||||
new_attn_mask.append(cur_attn_mask)
|
||||
|
||||
inputs_embeds = torch.stack(new_input_embeds, dim=0)
|
||||
if labels is not None:
|
||||
labels = torch.stack(new_labels, dim=0)
|
||||
if attention_mask is not None:
|
||||
attention_mask = torch.stack(new_attn_mask, dim=0)
|
||||
|
||||
return None, attention_mask, past_key_values, inputs_embeds, labels
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
|
@ -236,3 +229,6 @@ class LlavaLlamaForCausalLM(PreTrainedModel):
|
|||
past_state.index_select(0, beam_idx)
|
||||
for past_state in layer_past), )
|
||||
return reordered_past
|
||||
|
||||
def embed_tokens(self, input_ids):
|
||||
return self.lang_encoder.model.embed_tokens(input_ids)
|
||||
|
|
|
@ -31,12 +31,12 @@ class MiniGPT4(BaseModel):
|
|||
True.
|
||||
num_query_token (int): Number of query tokens of Qformer. Defaults to
|
||||
32.
|
||||
prompt_template (str): Prompt template of the model. Defaults to
|
||||
'###Human: {} ###Assistant: '.
|
||||
raw_prompts (list): Prompts for training. Defaults to None.
|
||||
prompt_template (dict): Multi-language prompt template of the model. Defaults to dict([ ('en', '###Ask: {} ###Answer: '),
|
||||
('zh', '###问:{} ###答:')])
|
||||
raw_prompts (dict): Prompts for training. Defaults to dict().
|
||||
max_txt_len (int): Max token length while doing tokenization. Defaults
|
||||
to 32.
|
||||
end_sym (str): Ended symbol of the sequence. Defaults to '\\n'.
|
||||
end_sym (str): Ended symbol of the sequence. Defaults to '###'.
|
||||
generation_cfg (dict): The config of text generation. Defaults to
|
||||
dict().
|
||||
data_preprocessor (:obj:`BaseDataPreprocessor`): Used for
|
||||
|
@ -54,10 +54,12 @@ class MiniGPT4(BaseModel):
|
|||
freeze_vit: bool = True,
|
||||
freeze_q_former: bool = True,
|
||||
num_query_token: int = 32,
|
||||
prompt_template: str = '###Human: {} ###Assistant: ',
|
||||
raw_prompts: Optional[list] = None,
|
||||
prompt_template: dict = dict([('en',
|
||||
'###Ask: {} ###Answer: '),
|
||||
('zh', '###问:{} ###答:')]),
|
||||
raw_prompts: dict = dict(),
|
||||
max_txt_len: int = 32,
|
||||
end_sym: str = '\n',
|
||||
end_sym: str = '###',
|
||||
generation_cfg: dict = dict(),
|
||||
data_preprocessor: Optional[dict] = None,
|
||||
init_cfg: Optional[dict] = None):
|
||||
|
@ -135,16 +137,23 @@ class MiniGPT4(BaseModel):
|
|||
self.end_token_id = self.llama_tokenizer.encode(end_sym)[-1]
|
||||
|
||||
# set prompts
|
||||
if raw_prompts is not None:
|
||||
filted_prompts = [
|
||||
raw_prompt for raw_prompt in raw_prompts
|
||||
self.en_prompt_list, self.zh_prompt_list = [], []
|
||||
if raw_prompts.get('en') is not None:
|
||||
en_filted_prompts = [
|
||||
raw_prompt for raw_prompt in raw_prompts['en']
|
||||
if '<ImageHere>' in raw_prompt
|
||||
]
|
||||
self.prompt_list = [
|
||||
prompt_template.format(p) for p in filted_prompts
|
||||
self.en_prompt_list = [
|
||||
prompt_template['en'].format(p) for p in en_filted_prompts
|
||||
]
|
||||
if raw_prompts.get('zh') is not None:
|
||||
zh_filted_prompts = [
|
||||
raw_prompt for raw_prompt in raw_prompts['zh']
|
||||
if '<ImageHere>' in raw_prompt
|
||||
]
|
||||
self.zh_prompt_list = [
|
||||
prompt_template['zh'].format(p) for p in zh_filted_prompts
|
||||
]
|
||||
else:
|
||||
self.prompt_list = []
|
||||
|
||||
# update generation configs
|
||||
self.generation_cfg = dict(
|
||||
|
@ -153,7 +162,7 @@ class MiniGPT4(BaseModel):
|
|||
do_sample=True,
|
||||
min_length=1,
|
||||
top_p=0.9,
|
||||
repetition_penalty=1.0,
|
||||
repetition_penalty=1.1,
|
||||
length_penalty=1.0,
|
||||
temperature=1.0)
|
||||
self.generation_cfg.update(**generation_cfg)
|
||||
|
@ -161,6 +170,10 @@ class MiniGPT4(BaseModel):
|
|||
if hasattr(self, 'register_load_state_dict_post_hook'):
|
||||
self.register_load_state_dict_post_hook(self._load_llama_proj_hook)
|
||||
|
||||
def half(self):
|
||||
self.llama_model = self.llama_model.half()
|
||||
return self
|
||||
|
||||
def encode_img(self,
|
||||
images: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""The function to encode the images."""
|
||||
|
@ -184,33 +197,39 @@ class MiniGPT4(BaseModel):
|
|||
return inputs_llama, atts_llama
|
||||
|
||||
def prompt_wrap(self, img_embeds: torch.Tensor, atts_img: torch.Tensor,
|
||||
prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
prompt: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""The function to wrap the image and prompt.
|
||||
|
||||
Currently, the function only supports applying one prompt to all input
|
||||
images in the one batch.
|
||||
Make sure that len(prompt) == img_embeds.shape[0].
|
||||
|
||||
Args:
|
||||
img_embeds (torch.Tensor): The embedding of the input images.
|
||||
atts_img (torch.Tensor): Attention map of the image embeddings.
|
||||
prompt (str): The prompt of the batch data.
|
||||
prompt (List[str]): The prompt of the batch data.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: The embedding and attention map.
|
||||
"""
|
||||
if prompt:
|
||||
batch_size = img_embeds.shape[0]
|
||||
p_before, p_after = prompt.split('<ImageHere>')
|
||||
if len(prompt) > 0:
|
||||
p_before_list, p_after_list = [], []
|
||||
for pro in prompt:
|
||||
p_before, p_after = pro.split('<ImageHere>')
|
||||
p_before_list.append(p_before)
|
||||
p_after_list.append(p_after)
|
||||
p_before_tokens = self.llama_tokenizer(
|
||||
p_before, return_tensors='pt',
|
||||
p_before_list,
|
||||
return_tensors='pt',
|
||||
padding='longest',
|
||||
add_special_tokens=False).to(img_embeds.device)
|
||||
p_after_tokens = self.llama_tokenizer(
|
||||
p_after, return_tensors='pt',
|
||||
p_after_list,
|
||||
return_tensors='pt',
|
||||
padding='longest',
|
||||
add_special_tokens=False).to(img_embeds.device)
|
||||
p_before_embeds = self.llama_model.model.embed_tokens(
|
||||
p_before_tokens.input_ids).expand(batch_size, -1, -1)
|
||||
p_before_tokens.input_ids)
|
||||
p_after_embeds = self.llama_model.model.embed_tokens(
|
||||
p_after_tokens.input_ids).expand(batch_size, -1, -1)
|
||||
p_after_tokens.input_ids)
|
||||
wrapped_img_embeds = torch.cat(
|
||||
[p_before_embeds, img_embeds, p_after_embeds], dim=1)
|
||||
wrapped_atts_img = atts_img[:, :1].expand(
|
||||
|
@ -234,17 +253,22 @@ class MiniGPT4(BaseModel):
|
|||
"""
|
||||
img_embeds, atts_img = self.encode_img(images)
|
||||
|
||||
if self.task == 'caption' and self.prompt_list:
|
||||
prompt = random.choice(self.prompt_list)
|
||||
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img,
|
||||
prompt)
|
||||
|
||||
self.llama_tokenizer.padding_side = 'right'
|
||||
|
||||
text = [t + self.end_sym for t in data_samples['text_input']]
|
||||
prompts, texts = [], []
|
||||
for t in data_samples:
|
||||
chat_content = t.chat_content
|
||||
split_mark = '###Answer: ' if t.lang == 'en' else '###答:'
|
||||
prompt, text = chat_content.split(split_mark)
|
||||
prompt += split_mark
|
||||
text += self.end_sym
|
||||
prompts.append(prompt)
|
||||
texts.append(text)
|
||||
|
||||
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompts)
|
||||
|
||||
to_regress_tokens = self.llama_tokenizer(
|
||||
text,
|
||||
texts,
|
||||
return_tensors='pt',
|
||||
padding='longest',
|
||||
truncation=True,
|
||||
|
@ -295,10 +319,12 @@ class MiniGPT4(BaseModel):
|
|||
with torch.no_grad():
|
||||
img_embeds, atts_img = self.encode_img(images)
|
||||
|
||||
if self.task == 'caption' and self.prompt_list:
|
||||
prompt = random.choice(self.prompt_list)
|
||||
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img,
|
||||
prompt)
|
||||
prompts = [
|
||||
random.choice(self.zh_prompt_list) if hasattr(t, 'lang')
|
||||
and t.lang == 'zh' else random.choice(self.en_prompt_list)
|
||||
for t in data_samples
|
||||
]
|
||||
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompts)
|
||||
|
||||
batch_size = img_embeds.shape[0]
|
||||
bos = torch.ones(
|
||||
|
@ -336,7 +362,6 @@ class MiniGPT4(BaseModel):
|
|||
for output, data_sample in zip(outputs, data_samples):
|
||||
if self.task == 'caption':
|
||||
output = output.split('###')[0]
|
||||
output = output.split('Assistant:')[-1].strip()
|
||||
data_sample.pred_caption = output
|
||||
else:
|
||||
# raw output
|
||||
|
|
|
@ -1301,6 +1301,7 @@ class OFAEncoderDecoder(BaseModule, GenerationMixin):
|
|||
Defaults to an empty dict.
|
||||
init_cfg (dict, optional): The initialization config. Defaults to None.
|
||||
"""
|
||||
base_model_prefix = ''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .ram import RAM, RAMNormal, RAMOpenset
|
||||
|
||||
__all__ = ['RAM', 'RAMNormal', 'RAMOpenset']
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
@ -0,0 +1,93 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# data settings
|
||||
test_transforms_cfg = [
|
||||
dict(type='Resize', scale=(384, 384), interpolation='bicubic'),
|
||||
dict(
|
||||
type='mmpretrain.PackInputs',
|
||||
algorithm_keys=['text'],
|
||||
meta_keys=['image_id', 'scale_factor'],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_ram_cfg(mode='normal'):
|
||||
assert mode in ['normal', 'openset'], 'mode must "normal" or "openset"'
|
||||
model_type = 'RAMNormal' if mode == 'normal' else 'RAMOpenset'
|
||||
model_cfg = dict(
|
||||
type=model_type,
|
||||
tokenizer=dict(
|
||||
type='BertTokenizer',
|
||||
name_or_path='/public/DATA/qbw/ckpt/bert-base-uncased',
|
||||
use_fast=False),
|
||||
vision_backbone=dict(
|
||||
type='SwinTransformer',
|
||||
arch='large',
|
||||
img_size=384,
|
||||
window_size=12,
|
||||
),
|
||||
tag_encoder={
|
||||
'architectures': ['BertModel'],
|
||||
'attention_probs_dropout_prob': 0.1,
|
||||
'hidden_act': 'gelu',
|
||||
'hidden_dropout_prob': 0.1,
|
||||
'hidden_size': 768,
|
||||
'initializer_range': 0.02,
|
||||
'intermediate_size': 3072,
|
||||
'layer_norm_eps': 1e-12,
|
||||
'max_position_embeddings': 512,
|
||||
'model_type': 'bert',
|
||||
'num_attention_heads': 12,
|
||||
'num_hidden_layers': 12,
|
||||
'pad_token_id': 0,
|
||||
'type_vocab_size': 2,
|
||||
'vocab_size': 30524,
|
||||
'encoder_width': 512,
|
||||
'add_cross_attention': True
|
||||
},
|
||||
text_decoder={
|
||||
'architectures': ['BertModel'],
|
||||
'attention_probs_dropout_prob': 0.1,
|
||||
'hidden_act': 'gelu',
|
||||
'hidden_dropout_prob': 0.1,
|
||||
'hidden_size': 768,
|
||||
'initializer_range': 0.02,
|
||||
'intermediate_size': 3072,
|
||||
'layer_norm_eps': 1e-12,
|
||||
'max_position_embeddings': 512,
|
||||
'model_type': 'bert',
|
||||
'num_attention_heads': 12,
|
||||
'num_hidden_layers': 12,
|
||||
'pad_token_id': 0,
|
||||
'type_vocab_size': 2,
|
||||
'vocab_size': 30524,
|
||||
'encoder_width': 768,
|
||||
'add_cross_attention': True
|
||||
},
|
||||
tagging_head={
|
||||
'architectures': ['BertModel'],
|
||||
'attention_probs_dropout_prob': 0.1,
|
||||
'hidden_act': 'gelu',
|
||||
'hidden_dropout_prob': 0.1,
|
||||
'hidden_size': 768,
|
||||
'initializer_range': 0.02,
|
||||
'intermediate_size': 3072,
|
||||
'layer_norm_eps': 1e-12,
|
||||
'max_position_embeddings': 512,
|
||||
'model_type': 'bert',
|
||||
'num_attention_heads': 4,
|
||||
'num_hidden_layers': 2,
|
||||
'pad_token_id': 0,
|
||||
'type_vocab_size': 2,
|
||||
'vocab_size': 30522,
|
||||
'encoder_width': 512,
|
||||
'add_cross_attention': True,
|
||||
'add_tag_cross_attention': False
|
||||
},
|
||||
data_preprocessor=dict(
|
||||
type='MultiModalDataPreprocessor',
|
||||
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
|
||||
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
|
||||
to_rgb=False,
|
||||
),
|
||||
)
|
||||
return model_cfg
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,109 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
|
||||
from mmpretrain.registry import MODELS, TRANSFORMS
|
||||
from .config.ram_swin_large_14m import get_ram_cfg, test_transforms_cfg
|
||||
from .run.inference import inference
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description='RAM(Recognize Anything Model) demo')
|
||||
parser.add_argument(
|
||||
'ram_ckpt', type=str, help='pretrained file for ram (absolute path)')
|
||||
parser.add_argument(
|
||||
'clip_ckpt',
|
||||
type=str,
|
||||
help='clip vit-base-p16 pretrained file (absolute path)')
|
||||
args = parser.parse_args()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
devices = [
|
||||
torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())
|
||||
]
|
||||
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
devices = [torch.device('mps')]
|
||||
else:
|
||||
devices = [torch.device('cpu')]
|
||||
|
||||
|
||||
def get_free_device():
|
||||
if hasattr(torch.cuda, 'mem_get_info'):
|
||||
free = [torch.cuda.mem_get_info(gpu)[0] for gpu in devices]
|
||||
select = max(zip(free, range(len(free))))[1]
|
||||
else:
|
||||
import random
|
||||
select = random.randint(0, len(devices) - 1)
|
||||
return devices[select]
|
||||
|
||||
|
||||
device = get_free_device()
|
||||
|
||||
|
||||
def ram_inference(image, tag_list, mode, threshold):
|
||||
test_transforms = TRANSFORMS.get('Compose')(transforms=test_transforms_cfg)
|
||||
model = MODELS.build(get_ram_cfg(mode=mode))
|
||||
model.load_state_dict(torch.load(args.ram_ckpt))
|
||||
model.device = device
|
||||
|
||||
if mode == 'openset':
|
||||
categories = tag_list
|
||||
if categories != '':
|
||||
categories = categories.strip().split()
|
||||
else:
|
||||
categories = None
|
||||
model.set_openset(
|
||||
categories=categories,
|
||||
clip_ckpt=args.clip_ckpt,
|
||||
threshold=threshold)
|
||||
|
||||
sample = dict(img=image)
|
||||
result = inference(sample, model, test_transforms, mode=mode)
|
||||
tag, tag_chinese, logits = \
|
||||
result.get('tag_output')[0][0], result.get('tag_output')[1][0],\
|
||||
result.get('logits_output')[0]
|
||||
|
||||
def wrap(tags, logits):
|
||||
if tags is None:
|
||||
return 'Openset mode has no tag_en'
|
||||
tag_lst = tags.split('|')
|
||||
rt_lst = []
|
||||
for i, tag in enumerate(tag_lst):
|
||||
tag = tag.strip()
|
||||
rt_lst.append(tag + f': {logits[i]:.2f}')
|
||||
return ' | '.join(rt_lst)
|
||||
|
||||
return [wrap(tag, logits), wrap(tag_chinese, logits)]
|
||||
|
||||
|
||||
def build_gradio():
|
||||
inputs = [
|
||||
gr.components.Image(label='image'),
|
||||
gr.components.Textbox(
|
||||
lines=2,
|
||||
label='tag_list',
|
||||
placeholder=
|
||||
'please input the categories split by keyboard "blank": ',
|
||||
value=''),
|
||||
gr.components.Radio(['normal', 'openset'],
|
||||
label='mode',
|
||||
value='normal'),
|
||||
gr.components.Slider(
|
||||
minimum=0, maximum=1, value=0.68, step=0.01, label='threshold')
|
||||
]
|
||||
return gr.Interface(
|
||||
fn=ram_inference,
|
||||
inputs=inputs,
|
||||
outputs=[
|
||||
gr.components.Textbox(),
|
||||
gr.components.Textbox(info="it's translated from the english tags")
|
||||
])
|
||||
|
||||
|
||||
def main():
|
||||
build_gradio().launch()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,212 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
|
||||
|
||||
def article(name):
|
||||
return 'an' if name[0] in 'aeiou' else 'a'
|
||||
|
||||
|
||||
def processed_name(name, rm_dot=False):
|
||||
# _ for lvis
|
||||
# / for obj365
|
||||
res = name.replace('_', ' ').replace('/', ' or ').lower()
|
||||
if rm_dot:
|
||||
res = res.rstrip('.')
|
||||
return res
|
||||
|
||||
|
||||
single_template = ['a photo of a {}.']
|
||||
|
||||
multiple_templates = [
|
||||
'There is {article} {} in the scene.',
|
||||
'There is the {} in the scene.',
|
||||
'a photo of {article} {} in the scene.',
|
||||
'a photo of the {} in the scene.',
|
||||
'a photo of one {} in the scene.',
|
||||
'itap of {article} {}.',
|
||||
'itap of my {}.', # itap: I took a picture of
|
||||
'itap of the {}.',
|
||||
'a photo of {article} {}.',
|
||||
'a photo of my {}.',
|
||||
'a photo of the {}.',
|
||||
'a photo of one {}.',
|
||||
'a photo of many {}.',
|
||||
'a good photo of {article} {}.',
|
||||
'a good photo of the {}.',
|
||||
'a bad photo of {article} {}.',
|
||||
'a bad photo of the {}.',
|
||||
'a photo of a nice {}.',
|
||||
'a photo of the nice {}.',
|
||||
'a photo of a cool {}.',
|
||||
'a photo of the cool {}.',
|
||||
'a photo of a weird {}.',
|
||||
'a photo of the weird {}.',
|
||||
'a photo of a small {}.',
|
||||
'a photo of the small {}.',
|
||||
'a photo of a large {}.',
|
||||
'a photo of the large {}.',
|
||||
'a photo of a clean {}.',
|
||||
'a photo of the clean {}.',
|
||||
'a photo of a dirty {}.',
|
||||
'a photo of the dirty {}.',
|
||||
'a bright photo of {article} {}.',
|
||||
'a bright photo of the {}.',
|
||||
'a dark photo of {article} {}.',
|
||||
'a dark photo of the {}.',
|
||||
'a photo of a hard to see {}.',
|
||||
'a photo of the hard to see {}.',
|
||||
'a low resolution photo of {article} {}.',
|
||||
'a low resolution photo of the {}.',
|
||||
'a cropped photo of {article} {}.',
|
||||
'a cropped photo of the {}.',
|
||||
'a close-up photo of {article} {}.',
|
||||
'a close-up photo of the {}.',
|
||||
'a jpeg corrupted photo of {article} {}.',
|
||||
'a jpeg corrupted photo of the {}.',
|
||||
'a blurry photo of {article} {}.',
|
||||
'a blurry photo of the {}.',
|
||||
'a pixelated photo of {article} {}.',
|
||||
'a pixelated photo of the {}.',
|
||||
'a black and white photo of the {}.',
|
||||
'a black and white photo of {article} {}.',
|
||||
'a plastic {}.',
|
||||
'the plastic {}.',
|
||||
'a toy {}.',
|
||||
'the toy {}.',
|
||||
'a plushie {}.',
|
||||
'the plushie {}.',
|
||||
'a cartoon {}.',
|
||||
'the cartoon {}.',
|
||||
'an embroidered {}.',
|
||||
'the embroidered {}.',
|
||||
'a painting of the {}.',
|
||||
'a painting of a {}.',
|
||||
]
|
||||
|
||||
openimages_rare_unseen = [
|
||||
'Aerial photography', 'Aircraft engine', 'Ale', 'Aloe', 'Amphibian',
|
||||
'Angling', 'Anole', 'Antique car', 'Arcade game', 'Arthropod',
|
||||
'Assault rifle', 'Athletic shoe', 'Auto racing', 'Backlighting',
|
||||
'Bagpipes', 'Ball game', 'Barbecue chicken', 'Barechested', 'Barquentine',
|
||||
'Beef tenderloin', 'Billiard room', 'Billiards', 'Bird of prey',
|
||||
'Black swan', 'Black-and-white', 'Blond', 'Boating', 'Bonbon',
|
||||
'Bottled water', 'Bouldering', 'Bovine', 'Bratwurst', 'Breadboard',
|
||||
'Briefs', 'Brisket', 'Brochette', 'Calabaza', 'Camera operator', 'Canola',
|
||||
'Childbirth', 'Chordophone', 'Church bell', 'Classical sculpture',
|
||||
'Close-up', 'Cobblestone', 'Coca-cola', 'Combat sport', 'Comics',
|
||||
'Compact car', 'Computer speaker', 'Cookies and crackers',
|
||||
'Coral reef fish', 'Corn on the cob', 'Cosmetics', 'Crocodilia',
|
||||
'Digital camera', 'Dishware', 'Divemaster', 'Dobermann', 'Dog walking',
|
||||
'Domestic rabbit', 'Domestic short-haired cat', 'Double-decker bus',
|
||||
'Drums', 'Electric guitar', 'Electric piano', 'Electronic instrument',
|
||||
'Equestrianism', 'Equitation', 'Erinaceidae', 'Extreme sport', 'Falafel',
|
||||
'Figure skating', 'Filling station', 'Fire apparatus', 'Firearm',
|
||||
'Flatbread', 'Floristry', 'Forklift truck', 'Freight transport',
|
||||
'Fried food', 'Fried noodles', 'Frigate', 'Frozen yogurt', 'Frying',
|
||||
'Full moon', 'Galleon', 'Glacial landform', 'Gliding', 'Go-kart', 'Goats',
|
||||
'Grappling', 'Great white shark', 'Gumbo', 'Gun turret', 'Hair coloring',
|
||||
'Halter', 'Headphones', 'Heavy cruiser', 'Herding', 'High-speed rail',
|
||||
'Holding hands', 'Horse and buggy', 'Horse racing', 'Hound',
|
||||
'Hunting knife', 'Hurdling', 'Inflatable', 'Jackfruit', 'Jeans', 'Jiaozi',
|
||||
'Junk food', 'Khinkali', 'Kitesurfing', 'Lawn game', 'Leaf vegetable',
|
||||
'Lechon', 'Lifebuoy', 'Locust', 'Lumpia', 'Luxury vehicle', 'Machine tool',
|
||||
'Medical imaging', 'Melee weapon', 'Microcontroller', 'Middle ages',
|
||||
'Military person', 'Military vehicle', 'Milky way', 'Miniature Poodle',
|
||||
'Modern dance', 'Molluscs', 'Monoplane', 'Motorcycling', 'Musical theatre',
|
||||
'Narcissus', 'Nest box', 'Newsagent\'s shop', 'Nile crocodile',
|
||||
'Nordic skiing', 'Nuclear power plant', 'Orator', 'Outdoor shoe',
|
||||
'Parachuting', 'Pasta salad', 'Peafowl', 'Pelmeni', 'Perching bird',
|
||||
'Performance car', 'Personal water craft', 'Pit bull', 'Plant stem',
|
||||
'Pork chop', 'Portrait photography', 'Primate', 'Procyonidae',
|
||||
'Prosciutto', 'Public speaking', 'Racewalking', 'Ramen',
|
||||
'Rear-view mirror', 'Residential area', 'Ribs', 'Rice ball',
|
||||
'Road cycling', 'Roller skating', 'Roman temple', 'Rowing', 'Rural area',
|
||||
'Sailboat racing', 'Scaled reptile', 'Scuba diving', 'Senior citizen',
|
||||
'Shallot', 'Shinto shrine', 'Shooting range', 'Siberian husky', 'Sledding',
|
||||
'Soba', 'Solar energy', 'Sport climbing', 'Sport utility vehicle',
|
||||
'Steamed rice', 'Stemware', 'Sumo', 'Surfing Equipment', 'Team sport',
|
||||
'Touring car', 'Toy block', 'Trampolining', 'Underwater diving',
|
||||
'Vegetarian food', 'Wallaby', 'Water polo', 'Watercolor paint', 'Whiskers',
|
||||
'Wind wave', 'Woodwind instrument', 'Yakitori', 'Zeppelin'
|
||||
]
|
||||
|
||||
|
||||
def get_clip_model():
|
||||
model = dict(
|
||||
type='CLIPZeroShot',
|
||||
vision_backbone=dict(
|
||||
type='VisionTransformer',
|
||||
arch='base',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
drop_rate=0.,
|
||||
layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')),
|
||||
pre_norm=True,
|
||||
),
|
||||
projection=dict(
|
||||
type='CLIPProjection', in_channels=768, out_channels=512),
|
||||
text_backbone=dict(
|
||||
type='CLIPTransformer',
|
||||
width=512,
|
||||
layers=12,
|
||||
heads=8,
|
||||
attn_mask=True,
|
||||
),
|
||||
tokenizer=dict(
|
||||
type='AutoTokenizer',
|
||||
name_or_path='openai/clip-vit-base-patch16',
|
||||
use_fast=False),
|
||||
vocab_size=49408,
|
||||
transformer_width=512,
|
||||
proj_dim=512,
|
||||
context_length=77,
|
||||
data_preprocessor=dict(
|
||||
type='MultiModalDataPreprocessor',
|
||||
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
|
||||
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
|
||||
to_rgb=False,
|
||||
),
|
||||
)
|
||||
return MODELS.build(model)
|
||||
|
||||
|
||||
def build_openset_label_embedding(categories=None, clip_ckpt_path=''):
|
||||
if categories is None:
|
||||
print('Categories is None, so using rare_unseen categories')
|
||||
categories = openimages_rare_unseen
|
||||
model = get_clip_model()
|
||||
model.load_state_dict(torch.load(clip_ckpt_path))
|
||||
templates = multiple_templates
|
||||
|
||||
run_on_gpu = torch.cuda.is_available()
|
||||
|
||||
with torch.no_grad():
|
||||
openset_label_embedding = []
|
||||
for category in categories:
|
||||
texts = [
|
||||
template.format(
|
||||
processed_name(category, rm_dot=True),
|
||||
article=article(category)) for template in templates
|
||||
]
|
||||
texts = [
|
||||
'This is ' + text
|
||||
if text.startswith('a') or text.startswith('the') else text
|
||||
for text in texts
|
||||
]
|
||||
texts = model.tokenize(texts) # tokenize
|
||||
if run_on_gpu:
|
||||
texts = texts.cuda()
|
||||
model = model.cuda()
|
||||
text_embeddings = model.extract_text_feat(texts)
|
||||
text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
|
||||
text_embedding = text_embeddings.mean(dim=0)
|
||||
text_embedding /= text_embedding.norm()
|
||||
openset_label_embedding.append(text_embedding)
|
||||
openset_label_embedding = torch.stack(openset_label_embedding, dim=1)
|
||||
if run_on_gpu:
|
||||
openset_label_embedding = openset_label_embedding.cuda()
|
||||
|
||||
openset_label_embedding = openset_label_embedding.t()
|
||||
return openset_label_embedding, categories
|
|
@ -0,0 +1,332 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import pickle
|
||||
from abc import abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.model import BaseModel
|
||||
|
||||
from mmpretrain.registry import MODELS, TOKENIZER
|
||||
from mmpretrain.structures import DataSample
|
||||
from .bert import BertConfig, BertLMHeadModel, BertModel
|
||||
from .openset_utils import build_openset_label_embedding
|
||||
from .utils import tie_encoder_decoder_weights
|
||||
|
||||
|
||||
def get_path(path):
|
||||
file_path = os.path.abspath(os.path.dirname(__file__))
|
||||
if not os.path.isabs(path):
|
||||
return os.path.join(file_path, path)
|
||||
|
||||
|
||||
class RAM(BaseModel):
|
||||
"""The implementation of `RAM <https://arxiv.org/abs/2306.03514>`_."""
|
||||
|
||||
def __init__(self,
|
||||
tokenizer: dict,
|
||||
vision_backbone: dict,
|
||||
tag_encoder: dict,
|
||||
tagging_head: dict,
|
||||
text_decoder: dict,
|
||||
device: str = 'cpu',
|
||||
vision_width: int = 1536,
|
||||
prompt='a picture of ',
|
||||
threshold=0.68,
|
||||
delete_tag_index=[],
|
||||
tag_list='./data/ram_tag_list.pickle',
|
||||
tag_list_chinese='./data/ram_tag_list_chinese.pickle',
|
||||
data_preprocessor: Optional[dict] = None,
|
||||
init_cfg: Optional[dict] = None):
|
||||
if data_preprocessor is None:
|
||||
data_preprocessor = {}
|
||||
data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor')
|
||||
data_preprocessor = MODELS.build(data_preprocessor)
|
||||
|
||||
super().__init__(
|
||||
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
|
||||
|
||||
self.device = device
|
||||
# build the visual encoder
|
||||
self.visual_encoder = MODELS.build(vision_backbone)
|
||||
|
||||
# build the tokenizer
|
||||
self.tokenizer = TOKENIZER.build(tokenizer)
|
||||
self.tokenizer.add_special_tokens({'bos_token': '[DEC]'})
|
||||
self.tokenizer.add_special_tokens(
|
||||
{'additional_special_tokens': ['[ENC]']})
|
||||
self.tokenizer.enc_token_id = \
|
||||
self.tokenizer.additional_special_tokens_ids[0]
|
||||
|
||||
# build the tag encoder
|
||||
# encoder_config = BertConfig.from_json_file(med_config)
|
||||
# encoder_config.encoder_width = 512
|
||||
encoder_config = BertConfig.from_dict(tag_encoder)
|
||||
self.tag_encoder = BertModel(
|
||||
config=encoder_config, add_pooling_layer=False)
|
||||
|
||||
# build image-tag-text decoder
|
||||
# decoder_config = BertConfig.from_json_file(med_config)
|
||||
decoder_config = BertConfig.from_dict(text_decoder)
|
||||
self.text_decoder = BertLMHeadModel(config=decoder_config)
|
||||
|
||||
self.delete_tag_index = delete_tag_index
|
||||
self.prompt = prompt
|
||||
self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1
|
||||
|
||||
# load tag list
|
||||
self.tag_list = self.load_tag_list(get_path(tag_list))
|
||||
self.tag_list_chinese = self.load_tag_list(get_path(tag_list_chinese))
|
||||
|
||||
# create image-tag recognition decoder
|
||||
self.threshold = threshold
|
||||
self.num_class = len(self.tag_list)
|
||||
# q2l_config = \
|
||||
# BertConfig.from_json_file(f'{CONFIG_PATH}/configs/q2l_config.json')
|
||||
# q2l_config.encoder_width = 512
|
||||
q2l_config = BertConfig.from_dict(tagging_head)
|
||||
self.tagging_head = BertModel(
|
||||
config=q2l_config, add_pooling_layer=False)
|
||||
self.tagging_head.resize_token_embeddings(len(self.tokenizer))
|
||||
self.label_embed = nn.Parameter(
|
||||
torch.zeros(self.num_class, q2l_config.encoder_width))
|
||||
|
||||
if q2l_config.hidden_size != 512:
|
||||
self.wordvec_proj = nn.Linear(512, q2l_config.hidden_size)
|
||||
else:
|
||||
self.wordvec_proj = nn.Identity()
|
||||
|
||||
self.fc = nn.Linear(q2l_config.hidden_size, 1)
|
||||
|
||||
self.del_selfattention()
|
||||
|
||||
# share weights of the lowest 2-layer of
|
||||
# "image-tag interaction encoder" with
|
||||
# the "image-tag recogntion decoder"
|
||||
tie_encoder_decoder_weights(self.tag_encoder, self.tagging_head, '',
|
||||
' ')
|
||||
self.image_proj = nn.Linear(vision_width, 512)
|
||||
# self.label_embed = nn.Parameter(torch.load(
|
||||
# f'{CONFIG_PATH}/data/textual_label_embedding.pth',
|
||||
# map_location='cpu').float())
|
||||
|
||||
# adjust thresholds for some tags
|
||||
self.class_threshold = torch.ones(self.num_class) * self.threshold
|
||||
ram_class_threshold_path = get_path(
|
||||
'./data/ram_tag_list_threshold.pickle')
|
||||
with open(ram_class_threshold_path, 'rb') as f:
|
||||
ram_class_threshold = pickle.load(f)
|
||||
for key, value in enumerate(ram_class_threshold):
|
||||
self.class_threshold[key] = value
|
||||
|
||||
def load_tag_list(self, tag_list_file):
|
||||
with open(tag_list_file, 'rb') as f:
|
||||
tag_list = pickle.load(f)
|
||||
tag_list = np.array(tag_list)
|
||||
return tag_list
|
||||
|
||||
# delete self-attention layer of image-tag recognition decoder
|
||||
# to reduce computation, follower Query2Label
|
||||
def del_selfattention(self):
|
||||
del self.tagging_head.embeddings
|
||||
for layer in self.tagging_head.encoder.layer:
|
||||
del layer.attention
|
||||
|
||||
def get_label_embed(self):
|
||||
return torch.nn.functional.relu(self.wordvec_proj(self.label_embed))
|
||||
|
||||
def extract_visual_feature(self, images):
|
||||
image_embeds = self.visual_encoder(images)[0]
|
||||
image_embeds = image_embeds.flatten(2, 3)
|
||||
attn_pool = nn.AdaptiveAvgPool1d(1)
|
||||
cls_token = attn_pool(image_embeds).permute(0, 2, 1).contiguous()
|
||||
image_embeds = image_embeds.permute(0, 2, 1).contiguous()
|
||||
image_embeds = torch.cat([cls_token, image_embeds], dim=1)
|
||||
image_embeds = self.image_proj(image_embeds)
|
||||
image_atts = torch.ones(
|
||||
image_embeds.size()[:-1], dtype=torch.long).to(images.device)
|
||||
return image_embeds, image_atts
|
||||
|
||||
def image2tag(self, label_embed, image_embeds, image_atts):
|
||||
# recognized image tags using image-tag recogntiion decoder
|
||||
# image_cls_embeds = image_embeds[:, 0, :]
|
||||
image_spatial_embeds = image_embeds[:, 1:, :]
|
||||
|
||||
bs = image_spatial_embeds.shape[0]
|
||||
label_embed = label_embed.unsqueeze(0).repeat(bs, 1, 1)
|
||||
tagging_embed = self.tagging_head(
|
||||
encoder_embeds=label_embed,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_atts,
|
||||
return_dict=False,
|
||||
mode='tagging',
|
||||
)
|
||||
|
||||
logits = self.fc(tagging_embed[0]).squeeze(-1)
|
||||
return logits
|
||||
|
||||
def forward(
|
||||
self,
|
||||
images: torch.Tensor,
|
||||
data_samples: Optional[list] = None,
|
||||
mode: str = 'predict',
|
||||
**kwargs,
|
||||
):
|
||||
if mode == 'predict':
|
||||
return self.predict(images, data_samples, **kwargs)
|
||||
else:
|
||||
raise RuntimeError(f'Invalid mode "{mode}".')
|
||||
|
||||
@abstractmethod
|
||||
def predict(self,
|
||||
images: torch.Tensor,
|
||||
data_samples: DataSample = None) -> DataSample:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class RAMNormal(RAM):
|
||||
|
||||
def __init__(self,
|
||||
tokenizer: dict,
|
||||
vision_backbone: dict,
|
||||
tag_encoder: dict,
|
||||
tagging_head: dict,
|
||||
text_decoder: dict,
|
||||
device: str = 'cpu',
|
||||
vision_width: int = 1536,
|
||||
prompt='a picture of ',
|
||||
threshold=0.68,
|
||||
delete_tag_index=[],
|
||||
tag_list='./data/ram_tag_list.pickle',
|
||||
tag_list_chinese='./data/ram_tag_list_chinese.pickle',
|
||||
data_preprocessor: Optional[dict] = None,
|
||||
init_cfg: Optional[dict] = None):
|
||||
super().__init__(
|
||||
tokenizer,
|
||||
vision_backbone,
|
||||
tag_encoder,
|
||||
tagging_head,
|
||||
text_decoder,
|
||||
device,
|
||||
vision_width,
|
||||
prompt,
|
||||
threshold,
|
||||
delete_tag_index,
|
||||
tag_list,
|
||||
tag_list_chinese,
|
||||
data_preprocessor,
|
||||
init_cfg,
|
||||
)
|
||||
|
||||
def tag_process(self, logits):
|
||||
targets = torch.where(
|
||||
torch.sigmoid(logits) > self.class_threshold.to(logits.device),
|
||||
torch.tensor(1.0).to(logits.device),
|
||||
torch.zeros(self.num_class).to(logits.device))
|
||||
|
||||
tag = targets.cpu().numpy()
|
||||
tag[:, self.delete_tag_index] = 0
|
||||
tag_output = []
|
||||
tag_output_chinese = []
|
||||
logits_output = []
|
||||
|
||||
bs = logits.shape[0]
|
||||
for b in range(bs):
|
||||
index = np.argwhere(tag[b] == 1)
|
||||
token = self.tag_list[index].squeeze(axis=1)
|
||||
logits_output.append(
|
||||
torch.sigmoid(logits)[b][index[:, 0]].cpu().numpy())
|
||||
tag_output.append(' | '.join(token))
|
||||
token_chinese = self.tag_list_chinese[index].squeeze(axis=1)
|
||||
tag_output_chinese.append(' | '.join(token_chinese))
|
||||
|
||||
return [(tag_output, tag_output_chinese), logits_output]
|
||||
|
||||
def predict(self,
|
||||
images: torch.Tensor,
|
||||
data_samples: DataSample = None) -> DataSample:
|
||||
self.eval()
|
||||
self.to(self.device)
|
||||
images = images.to(self.device)
|
||||
label_embed = self.get_label_embed()
|
||||
image_embeds, image_atts = self.extract_visual_feature(images)
|
||||
logits = self.image2tag(label_embed, image_embeds, image_atts)
|
||||
tag_output, logits_output = self.tag_process(logits)
|
||||
data_samples.set_field(logits_output, 'logits_output')
|
||||
data_samples.set_field(tag_output, 'tag_output')
|
||||
return data_samples
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class RAMOpenset(RAMNormal):
|
||||
|
||||
def __init__(self,
|
||||
tokenizer: dict,
|
||||
vision_backbone: dict,
|
||||
tag_encoder: dict,
|
||||
tagging_head: dict,
|
||||
text_decoder: dict,
|
||||
device: str = 'cpu',
|
||||
vision_width: int = 1536,
|
||||
prompt='a picture of ',
|
||||
threshold=0.68,
|
||||
delete_tag_index=[],
|
||||
tag_list='./data/ram_tag_list.pickle',
|
||||
tag_list_chinese='./data/ram_tag_list_chinese.pickle',
|
||||
data_preprocessor: Optional[dict] = None,
|
||||
init_cfg: Optional[dict] = None):
|
||||
super().__init__(
|
||||
tokenizer,
|
||||
vision_backbone,
|
||||
tag_encoder,
|
||||
tagging_head,
|
||||
text_decoder,
|
||||
device,
|
||||
vision_width,
|
||||
prompt,
|
||||
threshold,
|
||||
delete_tag_index,
|
||||
tag_list,
|
||||
tag_list_chinese,
|
||||
data_preprocessor,
|
||||
init_cfg,
|
||||
)
|
||||
|
||||
def set_openset(self,
|
||||
categories: List[str] = None,
|
||||
clip_ckpt: str = '',
|
||||
threshold: float = 0.68):
|
||||
openset_label_embedding, openset_categories = \
|
||||
build_openset_label_embedding(
|
||||
categories, clip_ckpt
|
||||
)
|
||||
self.tag_list = np.array(openset_categories)
|
||||
self.label_embed = nn.Parameter(openset_label_embedding.float())
|
||||
self.num_class = len(openset_categories)
|
||||
|
||||
# the threshold for unseen categories is often lower
|
||||
self.class_threshold = torch.ones(self.num_class) * threshold
|
||||
|
||||
def tag_process(self, logits):
|
||||
targets = torch.where(
|
||||
torch.sigmoid(logits) > self.class_threshold.to(logits.device),
|
||||
torch.tensor(1.0).to(logits.device),
|
||||
torch.zeros(self.num_class).to(logits.device))
|
||||
|
||||
tag = targets.cpu().numpy()
|
||||
tag[:, self.delete_tag_index] = 0
|
||||
|
||||
bs = logits.shape[0]
|
||||
tag_output = []
|
||||
logits_output = []
|
||||
for b in range(bs):
|
||||
index = np.argwhere(tag[b] == 1)
|
||||
token = self.tag_list[index].squeeze(axis=1)
|
||||
logits_output.append(
|
||||
torch.sigmoid(logits)[b][index[:, 0]].cpu().numpy())
|
||||
tag_output.append(' | '.join(token))
|
||||
|
||||
return [(tag_output, [None]), logits_output]
|
|
@ -0,0 +1 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
|
||||
def inference_ram(sample, model):
|
||||
|
||||
with torch.no_grad():
|
||||
result = model.test_step(sample)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def inference_ram_openset(sample, model):
|
||||
with torch.no_grad():
|
||||
result = model.test_step(sample)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def inference(sample, model, transforms, mode='normal'):
|
||||
sample = transforms(sample)
|
||||
if sample['inputs'].ndim == 3:
|
||||
sample['inputs'] = sample['inputs'].unsqueeze(dim=0)
|
||||
assert mode in ['normal', 'openset'
|
||||
], 'mode of inference must be "normal" or "openset"'
|
||||
if mode == 'normal':
|
||||
return inference_ram(sample, model)
|
||||
else:
|
||||
return inference_ram_openset(sample, model)
|
|
@ -0,0 +1,87 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module,
|
||||
base_model_prefix: str, skip_key: str):
|
||||
uninitialized_encoder_weights: List[str] = []
|
||||
if decoder.__class__ != encoder.__class__:
|
||||
print(f'''{decoder.__class__} and {encoder.__class__} are not equal.
|
||||
In this case make sure that
|
||||
all encoder weights are correctly initialized.''')
|
||||
|
||||
def tie_encoder_to_decoder_recursively(
|
||||
decoder_pointer: nn.Module,
|
||||
encoder_pointer: nn.Module,
|
||||
module_name: str,
|
||||
uninitialized_encoder_weights: List[str],
|
||||
skip_key: str,
|
||||
depth=0,
|
||||
):
|
||||
assert isinstance(decoder_pointer, nn.Module) and isinstance(
|
||||
encoder_pointer, nn.Module
|
||||
), f'{decoder_pointer} and {encoder_pointer}' + \
|
||||
'have to be of type torch.nn.Module'
|
||||
if hasattr(decoder_pointer, 'weight') and skip_key not in module_name:
|
||||
assert hasattr(encoder_pointer, 'weight')
|
||||
encoder_pointer.weight = decoder_pointer.weight
|
||||
if hasattr(decoder_pointer, 'bias'):
|
||||
assert hasattr(encoder_pointer, 'bias')
|
||||
encoder_pointer.bias = decoder_pointer.bias
|
||||
print(module_name + ' is tied')
|
||||
return
|
||||
|
||||
encoder_modules = encoder_pointer._modules
|
||||
decoder_modules = decoder_pointer._modules
|
||||
if len(decoder_modules) > 0:
|
||||
assert (len(encoder_modules) >
|
||||
0), f'''Encoder module {encoder_pointer}
|
||||
does not match decoder module {decoder_pointer}'''
|
||||
|
||||
all_encoder_weights = set([
|
||||
module_name + '/' + sub_name
|
||||
for sub_name in encoder_modules.keys()
|
||||
])
|
||||
encoder_layer_pos = 0
|
||||
for name, module in decoder_modules.items():
|
||||
if name.isdigit():
|
||||
encoder_name = str(int(name) + encoder_layer_pos)
|
||||
decoder_name = name
|
||||
if not isinstance(
|
||||
decoder_modules[decoder_name],
|
||||
type(encoder_modules[encoder_name])) and len(
|
||||
encoder_modules) != len(decoder_modules):
|
||||
# this can happen if the name corresponds to
|
||||
# the position in a list module list of layers
|
||||
# in this case the decoder has added a
|
||||
# cross-attention that the encoder doesn't have
|
||||
# thus skip this step and
|
||||
# subtract one layer pos from encoder
|
||||
encoder_layer_pos -= 1
|
||||
continue
|
||||
elif name not in encoder_modules:
|
||||
continue
|
||||
elif depth > 500:
|
||||
raise ValueError(
|
||||
'''Max depth of recursive function `tie_encoder_to_decoder` reached.
|
||||
It seems that there is a circular dependency
|
||||
between two or more `nn.Modules` of your model.''')
|
||||
else:
|
||||
decoder_name = encoder_name = name
|
||||
tie_encoder_to_decoder_recursively(
|
||||
decoder_modules[decoder_name],
|
||||
encoder_modules[encoder_name],
|
||||
module_name + '/' + name,
|
||||
uninitialized_encoder_weights,
|
||||
skip_key,
|
||||
depth=depth + 1,
|
||||
)
|
||||
all_encoder_weights.remove(module_name + '/' + encoder_name)
|
||||
|
||||
uninitialized_encoder_weights += list(all_encoder_weights)
|
||||
|
||||
# tie weights recursively
|
||||
tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix,
|
||||
uninitialized_encoder_weights, skip_key)
|
|
@ -64,6 +64,7 @@ class iTPNHiViT(HiViT):
|
|||
layer_scale_init_value: float = 0.0,
|
||||
mask_ratio: float = 0.75,
|
||||
reconstruction_type: str = 'pixel',
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
arch=arch,
|
||||
|
@ -80,7 +81,9 @@ class iTPNHiViT(HiViT):
|
|||
norm_cfg=norm_cfg,
|
||||
ape=ape,
|
||||
rpe=rpe,
|
||||
layer_scale_init_value=layer_scale_init_value)
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.pos_embed.requires_grad = False
|
||||
self.mask_ratio = mask_ratio
|
||||
|
|
|
@ -87,7 +87,7 @@ class ResizeMix(CutMix):
|
|||
(y1, y2, x1, x2), lam = self.cutmix_bbox_and_lam(img_shape, lam)
|
||||
batch_inputs[:, :, y1:y2, x1:x2] = F.interpolate(
|
||||
batch_inputs[index],
|
||||
size=(y2 - y1, x2 - x1),
|
||||
size=(int(y2 - y1), int(x2 - x1)),
|
||||
mode=self.interpolation,
|
||||
align_corners=False)
|
||||
mixed_scores = lam * batch_scores + (1 - lam) * batch_scores[index, :]
|
||||
|
|
|
@ -12,6 +12,7 @@ from .huggingface import register_hf_tokenizer
|
|||
|
||||
register_hf_tokenizer(AutoTokenizer)
|
||||
register_hf_tokenizer(LlamaTokenizer)
|
||||
register_hf_tokenizer(BertTokenizer)
|
||||
|
||||
|
||||
@register_hf_tokenizer()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved
|
||||
|
||||
__version__ = '1.0.2'
|
||||
__version__ = '1.2.0'
|
||||
|
||||
|
||||
def parse_version_info(version_str):
|
||||
|
|
|
@ -0,0 +1,137 @@
|
|||
# Modified from
|
||||
# https://github.com/Vision-CAIR/MiniGPT-4/blob/main/minigpt4/conversation/conversation.py
|
||||
import dataclasses
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Conversation:
|
||||
system: str
|
||||
roles: List[str]
|
||||
messages: List[List[str]]
|
||||
sep: str = '###'
|
||||
|
||||
def get_prompt(self):
|
||||
ret = self.system + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + ': ' + message + self.sep
|
||||
else:
|
||||
ret += role + ':'
|
||||
return ret
|
||||
|
||||
def append_message(self, role, message):
|
||||
self.messages.append([role, message])
|
||||
|
||||
def copy(self):
|
||||
return Conversation(
|
||||
system=self.system,
|
||||
roles=[role for role in self.roles],
|
||||
messages=[[y for y in x] for x in self.messages],
|
||||
sep=self.sep,
|
||||
)
|
||||
|
||||
def dict(self):
|
||||
return {
|
||||
'system': self.system,
|
||||
'roles': self.roles,
|
||||
'messages': self.messages,
|
||||
'offset': self.offset,
|
||||
'sep': self.sep,
|
||||
}
|
||||
|
||||
|
||||
EN_CONV_VISION = Conversation(
|
||||
system='Give the following image. '
|
||||
'You will be able to see the image once I provide it to you. '
|
||||
'Please answer my questions in detail.',
|
||||
roles=['Ask', 'Answer'],
|
||||
messages=[],
|
||||
sep='###',
|
||||
)
|
||||
|
||||
ZH_CONV_VISION = Conversation(
|
||||
system='给定一张图片,请仔细观察这张图片,并回答我的问题。',
|
||||
roles=['问', '答'],
|
||||
messages=[],
|
||||
sep='###',
|
||||
)
|
||||
|
||||
|
||||
class Chat:
|
||||
|
||||
def __init__(self, inferencer, device, is_half=False):
|
||||
self.device = device
|
||||
self.inferencer = inferencer
|
||||
self.model = inferencer.model
|
||||
self.is_half = is_half
|
||||
if is_half:
|
||||
self.model = self.model.half()
|
||||
self.model = self.model.to(device)
|
||||
self.max_length = 2000
|
||||
|
||||
def upload_img(self, image, conv, img_list):
|
||||
img = next(self.inferencer.preprocess([image]))
|
||||
img = self.model.data_preprocessor(img, False)['images']
|
||||
img = img.to(self.device)
|
||||
image_emb, _ = self.model.encode_img(img)
|
||||
img_list.append(image_emb)
|
||||
conv.append_message(conv.roles[0], '<Img><ImageHere></Img>')
|
||||
|
||||
def get_context_emb(self, conv, img_list):
|
||||
prompt = conv.get_prompt()
|
||||
prompt_segs = prompt.split('<ImageHere>')
|
||||
seg_tokens = [
|
||||
self.model.llama_tokenizer(
|
||||
seg, return_tensors='pt',
|
||||
add_special_tokens=(i == 0)).to(self.device).input_ids
|
||||
for i, seg in enumerate(prompt_segs)
|
||||
]
|
||||
seg_embs = [
|
||||
self.model.llama_model.model.embed_tokens(seg_token)
|
||||
for seg_token in seg_tokens
|
||||
]
|
||||
mixed_embs = [
|
||||
emb for pair in zip(seg_embs[:-1], img_list) for emb in pair
|
||||
] + [seg_embs[-1]]
|
||||
mixed_embs = torch.cat(mixed_embs, dim=1)
|
||||
return mixed_embs
|
||||
|
||||
def ask(self, text, conv):
|
||||
if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[
|
||||
0] and conv.messages[-1][1][-6:] == '</Img>':
|
||||
conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
|
||||
else:
|
||||
conv.append_message(conv.roles[0], text)
|
||||
|
||||
def answer(self, conv, img_list, generation_cfg):
|
||||
conv.append_message(conv.roles[1], None)
|
||||
embs = self.get_context_emb(conv, img_list)
|
||||
cur_max_len = generation_cfg['max_new_tokens'] + embs.shape[1]
|
||||
if cur_max_len > self.max_length:
|
||||
print('Warning: The number of tokens in current conversation'
|
||||
'exceeds the max length. '
|
||||
'The model will not see the contexts outside the range.')
|
||||
begin_idx = max(0, cur_max_len - self.max_length)
|
||||
embs = embs[:, begin_idx:]
|
||||
if self.is_half:
|
||||
embs = embs.half()
|
||||
outputs = self.model.llama_model.generate(
|
||||
inputs_embeds=embs,
|
||||
eos_token_id=self.model.end_token_id,
|
||||
**generation_cfg)
|
||||
|
||||
output_token = outputs[0]
|
||||
if output_token[0] == 0:
|
||||
output_token = output_token[1:]
|
||||
elif output_token[0] == 1:
|
||||
output_token = output_token[1:]
|
||||
output_text = self.model.llama_tokenizer.decode(
|
||||
output_token,
|
||||
add_special_tokens=False,
|
||||
skip_special_tokens=True)
|
||||
output_text = output_text.split('###')[0]
|
||||
conv.messages[-1][1] = output_text
|
||||
return output_text
|
|
@ -0,0 +1,144 @@
|
|||
import argparse
|
||||
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
import torch
|
||||
from conversation import EN_CONV_VISION, ZH_CONV_VISION, Chat
|
||||
|
||||
from mmpretrain import ImageCaptionInferencer
|
||||
|
||||
parser = argparse.ArgumentParser(description='MiniGPT4 demo')
|
||||
parser.add_argument(
|
||||
'cfg', type=str, help='config file for minigpt4 (absolute path)')
|
||||
parser.add_argument(
|
||||
'ckpt', type=str, help='pretrained file for minigpt4 (absolute path)')
|
||||
args = parser.parse_args()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
devices = [
|
||||
torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())
|
||||
]
|
||||
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
devices = [torch.device('mps')]
|
||||
else:
|
||||
devices = [torch.device('cpu')]
|
||||
|
||||
|
||||
def get_free_device():
|
||||
if hasattr(torch.cuda, 'mem_get_info'):
|
||||
free = [torch.cuda.mem_get_info(gpu)[0] for gpu in devices]
|
||||
select = max(zip(free, range(len(free))))[1]
|
||||
else:
|
||||
import random
|
||||
select = random.randint(0, len(devices) - 1)
|
||||
return devices[select]
|
||||
|
||||
|
||||
device = get_free_device()
|
||||
inferencer = ImageCaptionInferencer(model=args.cfg, pretrained=args.ckpt)
|
||||
model = inferencer.model
|
||||
chat = Chat(inferencer, device=device, is_half=(device.type != 'cpu'))
|
||||
|
||||
|
||||
def reset(chat_state, img_list):
|
||||
if chat_state is not None:
|
||||
chat_state.messages = []
|
||||
if img_list is not None:
|
||||
img_list = []
|
||||
return (None, gr.update(value=None, interactive=True),
|
||||
gr.update(
|
||||
value=None,
|
||||
placeholder='Please upload your image first',
|
||||
interactive=False),
|
||||
gr.update(value='Upload & Start Chat',
|
||||
interactive=True), chat_state, img_list,
|
||||
gr.update(value='Restart', interactive=False),
|
||||
gr.update(value='English', interactive=True))
|
||||
|
||||
|
||||
def upload_img(gr_img, language, chat_state):
|
||||
if gr_img is None:
|
||||
return (None,
|
||||
gr.update(
|
||||
placeholder='Please upload your image first',
|
||||
interactive=False),
|
||||
gr.update(value='Upload & Start Chat',
|
||||
interactive=True), chat_state, None,
|
||||
gr.update(value='Restart', interactive=False),
|
||||
gr.update(value='English', interactive=True))
|
||||
|
||||
if (language == 'English'):
|
||||
chat_state = EN_CONV_VISION.copy()
|
||||
else:
|
||||
chat_state = ZH_CONV_VISION.copy()
|
||||
img_list = []
|
||||
gr_img_array = np.asarray(gr_img)
|
||||
chat.upload_img(gr_img_array, chat_state, img_list)
|
||||
return (gr.update(interactive=False),
|
||||
gr.update(placeholder='Type and press Enter', interactive=True),
|
||||
gr.update(value='Start Chatting',
|
||||
interactive=False), chat_state, img_list,
|
||||
gr.update(value='Restart',
|
||||
interactive=True), gr.update(interactive=False))
|
||||
|
||||
|
||||
def ask(user_message, chatbot, chat_state):
|
||||
if (len(user_message) == 0):
|
||||
return gr.update(
|
||||
value=None,
|
||||
placeholder='Input should not be empty!',
|
||||
interactive=True), chatbot, chat_state
|
||||
chat.ask(user_message, chat_state)
|
||||
chatbot = chatbot + [[user_message, None]]
|
||||
return '', chatbot, chat_state
|
||||
|
||||
|
||||
def answer(chatbot, chat_state, img_list):
|
||||
llm_message = chat.answer(
|
||||
conv=chat_state,
|
||||
img_list=img_list,
|
||||
generation_cfg=model.generation_cfg)
|
||||
chatbot[-1][1] = llm_message
|
||||
return chatbot, chat_state, img_list
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
title = 'MMPretrain MiniGPT-4 Inference Demo'
|
||||
with gr.Blocks(analytics_enabled=False, title=title) as demo:
|
||||
gr.Markdown(f'# {title}')
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
image = gr.Image(type='pil')
|
||||
language = gr.Dropdown(['English', 'Chinese'],
|
||||
label='Language',
|
||||
info='Select chatbot\'s language',
|
||||
value='English',
|
||||
interactive=True)
|
||||
upload_button = gr.Button(
|
||||
value='Upload & Start Chat', interactive=True)
|
||||
clear = gr.Button(value='Restart', interactive=False)
|
||||
|
||||
with gr.Column():
|
||||
chat_state = gr.State()
|
||||
img_list = gr.State()
|
||||
chatbot = gr.Chatbot(
|
||||
label='MiniGPT-4', min_width=320, height=600)
|
||||
text_input = gr.Textbox(
|
||||
label='User',
|
||||
placeholder='Please upload your image first',
|
||||
interactive=False)
|
||||
|
||||
upload_button.click(upload_img, [image, language, chat_state], [
|
||||
image, text_input, upload_button, chat_state, img_list, clear,
|
||||
language
|
||||
])
|
||||
text_input.submit(ask, [text_input, chatbot, chat_state],
|
||||
[text_input, chatbot, chat_state]).then(
|
||||
answer, [chatbot, chat_state, img_list],
|
||||
[chatbot, chat_state, img_list])
|
||||
clear.click(reset, [chat_state, img_list], [
|
||||
chatbot, image, text_input, upload_button, chat_state, img_list,
|
||||
clear, language
|
||||
])
|
||||
|
||||
demo.launch(share=True)
|
|
@ -1,2 +1,2 @@
|
|||
mmcv>=2.0.0,<2.1.0
|
||||
mmcv>=2.0.0,<2.4.0
|
||||
mmengine>=0.8.3,<1.0.0
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
albumentations>=0.3.2 --no-binary qudida,albumentations # For Albumentations data transform
|
||||
grad-cam >= 1.3.7 # For CAM visualization
|
||||
grad-cam >= 1.3.7,<1.5.0 # For CAM visualization
|
||||
requests # For torchserve
|
||||
scikit-learn # For t-SNE visualization and unit tests.
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 220 KiB |
|
@ -169,4 +169,5 @@ class TestRepMLP(TestCase):
|
|||
|
||||
assert len(feats_) == len(feats__)
|
||||
for i in range(len(feats)):
|
||||
self.assertTrue(torch.allclose(feats__[i], feats_[i]))
|
||||
self.assertTrue(
|
||||
torch.allclose(feats__[i], feats_[i], rtol=0.01, atol=0.01))
|
||||
|
|
|
@ -9,23 +9,21 @@ from huggingface_hub import snapshot_download
|
|||
from transformers.modeling_utils import load_state_dict
|
||||
|
||||
prog_description = """\
|
||||
Merge Llava delta weights and original weights,
|
||||
and save as MMPreTrain checkpoint.
|
||||
Convert Llava weights and original weights.
|
||||
"""
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description=prog_description)
|
||||
parser.add_argument(
|
||||
'src_path', type=str, help='The original checkpoint dir')
|
||||
parser.add_argument(
|
||||
'delta_path', type=str, help='The delta checkpoint dir')
|
||||
parser.add_argument('dst_path', type=str, help='The saved checkpoint path')
|
||||
parser.add_argument('src', type=str, help='The original checkpoint dir')
|
||||
parser.add_argument('dst', type=str, help='The saved checkpoint path')
|
||||
parser.add_argument('--delta', type=str, help='The delta checkpoint dir')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def load_checkpoint(path: Path):
|
||||
path = Path(path)
|
||||
if path.is_file():
|
||||
return torch.load(path)
|
||||
|
||||
|
@ -41,19 +39,23 @@ def load_checkpoint(path: Path):
|
|||
def main():
|
||||
args = parse_args()
|
||||
|
||||
if Path(args.src_path).exists():
|
||||
src_path = Path(args.src_path)
|
||||
if Path(args.src).exists():
|
||||
src_path = args.src
|
||||
else:
|
||||
src_path = Path(snapshot_download(args.src_path))
|
||||
src_path = snapshot_download(
|
||||
args.src, allow_patterns='pytorch_model*.bin')
|
||||
src_state_dict = load_checkpoint(src_path)
|
||||
|
||||
if Path(args.delta_path).exists():
|
||||
delta_path = Path(args.delta_path)
|
||||
if args.delta is None:
|
||||
delta_state_dict = {}
|
||||
elif Path(args.delta).exists():
|
||||
delta_state_dict = load_checkpoint(args.delta)
|
||||
else:
|
||||
delta_path = Path(snapshot_download(args.delta_path))
|
||||
delta_state_dict = load_checkpoint(delta_path)
|
||||
delta_path = snapshot_download(
|
||||
args.delta, allow_patterns='pytorch_model*.bin')
|
||||
delta_state_dict = load_checkpoint(delta_path)
|
||||
|
||||
merged_state_dict = OrderedDict()
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in src_state_dict.items():
|
||||
if k in delta_state_dict:
|
||||
delta_v = delta_state_dict.pop(k)
|
||||
|
@ -63,12 +65,13 @@ def main():
|
|||
v = delta_v
|
||||
else:
|
||||
v += delta_v
|
||||
merged_state_dict['model.lang_encoder.' + k] = v
|
||||
if 'rotary_emb.inv_freq' not in k:
|
||||
new_state_dict['model.lang_encoder.' + k] = v
|
||||
|
||||
for k, v in delta_state_dict.items():
|
||||
merged_state_dict['model.lang_encoder.' + k] = v
|
||||
new_state_dict['model.lang_encoder.' + k] = v
|
||||
|
||||
torch.save(merged_state_dict, args.dst_path)
|
||||
torch.save(new_state_dict, args.dst)
|
||||
print('Done!!')
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,117 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
|
||||
import mmengine
|
||||
import torch
|
||||
from mmengine.runner import CheckpointLoader
|
||||
|
||||
|
||||
def convert_swin(ckpt):
|
||||
new_ckpt = OrderedDict()
|
||||
convert_mapping = dict()
|
||||
|
||||
def correct_unfold_reduction_order(x):
|
||||
out_channel, in_channel = x.shape
|
||||
x = x.reshape(out_channel, 4, in_channel // 4)
|
||||
x = x[:, [0, 2, 1, 3], :].transpose(1,
|
||||
2).reshape(out_channel, in_channel)
|
||||
return x
|
||||
|
||||
def correct_unfold_norm_order(x):
|
||||
in_channel = x.shape[0]
|
||||
x = x.reshape(4, in_channel // 4)
|
||||
x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
|
||||
return x
|
||||
|
||||
for k, v in ckpt.items():
|
||||
if 'attn_mask' in k:
|
||||
continue
|
||||
if k.startswith('head'):
|
||||
continue
|
||||
elif k.startswith('layers'):
|
||||
new_v = v
|
||||
if 'attn.' in k:
|
||||
new_k = k.replace('attn.', 'attn.w_msa.')
|
||||
elif 'mlp.' in k:
|
||||
if 'mlp.fc1.' in k:
|
||||
new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.')
|
||||
elif 'mlp.fc2.' in k:
|
||||
new_k = k.replace('mlp.fc2.', 'ffn.layers.1.')
|
||||
else:
|
||||
new_k = k.replace('mlp.', 'ffn.')
|
||||
elif 'downsample' in k:
|
||||
new_k = k
|
||||
if 'reduction.' in k:
|
||||
new_v = correct_unfold_reduction_order(v)
|
||||
elif 'norm.' in k:
|
||||
new_v = correct_unfold_norm_order(v)
|
||||
else:
|
||||
new_k = k
|
||||
new_k = new_k.replace('layers', 'stages', 1)
|
||||
elif k.startswith('patch_embed'):
|
||||
new_v = v
|
||||
if 'proj' in k:
|
||||
new_k = k.replace('proj', 'projection')
|
||||
else:
|
||||
new_k = k
|
||||
elif k.startswith('norm'):
|
||||
new_v = v
|
||||
new_k = k.replace('norm', 'norm3')
|
||||
else:
|
||||
new_v = v
|
||||
new_k = k
|
||||
|
||||
new_ckpt[new_k] = new_v
|
||||
convert_mapping[k] = new_k
|
||||
|
||||
return new_ckpt, convert_mapping
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert keys in official pretrained RAM models to'
|
||||
'MMPretrain style.')
|
||||
parser.add_argument('src', help='src model path or url')
|
||||
# The dst path must be a full path of the new checkpoint.
|
||||
parser.add_argument('dst', help='save path')
|
||||
args = parser.parse_args()
|
||||
|
||||
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
elif 'model' in checkpoint:
|
||||
state_dict = checkpoint['model']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
visual_ckpt = OrderedDict()
|
||||
for key in state_dict:
|
||||
if key.startswith('visual_encoder.'):
|
||||
new_key = key.replace('visual_encoder.', '')
|
||||
visual_ckpt[new_key] = state_dict[key]
|
||||
|
||||
new_visual_ckpt, convert_mapping = convert_swin(visual_ckpt)
|
||||
new_ckpt = deepcopy(state_dict)
|
||||
for key in state_dict:
|
||||
if key.startswith('visual_encoder.'):
|
||||
if 'attn_mask' in key:
|
||||
del new_ckpt[key]
|
||||
continue
|
||||
del new_ckpt[key]
|
||||
old_key = key.replace('visual_encoder.', '')
|
||||
new_ckpt[key.replace(old_key,
|
||||
convert_mapping[old_key])] = deepcopy(
|
||||
new_visual_ckpt[key.replace(
|
||||
old_key,
|
||||
convert_mapping[old_key]).replace(
|
||||
'visual_encoder.', '')])
|
||||
|
||||
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
||||
torch.save(new_ckpt, args.dst)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -91,10 +91,6 @@ def merge_args(cfg, args):
|
|||
|
||||
# enable automatic-mixed-precision training
|
||||
if args.amp is True:
|
||||
optim_wrapper = cfg.optim_wrapper.get('type', 'OptimWrapper')
|
||||
assert optim_wrapper in ['OptimWrapper', 'AmpOptimWrapper'], \
|
||||
'`--amp` is not supported custom optimizer wrapper type ' \
|
||||
f'`{optim_wrapper}.'
|
||||
cfg.optim_wrapper.type = 'AmpOptimWrapper'
|
||||
cfg.optim_wrapper.setdefault('loss_scale', 'dynamic')
|
||||
|
||||
|
|
Loading…
Reference in New Issue