Merge remote-tracking branch 'origin/main' into pr1802/ram
commit
4584a07e65
|
@ -86,13 +86,10 @@ https://github.com/open-mmlab/mmpretrain/assets/26739999/e4dcd3a2-f895-4d1b-a351
|
|||
|
||||
## What's new
|
||||
|
||||
🌟 v1.0.2 was released in 15/08/2023
|
||||
🌟 v1.1.0 was released in 12/10/2023
|
||||
|
||||
Support [MFF](./configs/mff/) self-supervised algorithm and enhance the codebase. More details can be found in the [changelog](https://mmpretrain.readthedocs.io/en/latest/notes/changelog.html).
|
||||
|
||||
🌟 v1.0.1 was released in 28/07/2023
|
||||
|
||||
Fix some bugs and enhance the codebase. Please refer to [changelog](https://mmpretrain.readthedocs.io/en/latest/notes/changelog.html) for more details.
|
||||
- Support Mini-GPT4 training and provide a Chinese model (based on Baichuan-7B)
|
||||
- Support zero-shot classification based on CLIP.
|
||||
|
||||
🌟 v1.0.0 was released in 04/07/2023
|
||||
|
||||
|
|
|
@ -84,13 +84,10 @@ https://github.com/open-mmlab/mmpretrain/assets/26739999/e4dcd3a2-f895-4d1b-a351
|
|||
|
||||
## 更新日志
|
||||
|
||||
🌟 2023/8/15 发布了 v1.0.2 版本
|
||||
🌟 2023/10/12 发布了 v1.1.0 版本
|
||||
|
||||
支持了 [MFF](./configs/mff/) 自监督算法,增强算法库功能。细节请参考 [更新日志](https://mmpretrain.readthedocs.io/zh_CN/latest/notes/changelog.html)。
|
||||
|
||||
🌟 2023/7/28 发布了 v1.0.1 版本
|
||||
|
||||
修复部分 bug 和增强算法库功能。细节请参考 [更新日志](https://mmpretrain.readthedocs.io/zh_CN/latest/notes/changelog.html)。
|
||||
- 支持 Mini-GPT4 训练并提供一个基于 Baichuan-7B 的中文模型
|
||||
- 支持基于 CLIP 的零样本分类。
|
||||
|
||||
🌟 2023/7/4 发布了 v1.0.0 版本
|
||||
|
||||
|
@ -333,10 +330,10 @@ MMPreTrain 是一款由不同学校和公司共同贡献的开源项目。我们
|
|||
|
||||
## 欢迎加入 OpenMMLab 社区
|
||||
|
||||
扫描下方的二维码可关注 OpenMMLab 团队的 [知乎官方账号](https://www.zhihu.com/people/openmmlab),加入 OpenMMLab 团队的 [官方交流 QQ 群](https://jq.qq.com/?_wv=1027&k=aCvMxdr3) 或联络 OpenMMLab 官方微信小助手
|
||||
扫描下方的二维码可关注 OpenMMLab 团队的 [知乎官方账号](https://www.zhihu.com/people/openmmlab),扫描下方微信二维码添加喵喵好友,进入 MMPretrain 微信交流社群。【加好友申请格式:研究方向+地区+学校/公司+姓名】
|
||||
|
||||
<div align="center">
|
||||
<img src="./resources/zhihu_qrcode.jpg" height="400"/> <img src="./resources/xiaozhushou_weixin_qrcode.jpeg" height="400"/>
|
||||
<img src="./resources/zhihu_qrcode.jpg" height="400"/> <img src="./resources/miaomiao_qrcode.jpg" height="400"/>
|
||||
</div>
|
||||
|
||||
我们会在 OpenMMLab 社区为大家
|
||||
|
|
|
@ -34,9 +34,10 @@ For Vicuna model, please refer to [MiniGPT-4 page](https://github.com/Vision-CAI
|
|||
|
||||
### Pretrained models
|
||||
|
||||
| Model | Params (M) | Flops (G) | Config | Download |
|
||||
| :------------------------------ | :--------: | :-------: | :--------------------------------------: | :------------------------------------------------------------------------------------------------------------: |
|
||||
| `minigpt-4_vicuna-7b_caption`\* | 8121.32 | N/A | [config](minigpt-4_vicuna-7b_caption.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_linear-projection_20230615-714b5f52.pth) |
|
||||
| Model | Params (M) | Flops (G) | Config | Download |
|
||||
| :------------------------------ | :--------: | :-------: | :----------------------------------------: | :----------------------------------------------------------------------------------------------------------: |
|
||||
| `minigpt-4_baichuan-7b_caption` | 8094.77 | N/A | [config](minigpt-4_baichuan-7b_caption.py) | [model](https://download.openmmlab.com/mmclassification/v1/minigpt4/minigpt-4_linear_baichuan7b_20231011-5dca7ed6.pth) |
|
||||
| `minigpt-4_vicuna-7b_caption`\* | 8121.32 | N/A | [config](minigpt-4_vicuna-7b_caption.py) | [model](https://download.openmmlab.com/mmclassification/v1/minigpt4/minigpt-4_linear_vicuna7b_20230615-714b5f52.pth) |
|
||||
|
||||
*Models with * are converted from the [official repo](https://github.com/Vision-CAIR/MiniGPT-4/tree/main). The config files of these models are only for inference. We haven't reproduce the training results.*
|
||||
|
||||
|
|
|
@ -19,8 +19,19 @@ Models:
|
|||
- Task: Image Caption
|
||||
Dataset: COCO
|
||||
Metrics: null
|
||||
Weights: https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_linear-projection_20230615-714b5f52.pth
|
||||
Weights: https://download.openmmlab.com/mmclassification/v1/minigpt4/minigpt-4_linear_vicuna7b_20230615-714b5f52.pth
|
||||
Config: configs/minigpt4/minigpt-4_vicuna-7b_caption.py
|
||||
Converted From:
|
||||
Weights: https://github.com/Vision-CAIR/MiniGPT-4/tree/main
|
||||
Code: https://github.com/Vision-CAIR/MiniGPT-4/tree/main
|
||||
- Name: minigpt-4_baichuan-7b_caption
|
||||
Metadata:
|
||||
FLOPs: null
|
||||
Parameters: 8094769024
|
||||
In Collection: MiniGPT4
|
||||
Results:
|
||||
- Task: Image Caption
|
||||
Dataset: COCO
|
||||
Metrics: null
|
||||
Weights: https://download.openmmlab.com/mmclassification/v1/minigpt4/minigpt-4_linear_baichuan7b_20231011-5dca7ed6.pth
|
||||
Config: configs/minigpt4/minigpt-4_baichuan-7b_caption.py
|
||||
|
|
|
@ -0,0 +1,190 @@
|
|||
_base_ = [
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
data_preprocessor = dict(
|
||||
type='MultiModalDataPreprocessor',
|
||||
mean=[122.770938, 116.7460125, 104.09373615],
|
||||
std=[68.5005327, 66.6321579, 70.32316305],
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
# dataset settings
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
scale=(224, 224),
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='CleanCaption',
|
||||
keys='chat_content',
|
||||
remove_chars='',
|
||||
lowercase=False),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['chat_content', 'lang'],
|
||||
meta_keys=['image_id']),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=2,
|
||||
num_workers=4,
|
||||
dataset=dict(
|
||||
type='MiniGPT4Dataset',
|
||||
data_root='YOUR_DATA_DIRECTORY',
|
||||
ann_file='YOUR_DATA_FILE',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
collate_fn=dict(type='default_collate'),
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
scale=(224, 224),
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='PackInputs', meta_keys=['image_id']),
|
||||
]
|
||||
|
||||
test_evaluator = dict(
|
||||
type='COCOCaption',
|
||||
ann_file='data/coco/annotations/coco_karpathy_val_gt.json',
|
||||
)
|
||||
|
||||
test_dataloader = dict(
|
||||
batch_size=1,
|
||||
dataset=dict(
|
||||
type='COCOCaption',
|
||||
data_root='data/coco',
|
||||
ann_file='annotations/coco_karpathy_val.json',
|
||||
pipeline=test_pipeline))
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='MiniGPT4',
|
||||
vision_encoder=dict(
|
||||
type='BEiTViT',
|
||||
# eva-g without the final layer
|
||||
arch=dict(
|
||||
embed_dims=1408,
|
||||
num_layers=39,
|
||||
num_heads=16,
|
||||
feedforward_channels=6144,
|
||||
),
|
||||
img_size=224,
|
||||
patch_size=14,
|
||||
layer_scale_init_value=0.0,
|
||||
frozen_stages=39,
|
||||
use_abs_pos_emb=True,
|
||||
use_rel_pos_bias=False,
|
||||
final_norm=False,
|
||||
use_shared_rel_pos_bias=False,
|
||||
out_type='raw',
|
||||
pretrained= # noqa
|
||||
'https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_eva-g-p14_20230615-e908c021.pth' # noqa
|
||||
),
|
||||
q_former_model=dict(
|
||||
type='Qformer',
|
||||
model_style='bert-base-uncased',
|
||||
vision_model_width=1408,
|
||||
add_cross_attention=True,
|
||||
cross_attention_freq=2,
|
||||
num_query_token=32,
|
||||
pretrained= # noqa
|
||||
'https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_qformer_20230615-1dfa889c.pth' # noqa
|
||||
),
|
||||
lang_encoder=dict(
|
||||
type='AutoModelForCausalLM',
|
||||
name_or_path='baichuan-inc/baichuan-7B',
|
||||
trust_remote_code=True),
|
||||
tokenizer=dict(
|
||||
type='AutoTokenizer',
|
||||
name_or_path='baichuan-inc/baichuan-7B',
|
||||
trust_remote_code=True),
|
||||
task='caption',
|
||||
prompt_template=dict([('en', '###Ask: {} ###Answer: '),
|
||||
('zh', '###问:{} ###答:')]),
|
||||
raw_prompts=dict([
|
||||
('en', [('<Img><ImageHere></Img> '
|
||||
'Describe this image in detail.'),
|
||||
('<Img><ImageHere></Img> '
|
||||
'Take a look at this image and describe what you notice.'),
|
||||
('<Img><ImageHere></Img> '
|
||||
'Please provide a detailed description of the picture.'),
|
||||
('<Img><ImageHere></Img> '
|
||||
'Could you describe the contents of this image for me?')]),
|
||||
('zh', [('<Img><ImageHere></Img> '
|
||||
'详细描述这张图片。'), ('<Img><ImageHere></Img> '
|
||||
'浏览这张图片并描述你注意到什么。'),
|
||||
('<Img><ImageHere></Img> '
|
||||
'请对这张图片进行详细的描述。'),
|
||||
('<Img><ImageHere></Img> '
|
||||
'你能为我描述这张图片的内容吗?')])
|
||||
]),
|
||||
max_txt_len=160,
|
||||
end_sym='###')
|
||||
|
||||
strategy = dict(
|
||||
type='DeepSpeedStrategy',
|
||||
fp16=dict(
|
||||
enabled=True,
|
||||
auto_cast=False,
|
||||
fp16_master_weights_and_grads=False,
|
||||
loss_scale=0,
|
||||
loss_scale_window=1000,
|
||||
hysteresis=1,
|
||||
min_loss_scale=1,
|
||||
initial_scale_power=16,
|
||||
),
|
||||
inputs_to_half=[0],
|
||||
zero_optimization=dict(
|
||||
stage=2,
|
||||
allgather_partitions=True,
|
||||
allgather_bucket_size=2e8,
|
||||
reduce_scatter=True,
|
||||
reduce_bucket_size='auto',
|
||||
overlap_comm=True,
|
||||
contiguous_gradients=True,
|
||||
),
|
||||
)
|
||||
|
||||
# schedule settings
|
||||
optim_wrapper = dict(
|
||||
type='DeepSpeedOptimWrapper',
|
||||
optimizer=dict(type='AdamW', lr=1e-3, weight_decay=0.05))
|
||||
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-3 / 500,
|
||||
by_epoch=False,
|
||||
begin=0,
|
||||
end=500,
|
||||
),
|
||||
dict(
|
||||
type='CosineAnnealingLR',
|
||||
eta_min=2e-4,
|
||||
by_epoch=False,
|
||||
begin=500,
|
||||
),
|
||||
]
|
||||
|
||||
train_cfg = dict(by_epoch=True, max_epochs=6)
|
||||
test_cfg = dict()
|
||||
|
||||
runner_type = 'FlexibleRunner'
|
||||
|
||||
default_hooks = dict(
|
||||
checkpoint=dict(
|
||||
type='CheckpointHook',
|
||||
interval=1,
|
||||
by_epoch=True,
|
||||
save_last=True,
|
||||
max_keep_ckpts=1,
|
||||
))
|
|
@ -55,13 +55,25 @@ model = dict(
|
|||
type='AutoModelForCausalLM', name_or_path='YOUR_PATH_TO_VICUNA'),
|
||||
tokenizer=dict(type='LlamaTokenizer', name_or_path='YOUR_PATH_TO_VICUNA'),
|
||||
task='caption',
|
||||
prompt_template='###Human: {} ###Assistant: ',
|
||||
raw_prompts=[
|
||||
'<Img><ImageHere></Img> Describe this image in detail.',
|
||||
'<Img><ImageHere></Img> Take a look at this image and describe what you notice.', # noqa
|
||||
'<Img><ImageHere></Img> Please provide a detailed description of the picture.', # noqa
|
||||
'<Img><ImageHere></Img> Could you describe the contents of this image for me?', # noqa
|
||||
],
|
||||
prompt_template=dict([('en', '###Ask: {} ###Answer: '),
|
||||
('zh', '###问:{} ###答:')]),
|
||||
raw_prompts=dict([
|
||||
('en', [('<Img><ImageHere></Img> '
|
||||
'Describe this image in detail.'),
|
||||
('<Img><ImageHere></Img> '
|
||||
'Take a look at this image and describe what you notice.'),
|
||||
('<Img><ImageHere></Img> '
|
||||
'Please provide a detailed description of the picture.'),
|
||||
('<Img><ImageHere></Img> '
|
||||
'Could you describe the contents of this image for me?')]),
|
||||
('zh', [('<Img><ImageHere></Img> '
|
||||
'详细描述这张图片。'), ('<Img><ImageHere></Img> '
|
||||
'浏览这张图片并描述你注意到什么。'),
|
||||
('<Img><ImageHere></Img> '
|
||||
'请对这张图片进行详细的描述。'),
|
||||
('<Img><ImageHere></Img> '
|
||||
'你能为我描述这张图片的内容吗?')])
|
||||
]),
|
||||
max_txt_len=160,
|
||||
end_sym='###')
|
||||
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
imagenet1k:
|
||||
dataset: ImageNet-1K
|
||||
dataset: OpenDataLab/ImageNet-1K
|
||||
download_root: data
|
||||
data_root: data/imagenet
|
||||
script: tools/dataset_converters/odl_imagenet1k_preprocess.sh
|
||||
|
||||
cub:
|
||||
dataset: CUB-200-2011
|
||||
dataset: OpenDataLab/CUB-200-2011
|
||||
download_root: data
|
||||
data_root: data/CUB_200_2011
|
||||
script: tools/dataset_converters/odl_cub_preprocess.sh
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
ARG PYTORCH="1.12.1"
|
||||
ARG CUDA="11.3"
|
||||
ARG PYTORCH="2.0.1"
|
||||
ARG CUDA="11.7"
|
||||
ARG CUDNN="8"
|
||||
FROM pytorch/torchserve:latest-gpu
|
||||
|
||||
ARG MMPRE="1.0.2"
|
||||
ARG MMPRE="1.1.0"
|
||||
|
||||
ENV PYTHONUNBUFFERED TRUE
|
||||
|
||||
|
|
|
@ -1,5 +1,27 @@
|
|||
# Changelog (MMPreTrain)
|
||||
|
||||
## v1.1.0(12/10/2023)
|
||||
|
||||
### New Features
|
||||
|
||||
- [Feature] Implement of Zero-Shot CLIP Classifier ([#1737](https://github.com/open-mmlab/mmpretrain/pull/1737))
|
||||
- [Feature] Add minigpt4 gradio demo and training script. ([#1758](https://github.com/open-mmlab/mmpretrain/pull/1758))
|
||||
|
||||
### Improvements
|
||||
|
||||
- [Config] New Version of config Adapting MobileNet Algorithm ([#1774](https://github.com/open-mmlab/mmpretrain/pull/1774))
|
||||
- [Config] Support DINO self-supervised learning in project ([#1756](https://github.com/open-mmlab/mmpretrain/pull/1756))
|
||||
- [Config] New Version of config Adapting Swin Transformer Algorithm ([#1780](https://github.com/open-mmlab/mmpretrain/pull/1780))
|
||||
- [Enhance] Add iTPN Supports for Non-three channel image ([#1735](https://github.com/open-mmlab/mmpretrain/pull/1735))
|
||||
- [Docs] Update dataset download script from opendatalab to openXlab ([#1765](https://github.com/open-mmlab/mmpretrain/pull/1765))
|
||||
- [Docs] Update COCO-Retrieval dataset docs. ([#1806](https://github.com/open-mmlab/mmpretrain/pull/1806))
|
||||
|
||||
### Bug Fix
|
||||
|
||||
- Update `train.py` to compat with new config.
|
||||
- Update OFA module to compat with the latest huggingface.
|
||||
- Fix pipeline bug in ImageRetrievalInferencer.
|
||||
|
||||
## v1.0.2(15/08/2023)
|
||||
|
||||
### New Features
|
||||
|
|
|
@ -16,7 +16,7 @@ and make sure you fill in all required information in the template.
|
|||
|
||||
| MMPretrain version | MMEngine version | MMCV version |
|
||||
| :----------------: | :---------------: | :--------------: |
|
||||
| 1.0.2 (main) | mmengine >= 0.8.3 | mmcv >= 2.0.0 |
|
||||
| 1.1.0 (main) | mmengine >= 0.8.3 | mmcv >= 2.0.0 |
|
||||
| 1.0.0 | mmengine >= 0.8.0 | mmcv >= 2.0.0 |
|
||||
| 1.0.0rc8 | mmengine >= 0.7.1 | mmcv >= 2.0.0rc4 |
|
||||
| 1.0.0rc7 | mmengine >= 0.5.0 | mmcv >= 2.0.0rc4 |
|
||||
|
|
|
@ -144,15 +144,15 @@ ImageNet has multiple versions, but the most commonly used one is [ILSVRC 2012](
|
|||
|
||||
````{group-tab} Download by MIM
|
||||
|
||||
MIM supports downloading from [OpenDataLab](https://opendatalab.com/) and preprocessing ImageNet dataset with one command line.
|
||||
MIM supports downloading from [OpenXlab](https://openxlab.org.cn/datasets) and preprocessing ImageNet dataset with one command line.
|
||||
|
||||
_You need to register an account at [OpenDataLab official website](https://opendatalab.com/) and login by CLI._
|
||||
_You need to register an account at [OpenXlab official website](https://openxlab.org.cn/datasets) and login by CLI._
|
||||
|
||||
```Bash
|
||||
# install OpenDataLab CLI tools
|
||||
pip install -U opendatalab
|
||||
# log in OpenDataLab, register if you don't have an account.
|
||||
odl login
|
||||
# install OpenXlab CLI tools
|
||||
pip install -U openxlab
|
||||
# log in OpenXLab
|
||||
openxlab login
|
||||
# download and preprocess by MIM, better to execute in $MMPreTrain directory.
|
||||
mim download mmpretrain --dataset imagenet1k
|
||||
```
|
||||
|
@ -278,7 +278,7 @@ test_dataloader = val_dataloader
|
|||
| [`SUN397`](mmpretrain.datasets.SUN397)(data_root[, split, pipeline, ...]) | ["train", "test"] | [SUN397](https://vision.princeton.edu/projects/2010/SUN/) Dataset. |
|
||||
| [`VOC`](mmpretrain.datasets.VOC)(data_root[, image_set_path, pipeline, ...]) | ["train", "val", "tranval", "test"] | [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC/) Dataset. |
|
||||
|
||||
Some dataset homepage links may be unavailable, and you can download datasets through [OpenDataLab](https://opendatalab.com/), such as [Stanford Cars](https://opendatalab.com/Stanford_Cars/download).
|
||||
Some dataset homepage links may be unavailable, and you can download datasets through [OpenXLab](https://openxlab.org.cn/datasets), such as [Stanford Cars](https://openxlab.org.cn/datasets/OpenDataLab/Stanford_Cars).
|
||||
|
||||
## Supported Multi-modality Datasets
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
|
||||
| MMPretrain 版本 | MMEngine 版本 | MMCV 版本 |
|
||||
| :-------------: | :---------------: | :--------------: |
|
||||
| 1.0.2 (main) | mmengine >= 0.8.3 | mmcv >= 2.0.0 |
|
||||
| 1.1.0 (main) | mmengine >= 0.8.3 | mmcv >= 2.0.0 |
|
||||
| 1.0.0 | mmengine >= 0.8.0 | mmcv >= 2.0.0 |
|
||||
| 1.0.0rc8 | mmengine >= 0.7.1 | mmcv >= 2.0.0rc4 |
|
||||
| 1.0.0rc7 | mmengine >= 0.5.0 | mmcv >= 2.0.0rc4 |
|
||||
|
|
|
@ -142,15 +142,15 @@ ImageNet 有多个版本,但最常用的一个是 [ILSVRC 2012](http://www.ima
|
|||
|
||||
````{group-tab} MIM 下载
|
||||
|
||||
MIM支持使用一条命令行从 [OpenDataLab](https://opendatalab.com/) 下载并预处理 ImageNet 数据集。
|
||||
MIM支持使用一条命令行从 [OpenXLab](https://openxlab.org.cn/datasets?lang=zh-CN) 下载并预处理 ImageNet 数据集。
|
||||
|
||||
_需要在 [OpenDataLab 官网](https://opendatalab.com/) 注册账号并命令行登录_。
|
||||
_需要在 [OpenXLab 官网](https://openxlab.org.cn/datasets?lang=zh-CN) 注册账号并命令行登录_。
|
||||
|
||||
```Bash
|
||||
# 安装opendatalab库
|
||||
pip install -U opendatalab
|
||||
# 登录到 OpenDataLab, 如果还没有注册,请到官网注册一个
|
||||
odl login
|
||||
# 安装 OpenXLab CLI 工具
|
||||
pip install -U openxlab
|
||||
# 登录 OpenXLab
|
||||
openxlab login
|
||||
# 使用 MIM 下载数据集, 最好在 $MMPreTrain 目录执行
|
||||
mim download mmpretrain --dataset imagenet1k
|
||||
```
|
||||
|
@ -276,7 +276,7 @@ test_dataloader = val_dataloader
|
|||
| [`SUN397`](mmpretrain.datasets.SUN397)(data_root[, split, pipeline, ...]) | ["train", "test"] | [SUN397](https://vision.princeton.edu/projects/2010/SUN/) 数据集 |
|
||||
| [`VOC`](mmpretrain.datasets.VOC)(data_root[, image_set_path, pipeline, ...]) | ["train", "val", "tranval", "test"] | [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC/) 数据集 |
|
||||
|
||||
有些数据集主页链接可能已经失效,您可以通过[OpenDataLab](https://opendatalab.com/)下载数据集,例如 [Stanford Cars](https://opendatalab.com/Stanford_Cars/download)数据集。
|
||||
有些数据集主页链接可能已经失效,您可以通过[OpenXLab](https://openxlab.org.cn/datasets?lang=zh-CN)下载数据集,例如 [Stanford Cars](https://openxlab.org.cn/datasets/OpenDataLab/Stanford_Cars)数据集。
|
||||
|
||||
## OpenMMLab 2.0 标准数据集
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from .apis import * # noqa: F401, F403
|
|||
from .version import __version__
|
||||
|
||||
mmcv_minimum_version = '2.0.0'
|
||||
mmcv_maximum_version = '2.1.0'
|
||||
mmcv_maximum_version = '2.2.0'
|
||||
mmcv_version = digit_version(mmcv.__version__)
|
||||
|
||||
mmengine_minimum_version = '0.8.3'
|
||||
|
|
|
@ -108,6 +108,7 @@ class ImageRetrievalInferencer(BaseInferencer):
|
|||
# A config of dataset
|
||||
from mmpretrain.registry import DATASETS
|
||||
test_pipeline = [dict(type='LoadImageFromFile'), self.pipeline]
|
||||
prototype.setdefault('pipeline', test_pipeline)
|
||||
dataset = DATASETS.build(prototype)
|
||||
dataloader = build_dataloader(dataset)
|
||||
elif isinstance(prototype, DataLoader):
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.dataset import DefaultSampler
|
||||
|
||||
from mmpretrain.datasets import (CUB, CenterCrop, LoadImageFromFile,
|
||||
PackInputs, RandomCrop, RandomFlip, Resize)
|
||||
from mmpretrain.evaluation import Accuracy
|
||||
|
||||
# dataset settings
|
||||
dataset_type = CUB
|
||||
data_preprocessor = dict(
|
||||
num_classes=200,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type=LoadImageFromFile),
|
||||
dict(type=Resize, scale=510),
|
||||
dict(type=RandomCrop, crop_size=384),
|
||||
dict(type=RandomFlip, prob=0.5, direction='horizontal'),
|
||||
dict(type=PackInputs),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type=LoadImageFromFile),
|
||||
dict(type=Resize, scale=510),
|
||||
dict(type=CenterCrop, crop_size=384),
|
||||
dict(type=PackInputs),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=8,
|
||||
num_workers=2,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/CUB_200_2011',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type=DefaultSampler, shuffle=True),
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=8,
|
||||
num_workers=2,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/CUB_200_2011',
|
||||
split='test',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type=DefaultSampler, shuffle=False),
|
||||
)
|
||||
val_evaluator = dict(type=Accuracy, topk=(1, ))
|
||||
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -0,0 +1,89 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.dataset import DefaultSampler
|
||||
|
||||
from mmpretrain.datasets import (CenterCrop, ImageNet, LoadImageFromFile,
|
||||
PackInputs, RandAugment, RandomErasing,
|
||||
RandomFlip, RandomResizedCrop, ResizeEdge)
|
||||
from mmpretrain.evaluation import Accuracy
|
||||
|
||||
# dataset settings
|
||||
dataset_type = ImageNet
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
bgr_mean = data_preprocessor['mean'][::-1]
|
||||
bgr_std = data_preprocessor['std'][::-1]
|
||||
|
||||
train_pipeline = [
|
||||
dict(type=LoadImageFromFile),
|
||||
dict(
|
||||
type=RandomResizedCrop,
|
||||
scale=256,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type=RandomFlip, prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type=RandAugment,
|
||||
policies='timm_increasing',
|
||||
num_policies=2,
|
||||
total_level=10,
|
||||
magnitude_level=9,
|
||||
magnitude_std=0.5,
|
||||
hparams=dict(
|
||||
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
|
||||
dict(
|
||||
type=RandomErasing,
|
||||
erase_prob=0.25,
|
||||
mode='rand',
|
||||
min_area_ratio=0.02,
|
||||
max_area_ratio=1 / 3,
|
||||
fill_color=bgr_mean,
|
||||
fill_std=bgr_std),
|
||||
dict(type=PackInputs),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type=LoadImageFromFile),
|
||||
dict(
|
||||
type=ResizeEdge,
|
||||
scale=292, # ( 256 / 224 * 256 )
|
||||
edge='short',
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type=CenterCrop, crop_size=256),
|
||||
dict(type=PackInputs),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=64,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type=DefaultSampler, shuffle=True),
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=64,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type=DefaultSampler, shuffle=False),
|
||||
)
|
||||
val_evaluator = dict(type=Accuracy, topk=(1, 5))
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmpretrain.models import (CrossEntropyLoss, GlobalAveragePooling,
|
||||
ImageClassifier, LinearClsHead, SwinTransformer)
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type=ImageClassifier,
|
||||
backbone=dict(
|
||||
type=SwinTransformer,
|
||||
arch='base',
|
||||
img_size=384,
|
||||
stage_cfgs=dict(block_cfgs=dict(window_size=12))),
|
||||
neck=dict(type=GlobalAveragePooling),
|
||||
head=dict(
|
||||
type=LinearClsHead,
|
||||
num_classes=1000,
|
||||
in_channels=1024,
|
||||
loss=dict(type=CrossEntropyLoss, loss_weight=1.0),
|
||||
topk=(1, 5)))
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmpretrain.models import (GlobalAveragePooling, ImageClassifier,
|
||||
LabelSmoothLoss, LinearClsHead,
|
||||
SwinTransformerV2)
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type=ImageClassifier,
|
||||
backbone=dict(
|
||||
type=SwinTransformerV2, arch='base', img_size=384, drop_path_rate=0.2),
|
||||
neck=dict(type=GlobalAveragePooling),
|
||||
head=dict(
|
||||
type=LinearClsHead,
|
||||
num_classes=1000,
|
||||
in_channels=1024,
|
||||
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
||||
loss=dict(type=LabelSmoothLoss, label_smooth_val=0.1, mode='original'),
|
||||
cal_acc=False))
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.optim import CosineAnnealingLR, LinearLR
|
||||
from torch.optim import SGD
|
||||
|
||||
# optimizer
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(
|
||||
type=SGD, lr=0.01, momentum=0.9, weight_decay=0.0005, nesterov=True))
|
||||
|
||||
# learning policy
|
||||
param_scheduler = [
|
||||
# warm up learning rate scheduler
|
||||
dict(
|
||||
type=LinearLR,
|
||||
start_factor=0.01,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=5,
|
||||
# update by iter
|
||||
convert_to_iter_based=True),
|
||||
# main learning rate scheduler
|
||||
dict(
|
||||
type=CosineAnnealingLR,
|
||||
T_max=95,
|
||||
by_epoch=True,
|
||||
begin=5,
|
||||
end=100,
|
||||
)
|
||||
]
|
||||
|
||||
# train, val, test setting
|
||||
train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1)
|
||||
val_cfg = dict()
|
||||
test_cfg = dict()
|
||||
|
||||
# NOTE: `auto_scale_lr` is for automatically scaling LR
|
||||
# based on the actual training batch size.
|
||||
auto_scale_lr = dict(base_batch_size=64)
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, LabelSmoothLoss, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_224 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(img_size=224, drop_path_rate=0.5, stage_cfgs=None),
|
||||
head=dict(
|
||||
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
||||
loss=dict(
|
||||
type=LabelSmoothLoss,
|
||||
label_smooth_val=0.1,
|
||||
mode='original',
|
||||
loss_weight=0),
|
||||
topk=None,
|
||||
cal_acc=False),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
||||
|
||||
# schedule settings
|
||||
optim_wrapper = dict(clip_grad=dict(max_norm=5.0))
|
|
@ -0,0 +1,12 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_384 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# schedule settings
|
||||
optim_wrapper = dict(clip_grad=dict(max_norm=5.0))
|
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_224 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(arch='large', img_size=224, stage_cfgs=None),
|
||||
head=dict(in_channels=1536),
|
||||
)
|
||||
|
||||
# schedule settings
|
||||
optim_wrapper = dict(clip_grad=dict(max_norm=5.0))
|
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_384 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(arch='large'),
|
||||
head=dict(in_channels=1536),
|
||||
)
|
||||
|
||||
# schedule settings
|
||||
optim_wrapper = dict(clip_grad=dict(max_norm=5.0))
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.hooks import CheckpointHook, LoggerHook
|
||||
from mmengine.model import PretrainedInit
|
||||
from torch.optim.adamw import AdamW
|
||||
|
||||
from mmpretrain.models import ImageClassifier
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.cub_bs8_384 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_base import *
|
||||
from .._base_.schedules.cub_bs64 import *
|
||||
|
||||
# model settings
|
||||
checkpoint = 'https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin-large_3rdparty_in21k-384px.pth' # noqa
|
||||
|
||||
model.update(
|
||||
backbone=dict(
|
||||
arch='large',
|
||||
init_cfg=dict(
|
||||
type=PretrainedInit, checkpoint=checkpoint, prefix='backbone')),
|
||||
head=dict(num_classes=200, in_channels=1536))
|
||||
|
||||
# schedule settings
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(
|
||||
_delete_=True,
|
||||
type=AdamW,
|
||||
lr=5e-6,
|
||||
weight_decay=0.0005,
|
||||
eps=1e-8,
|
||||
betas=(0.9, 0.999)),
|
||||
paramwise_cfg=dict(
|
||||
norm_decay_mult=0.0,
|
||||
bias_decay_mult=0.0,
|
||||
custom_keys={
|
||||
'.absolute_pos_embed': dict(decay_mult=0.0),
|
||||
'.relative_position_bias_table': dict(decay_mult=0.0)
|
||||
}),
|
||||
clip_grad=dict(max_norm=5.0),
|
||||
)
|
||||
|
||||
default_hooks = dict(
|
||||
# log every 20 intervals
|
||||
logger=dict(type=LoggerHook, interval=20),
|
||||
# save last three checkpoints
|
||||
checkpoint=dict(type=CheckpointHook, interval=1, max_keep_ckpts=3))
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, LabelSmoothLoss, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_224 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(
|
||||
arch='small', img_size=224, drop_path_rate=0.3, stage_cfgs=None),
|
||||
head=dict(
|
||||
in_channels=768,
|
||||
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
||||
loss=dict(
|
||||
type=LabelSmoothLoss,
|
||||
label_smooth_val=0.1,
|
||||
mode='original',
|
||||
loss_weight=0),
|
||||
topk=None,
|
||||
cal_acc=False),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
||||
|
||||
# schedule settings
|
||||
optim_wrapper = dict(clip_grad=dict(max_norm=5.0))
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, LabelSmoothLoss, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_224 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(
|
||||
arch='tiny', img_size=224, drop_path_rate=0.2, stage_cfgs=None),
|
||||
head=dict(
|
||||
in_channels=768,
|
||||
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
||||
loss=dict(
|
||||
type=LabelSmoothLoss,
|
||||
label_smooth_val=0.1,
|
||||
mode='original',
|
||||
loss_weight=0),
|
||||
topk=None,
|
||||
cal_acc=False),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
||||
|
||||
# schedule settings
|
||||
optim_wrapper = dict(clip_grad=dict(max_norm=5.0))
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet21k_bs128 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(
|
||||
img_size=192, drop_path_rate=0.5, window_size=[12, 12, 12, 6]),
|
||||
head=dict(num_classes=21841),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
||||
|
||||
# dataset settings
|
||||
data_preprocessor = dict(num_classes=21841)
|
||||
|
||||
_base_['train_pipeline'][1]['scale'] = 192 # RandomResizedCrop
|
||||
_base_['test_pipeline'][1]['scale'] = 219 # ResizeEdge
|
||||
_base_['test_pipeline'][2]['crop_size'] = 192 # CenterCrop
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_256 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(
|
||||
img_size=256, drop_path_rate=0.5, window_size=[16, 16, 16, 8]),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_256 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(
|
||||
img_size=256,
|
||||
window_size=[16, 16, 16, 8],
|
||||
pretrained_window_sizes=[12, 12, 12, 6]),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_384 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(
|
||||
window_size=[24, 24, 24, 12], pretrained_window_sizes=[12, 12, 12, 6]))
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_256 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(img_size=256, drop_path_rate=0.5),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet21k_bs128 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(
|
||||
img_size=192, drop_path_rate=0.5, window_size=[12, 12, 12, 6]),
|
||||
head=dict(num_classes=21841),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
||||
|
||||
# dataset settings
|
||||
data_preprocessor = dict(num_classes=21841)
|
||||
|
||||
_base_['train_pipeline'][1]['scale'] = 192 # RandomResizedCrop
|
||||
_base_['test_pipeline'][1]['scale'] = 219 # ResizeEdge
|
||||
_base_['test_pipeline'][2]['crop_size'] = 192 # CenterCrop
|
|
@ -0,0 +1,24 @@
|
|||
# Only for evaluation
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
|
||||
from mmpretrain.models import CrossEntropyLoss
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_256 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(
|
||||
arch='large',
|
||||
img_size=256,
|
||||
window_size=[16, 16, 16, 8],
|
||||
pretrained_window_sizes=[12, 12, 12, 6]),
|
||||
head=dict(
|
||||
in_channels=1536,
|
||||
loss=dict(type=CrossEntropyLoss, loss_weight=1.0),
|
||||
topk=(1, 5)))
|
|
@ -0,0 +1,24 @@
|
|||
# Only for evaluation
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
|
||||
from mmpretrain.models import CrossEntropyLoss
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_384 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(
|
||||
arch='large',
|
||||
img_size=384,
|
||||
window_size=[24, 24, 24, 12],
|
||||
pretrained_window_sizes=[12, 12, 12, 6]),
|
||||
head=dict(
|
||||
in_channels=1536,
|
||||
loss=dict(type=CrossEntropyLoss, loss_weight=1.0),
|
||||
topk=(1, 5)))
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_256 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(
|
||||
arch='small',
|
||||
img_size=256,
|
||||
drop_path_rate=0.3,
|
||||
window_size=[16, 16, 16, 8]),
|
||||
head=dict(in_channels=768),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_256 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(arch='small', img_size=256, drop_path_rate=0.3),
|
||||
head=dict(in_channels=768),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_256 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(
|
||||
arch='tiny',
|
||||
img_size=256,
|
||||
drop_path_rate=0.2,
|
||||
window_size=[16, 16, 16, 8]),
|
||||
head=dict(in_channels=768),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_256 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(arch='tiny', img_size=256, drop_path_rate=0.2),
|
||||
head=dict(in_channels=768),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
|
@ -43,6 +43,7 @@ if WITH_MULTIMODAL:
|
|||
from .gqa_dataset import GQA
|
||||
from .iconqa import IconQA
|
||||
from .infographic_vqa import InfographicVQA
|
||||
from .minigpt4_dataset import MiniGPT4Dataset
|
||||
from .nocaps import NoCaps
|
||||
from .ocr_vqa import OCRVQA
|
||||
from .refcoco import RefCOCO
|
||||
|
@ -56,5 +57,6 @@ if WITH_MULTIMODAL:
|
|||
'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption',
|
||||
'FlamingoEvalCOCOVQA', 'Flickr30kCaption', 'Flickr30kRetrieval',
|
||||
'RefCOCO', 'VisualGenomeQA', 'ScienceQA', 'NoCaps', 'GQA', 'TextVQA',
|
||||
'VSR', 'VizWiz', 'OCRVQA', 'InfographicVQA', 'IconQA'
|
||||
'VSR', 'VizWiz', 'OCRVQA', 'InfographicVQA', 'IconQA',
|
||||
'MiniGPT4Dataset'
|
||||
])
|
||||
|
|
|
@ -1,18 +1,45 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import json
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
from typing import List
|
||||
from os import PathLike
|
||||
from typing import List, Sequence, Union
|
||||
|
||||
from mmengine import get_file_backend
|
||||
|
||||
from mmpretrain.registry import DATASETS
|
||||
from mmpretrain.registry import DATASETS, TRANSFORMS
|
||||
from .base_dataset import BaseDataset
|
||||
|
||||
|
||||
def expanduser(data_prefix):
|
||||
if isinstance(data_prefix, (str, PathLike)):
|
||||
return osp.expanduser(data_prefix)
|
||||
else:
|
||||
return data_prefix
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class COCORetrieval(BaseDataset):
|
||||
"""COCO Retrieval dataset.
|
||||
|
||||
COCO (Common Objects in Context): The COCO dataset contains more than
|
||||
330K images,each of which has approximately 5 descriptive annotations.
|
||||
This dataset was releasedin collaboration between Microsoft and Carnegie
|
||||
Mellon University
|
||||
|
||||
COCO_2014 dataset directory: ::
|
||||
|
||||
COCO_2014
|
||||
├── val2014
|
||||
├── train2014
|
||||
├── annotations
|
||||
├── instances_train2014.json
|
||||
├── instances_val2014.json
|
||||
├── person_keypoints_train2014.json
|
||||
├── person_keypoints_val2014.json
|
||||
├── captions_train2014.json
|
||||
├── captions_val2014.json
|
||||
|
||||
Args:
|
||||
ann_file (str): Annotation file path.
|
||||
test_mode (bool): Whether dataset is used for evaluation. This will
|
||||
|
@ -23,8 +50,52 @@ class COCORetrieval(BaseDataset):
|
|||
data_prefix (str | dict): Prefix for training data. Defaults to ''.
|
||||
pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
|
||||
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
||||
|
||||
Examples:
|
||||
>>> from mmpretrain.datasets import COCORetrieval
|
||||
>>> train_dataset=COCORetrieval(data_root='coco2014/')
|
||||
>>> train_dataset
|
||||
Dataset COCORetrieval
|
||||
Number of samples: 414113
|
||||
Annotation file: /coco2014/annotations/captions_train2014.json
|
||||
Prefix of images: /coco2014/
|
||||
>>> from mmpretrain.datasets import COCORetrieval
|
||||
>>> val_dataset = COCORetrieval(data_root='coco2014/')
|
||||
>>> val_dataset
|
||||
Dataset COCORetrieval
|
||||
Number of samples: 202654
|
||||
Annotation file: /coco2014/annotations/captions_val2014.json
|
||||
Prefix of images: /coco2014/
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ann_file: str,
|
||||
test_mode: bool = False,
|
||||
data_prefix: Union[str, dict] = '',
|
||||
data_root: str = '',
|
||||
pipeline: Sequence = (),
|
||||
**kwargs):
|
||||
|
||||
if isinstance(data_prefix, str):
|
||||
data_prefix = dict(img_path=expanduser(data_prefix))
|
||||
|
||||
ann_file = expanduser(ann_file)
|
||||
transforms = []
|
||||
for transform in pipeline:
|
||||
if isinstance(transform, dict):
|
||||
transforms.append(TRANSFORMS.build(transform))
|
||||
else:
|
||||
transforms.append(transform)
|
||||
|
||||
super().__init__(
|
||||
data_root=data_root,
|
||||
data_prefix=data_prefix,
|
||||
test_mode=test_mode,
|
||||
pipeline=transforms,
|
||||
ann_file=ann_file,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load data list."""
|
||||
# get file backend
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List
|
||||
|
||||
import mmengine
|
||||
from mmengine.dataset import BaseDataset
|
||||
from mmengine.fileio import get_file_backend
|
||||
|
||||
from mmpretrain.registry import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class MiniGPT4Dataset(BaseDataset):
|
||||
"""Dataset for training MiniGPT4.
|
||||
|
||||
MiniGPT4 dataset directory:
|
||||
|
||||
minigpt4_dataset
|
||||
├── image
|
||||
│ ├── id0.jpg
|
||||
│ │── id1.jpg
|
||||
│ │── id2.jpg
|
||||
│ └── ...
|
||||
└── conversation_data.json
|
||||
|
||||
The structure of conversation_data.json:
|
||||
|
||||
[
|
||||
// English data
|
||||
{
|
||||
"id": str(id0),
|
||||
"conversation": "###Ask: <Img><ImageHere></Img> [Ask content]
|
||||
###Answer: [Answer content]"
|
||||
},
|
||||
|
||||
// Chinese data
|
||||
{
|
||||
"id": str(id1),
|
||||
"conversation": "###问:<Img><ImageHere></Img> [Ask content]
|
||||
###答:[Answer content]"
|
||||
},
|
||||
|
||||
...
|
||||
]
|
||||
|
||||
Args:
|
||||
data_root (str): The root directory for ``ann_file`` and ``image``.
|
||||
ann_file (str): Conversation file path.
|
||||
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
||||
"""
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
file_backend = get_file_backend(self.data_root)
|
||||
conversation_path = file_backend.join_path(self.data_root,
|
||||
self.ann_file)
|
||||
conversation = mmengine.load(conversation_path)
|
||||
img_ids = {}
|
||||
n = 0
|
||||
for conv in conversation:
|
||||
img_id = conv['id']
|
||||
if img_id not in img_ids.keys():
|
||||
img_ids[img_id] = n
|
||||
n += 1
|
||||
|
||||
img_root = file_backend.join_path(self.data_root, 'image')
|
||||
data_list = []
|
||||
for conv in conversation:
|
||||
img_file = '{}.jpg'.format(conv['id'])
|
||||
chat_content = conv['conversation']
|
||||
lang = 'en' if chat_content.startswith('###Ask: ') else 'zh'
|
||||
data_info = {
|
||||
'image_id': img_ids[conv['id']],
|
||||
'img_path': file_backend.join_path(img_root, img_file),
|
||||
'chat_content': chat_content,
|
||||
'lang': lang,
|
||||
}
|
||||
|
||||
data_list.append(data_info)
|
||||
|
||||
return data_list
|
|
@ -14,15 +14,18 @@ class MAEPretrainHead(BaseModule):
|
|||
norm_pix_loss (bool): Whether or not normalize target.
|
||||
Defaults to False.
|
||||
patch_size (int): Patch size. Defaults to 16.
|
||||
in_channels (int): Number of input channels. Defaults to 3.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
loss: dict,
|
||||
norm_pix: bool = False,
|
||||
patch_size: int = 16) -> None:
|
||||
patch_size: int = 16,
|
||||
in_channels: int = 3) -> None:
|
||||
super().__init__()
|
||||
self.norm_pix = norm_pix
|
||||
self.patch_size = patch_size
|
||||
self.in_channels = in_channels
|
||||
self.loss_module = MODELS.build(loss)
|
||||
|
||||
def patchify(self, imgs: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -30,19 +33,19 @@ class MAEPretrainHead(BaseModule):
|
|||
|
||||
Args:
|
||||
imgs (torch.Tensor): A batch of images. The shape should
|
||||
be :math:`(B, 3, H, W)`.
|
||||
be :math:`(B, C, H, W)`.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Patchified images. The shape is
|
||||
:math:`(B, L, \text{patch_size}^2 \times 3)`.
|
||||
:math:`(B, L, \text{patch_size}^2 \times C)`.
|
||||
"""
|
||||
p = self.patch_size
|
||||
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
|
||||
|
||||
h = w = imgs.shape[2] // p
|
||||
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
|
||||
x = imgs.reshape(shape=(imgs.shape[0], self.in_channels, h, p, w, p))
|
||||
x = torch.einsum('nchpwq->nhwpqc', x)
|
||||
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
|
||||
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * self.in_channels))
|
||||
return x
|
||||
|
||||
def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -50,18 +53,18 @@ class MAEPretrainHead(BaseModule):
|
|||
|
||||
Args:
|
||||
x (torch.Tensor): The shape is
|
||||
:math:`(B, L, \text{patch_size}^2 \times 3)`.
|
||||
:math:`(B, L, \text{patch_size}^2 \times C)`.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The shape is :math:`(B, 3, H, W)`.
|
||||
torch.Tensor: The shape is :math:`(B, C, H, W)`.
|
||||
"""
|
||||
p = self.patch_size
|
||||
h = w = int(x.shape[1]**.5)
|
||||
assert h * w == x.shape[1]
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, self.in_channels))
|
||||
x = torch.einsum('nhwpqc->nchpwq', x)
|
||||
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
|
||||
imgs = x.reshape(shape=(x.shape[0], self.in_channels, h * p, h * p))
|
||||
return imgs
|
||||
|
||||
def construct_target(self, target: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -71,7 +74,7 @@ class MAEPretrainHead(BaseModule):
|
|||
normalize the image according to ``norm_pix``.
|
||||
|
||||
Args:
|
||||
target (torch.Tensor): Image with the shape of B x 3 x H x W
|
||||
target (torch.Tensor): Image with the shape of B x C x H x W
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Tokenized images with the shape of B x L x C
|
||||
|
|
|
@ -31,12 +31,12 @@ class MiniGPT4(BaseModel):
|
|||
True.
|
||||
num_query_token (int): Number of query tokens of Qformer. Defaults to
|
||||
32.
|
||||
prompt_template (str): Prompt template of the model. Defaults to
|
||||
'###Human: {} ###Assistant: '.
|
||||
raw_prompts (list): Prompts for training. Defaults to None.
|
||||
prompt_template (dict): Multi-language prompt template of the model. Defaults to dict([ ('en', '###Ask: {} ###Answer: '),
|
||||
('zh', '###问:{} ###答:')])
|
||||
raw_prompts (dict): Prompts for training. Defaults to dict().
|
||||
max_txt_len (int): Max token length while doing tokenization. Defaults
|
||||
to 32.
|
||||
end_sym (str): Ended symbol of the sequence. Defaults to '\\n'.
|
||||
end_sym (str): Ended symbol of the sequence. Defaults to '###'.
|
||||
generation_cfg (dict): The config of text generation. Defaults to
|
||||
dict().
|
||||
data_preprocessor (:obj:`BaseDataPreprocessor`): Used for
|
||||
|
@ -54,10 +54,12 @@ class MiniGPT4(BaseModel):
|
|||
freeze_vit: bool = True,
|
||||
freeze_q_former: bool = True,
|
||||
num_query_token: int = 32,
|
||||
prompt_template: str = '###Human: {} ###Assistant: ',
|
||||
raw_prompts: Optional[list] = None,
|
||||
prompt_template: dict = dict([('en',
|
||||
'###Ask: {} ###Answer: '),
|
||||
('zh', '###问:{} ###答:')]),
|
||||
raw_prompts: dict = dict(),
|
||||
max_txt_len: int = 32,
|
||||
end_sym: str = '\n',
|
||||
end_sym: str = '###',
|
||||
generation_cfg: dict = dict(),
|
||||
data_preprocessor: Optional[dict] = None,
|
||||
init_cfg: Optional[dict] = None):
|
||||
|
@ -135,16 +137,23 @@ class MiniGPT4(BaseModel):
|
|||
self.end_token_id = self.llama_tokenizer.encode(end_sym)[-1]
|
||||
|
||||
# set prompts
|
||||
if raw_prompts is not None:
|
||||
filted_prompts = [
|
||||
raw_prompt for raw_prompt in raw_prompts
|
||||
self.en_prompt_list, self.zh_prompt_list = [], []
|
||||
if raw_prompts.get('en') is not None:
|
||||
en_filted_prompts = [
|
||||
raw_prompt for raw_prompt in raw_prompts['en']
|
||||
if '<ImageHere>' in raw_prompt
|
||||
]
|
||||
self.prompt_list = [
|
||||
prompt_template.format(p) for p in filted_prompts
|
||||
self.en_prompt_list = [
|
||||
prompt_template['en'].format(p) for p in en_filted_prompts
|
||||
]
|
||||
if raw_prompts.get('zh') is not None:
|
||||
zh_filted_prompts = [
|
||||
raw_prompt for raw_prompt in raw_prompts['zh']
|
||||
if '<ImageHere>' in raw_prompt
|
||||
]
|
||||
self.zh_prompt_list = [
|
||||
prompt_template['zh'].format(p) for p in zh_filted_prompts
|
||||
]
|
||||
else:
|
||||
self.prompt_list = []
|
||||
|
||||
# update generation configs
|
||||
self.generation_cfg = dict(
|
||||
|
@ -153,7 +162,7 @@ class MiniGPT4(BaseModel):
|
|||
do_sample=True,
|
||||
min_length=1,
|
||||
top_p=0.9,
|
||||
repetition_penalty=1.0,
|
||||
repetition_penalty=1.1,
|
||||
length_penalty=1.0,
|
||||
temperature=1.0)
|
||||
self.generation_cfg.update(**generation_cfg)
|
||||
|
@ -161,6 +170,10 @@ class MiniGPT4(BaseModel):
|
|||
if hasattr(self, 'register_load_state_dict_post_hook'):
|
||||
self.register_load_state_dict_post_hook(self._load_llama_proj_hook)
|
||||
|
||||
def half(self):
|
||||
self.llama_model = self.llama_model.half()
|
||||
return self
|
||||
|
||||
def encode_img(self,
|
||||
images: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""The function to encode the images."""
|
||||
|
@ -184,33 +197,39 @@ class MiniGPT4(BaseModel):
|
|||
return inputs_llama, atts_llama
|
||||
|
||||
def prompt_wrap(self, img_embeds: torch.Tensor, atts_img: torch.Tensor,
|
||||
prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
prompt: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""The function to wrap the image and prompt.
|
||||
|
||||
Currently, the function only supports applying one prompt to all input
|
||||
images in the one batch.
|
||||
Make sure that len(prompt) == img_embeds.shape[0].
|
||||
|
||||
Args:
|
||||
img_embeds (torch.Tensor): The embedding of the input images.
|
||||
atts_img (torch.Tensor): Attention map of the image embeddings.
|
||||
prompt (str): The prompt of the batch data.
|
||||
prompt (List[str]): The prompt of the batch data.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: The embedding and attention map.
|
||||
"""
|
||||
if prompt:
|
||||
batch_size = img_embeds.shape[0]
|
||||
p_before, p_after = prompt.split('<ImageHere>')
|
||||
if len(prompt) > 0:
|
||||
p_before_list, p_after_list = [], []
|
||||
for pro in prompt:
|
||||
p_before, p_after = pro.split('<ImageHere>')
|
||||
p_before_list.append(p_before)
|
||||
p_after_list.append(p_after)
|
||||
p_before_tokens = self.llama_tokenizer(
|
||||
p_before, return_tensors='pt',
|
||||
p_before_list,
|
||||
return_tensors='pt',
|
||||
padding='longest',
|
||||
add_special_tokens=False).to(img_embeds.device)
|
||||
p_after_tokens = self.llama_tokenizer(
|
||||
p_after, return_tensors='pt',
|
||||
p_after_list,
|
||||
return_tensors='pt',
|
||||
padding='longest',
|
||||
add_special_tokens=False).to(img_embeds.device)
|
||||
p_before_embeds = self.llama_model.model.embed_tokens(
|
||||
p_before_tokens.input_ids).expand(batch_size, -1, -1)
|
||||
p_before_tokens.input_ids)
|
||||
p_after_embeds = self.llama_model.model.embed_tokens(
|
||||
p_after_tokens.input_ids).expand(batch_size, -1, -1)
|
||||
p_after_tokens.input_ids)
|
||||
wrapped_img_embeds = torch.cat(
|
||||
[p_before_embeds, img_embeds, p_after_embeds], dim=1)
|
||||
wrapped_atts_img = atts_img[:, :1].expand(
|
||||
|
@ -234,17 +253,22 @@ class MiniGPT4(BaseModel):
|
|||
"""
|
||||
img_embeds, atts_img = self.encode_img(images)
|
||||
|
||||
if self.task == 'caption' and self.prompt_list:
|
||||
prompt = random.choice(self.prompt_list)
|
||||
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img,
|
||||
prompt)
|
||||
|
||||
self.llama_tokenizer.padding_side = 'right'
|
||||
|
||||
text = [t + self.end_sym for t in data_samples['text_input']]
|
||||
prompts, texts = [], []
|
||||
for t in data_samples:
|
||||
chat_content = t.chat_content
|
||||
split_mark = '###Answer: ' if t.lang == 'en' else '###答:'
|
||||
prompt, text = chat_content.split(split_mark)
|
||||
prompt += split_mark
|
||||
text += self.end_sym
|
||||
prompts.append(prompt)
|
||||
texts.append(text)
|
||||
|
||||
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompts)
|
||||
|
||||
to_regress_tokens = self.llama_tokenizer(
|
||||
text,
|
||||
texts,
|
||||
return_tensors='pt',
|
||||
padding='longest',
|
||||
truncation=True,
|
||||
|
@ -295,10 +319,12 @@ class MiniGPT4(BaseModel):
|
|||
with torch.no_grad():
|
||||
img_embeds, atts_img = self.encode_img(images)
|
||||
|
||||
if self.task == 'caption' and self.prompt_list:
|
||||
prompt = random.choice(self.prompt_list)
|
||||
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img,
|
||||
prompt)
|
||||
prompts = [
|
||||
random.choice(self.zh_prompt_list) if hasattr(t, 'lang')
|
||||
and t.lang == 'zh' else random.choice(self.en_prompt_list)
|
||||
for t in data_samples
|
||||
]
|
||||
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompts)
|
||||
|
||||
batch_size = img_embeds.shape[0]
|
||||
bos = torch.ones(
|
||||
|
@ -336,7 +362,6 @@ class MiniGPT4(BaseModel):
|
|||
for output, data_sample in zip(outputs, data_samples):
|
||||
if self.task == 'caption':
|
||||
output = output.split('###')[0]
|
||||
output = output.split('Assistant:')[-1].strip()
|
||||
data_sample.pred_caption = output
|
||||
else:
|
||||
# raw output
|
||||
|
|
|
@ -1301,6 +1301,7 @@ class OFAEncoderDecoder(BaseModule, GenerationMixin):
|
|||
Defaults to an empty dict.
|
||||
init_cfg (dict, optional): The initialization config. Defaults to None.
|
||||
"""
|
||||
base_model_prefix = ''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -64,6 +64,7 @@ class iTPNHiViT(HiViT):
|
|||
layer_scale_init_value: float = 0.0,
|
||||
mask_ratio: float = 0.75,
|
||||
reconstruction_type: str = 'pixel',
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
arch=arch,
|
||||
|
@ -80,7 +81,9 @@ class iTPNHiViT(HiViT):
|
|||
norm_cfg=norm_cfg,
|
||||
ape=ape,
|
||||
rpe=rpe,
|
||||
layer_scale_init_value=layer_scale_init_value)
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.pos_embed.requires_grad = False
|
||||
self.mask_ratio = mask_ratio
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved
|
||||
|
||||
__version__ = '1.0.2'
|
||||
__version__ = '1.1.0'
|
||||
|
||||
|
||||
def parse_version_info(version_str):
|
||||
|
|
|
@ -0,0 +1,137 @@
|
|||
# Modified from
|
||||
# https://github.com/Vision-CAIR/MiniGPT-4/blob/main/minigpt4/conversation/conversation.py
|
||||
import dataclasses
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Conversation:
|
||||
system: str
|
||||
roles: List[str]
|
||||
messages: List[List[str]]
|
||||
sep: str = '###'
|
||||
|
||||
def get_prompt(self):
|
||||
ret = self.system + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + ': ' + message + self.sep
|
||||
else:
|
||||
ret += role + ':'
|
||||
return ret
|
||||
|
||||
def append_message(self, role, message):
|
||||
self.messages.append([role, message])
|
||||
|
||||
def copy(self):
|
||||
return Conversation(
|
||||
system=self.system,
|
||||
roles=[role for role in self.roles],
|
||||
messages=[[y for y in x] for x in self.messages],
|
||||
sep=self.sep,
|
||||
)
|
||||
|
||||
def dict(self):
|
||||
return {
|
||||
'system': self.system,
|
||||
'roles': self.roles,
|
||||
'messages': self.messages,
|
||||
'offset': self.offset,
|
||||
'sep': self.sep,
|
||||
}
|
||||
|
||||
|
||||
EN_CONV_VISION = Conversation(
|
||||
system='Give the following image. '
|
||||
'You will be able to see the image once I provide it to you. '
|
||||
'Please answer my questions in detail.',
|
||||
roles=['Ask', 'Answer'],
|
||||
messages=[],
|
||||
sep='###',
|
||||
)
|
||||
|
||||
ZH_CONV_VISION = Conversation(
|
||||
system='给定一张图片,请仔细观察这张图片,并回答我的问题。',
|
||||
roles=['问', '答'],
|
||||
messages=[],
|
||||
sep='###',
|
||||
)
|
||||
|
||||
|
||||
class Chat:
|
||||
|
||||
def __init__(self, inferencer, device, is_half=False):
|
||||
self.device = device
|
||||
self.inferencer = inferencer
|
||||
self.model = inferencer.model
|
||||
self.is_half = is_half
|
||||
if is_half:
|
||||
self.model = self.model.half()
|
||||
self.model = self.model.to(device)
|
||||
self.max_length = 2000
|
||||
|
||||
def upload_img(self, image, conv, img_list):
|
||||
img = next(self.inferencer.preprocess([image]))
|
||||
img = self.model.data_preprocessor(img, False)['images']
|
||||
img = img.to(self.device)
|
||||
image_emb, _ = self.model.encode_img(img)
|
||||
img_list.append(image_emb)
|
||||
conv.append_message(conv.roles[0], '<Img><ImageHere></Img>')
|
||||
|
||||
def get_context_emb(self, conv, img_list):
|
||||
prompt = conv.get_prompt()
|
||||
prompt_segs = prompt.split('<ImageHere>')
|
||||
seg_tokens = [
|
||||
self.model.llama_tokenizer(
|
||||
seg, return_tensors='pt',
|
||||
add_special_tokens=(i == 0)).to(self.device).input_ids
|
||||
for i, seg in enumerate(prompt_segs)
|
||||
]
|
||||
seg_embs = [
|
||||
self.model.llama_model.model.embed_tokens(seg_token)
|
||||
for seg_token in seg_tokens
|
||||
]
|
||||
mixed_embs = [
|
||||
emb for pair in zip(seg_embs[:-1], img_list) for emb in pair
|
||||
] + [seg_embs[-1]]
|
||||
mixed_embs = torch.cat(mixed_embs, dim=1)
|
||||
return mixed_embs
|
||||
|
||||
def ask(self, text, conv):
|
||||
if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[
|
||||
0] and conv.messages[-1][1][-6:] == '</Img>':
|
||||
conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
|
||||
else:
|
||||
conv.append_message(conv.roles[0], text)
|
||||
|
||||
def answer(self, conv, img_list, generation_cfg):
|
||||
conv.append_message(conv.roles[1], None)
|
||||
embs = self.get_context_emb(conv, img_list)
|
||||
cur_max_len = generation_cfg['max_new_tokens'] + embs.shape[1]
|
||||
if cur_max_len > self.max_length:
|
||||
print('Warning: The number of tokens in current conversation'
|
||||
'exceeds the max length. '
|
||||
'The model will not see the contexts outside the range.')
|
||||
begin_idx = max(0, cur_max_len - self.max_length)
|
||||
embs = embs[:, begin_idx:]
|
||||
if self.is_half:
|
||||
embs = embs.half()
|
||||
outputs = self.model.llama_model.generate(
|
||||
inputs_embeds=embs,
|
||||
eos_token_id=self.model.end_token_id,
|
||||
**generation_cfg)
|
||||
|
||||
output_token = outputs[0]
|
||||
if output_token[0] == 0:
|
||||
output_token = output_token[1:]
|
||||
elif output_token[0] == 1:
|
||||
output_token = output_token[1:]
|
||||
output_text = self.model.llama_tokenizer.decode(
|
||||
output_token,
|
||||
add_special_tokens=False,
|
||||
skip_special_tokens=True)
|
||||
output_text = output_text.split('###')[0]
|
||||
conv.messages[-1][1] = output_text
|
||||
return output_text
|
|
@ -0,0 +1,144 @@
|
|||
import argparse
|
||||
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
import torch
|
||||
from conversation import EN_CONV_VISION, ZH_CONV_VISION, Chat
|
||||
|
||||
from mmpretrain import ImageCaptionInferencer
|
||||
|
||||
parser = argparse.ArgumentParser(description='MiniGPT4 demo')
|
||||
parser.add_argument(
|
||||
'cfg', type=str, help='config file for minigpt4 (absolute path)')
|
||||
parser.add_argument(
|
||||
'ckpt', type=str, help='pretrained file for minigpt4 (absolute path)')
|
||||
args = parser.parse_args()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
devices = [
|
||||
torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())
|
||||
]
|
||||
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
devices = [torch.device('mps')]
|
||||
else:
|
||||
devices = [torch.device('cpu')]
|
||||
|
||||
|
||||
def get_free_device():
|
||||
if hasattr(torch.cuda, 'mem_get_info'):
|
||||
free = [torch.cuda.mem_get_info(gpu)[0] for gpu in devices]
|
||||
select = max(zip(free, range(len(free))))[1]
|
||||
else:
|
||||
import random
|
||||
select = random.randint(0, len(devices) - 1)
|
||||
return devices[select]
|
||||
|
||||
|
||||
device = get_free_device()
|
||||
inferencer = ImageCaptionInferencer(model=args.cfg, pretrained=args.ckpt)
|
||||
model = inferencer.model
|
||||
chat = Chat(inferencer, device=device, is_half=(device.type != 'cpu'))
|
||||
|
||||
|
||||
def reset(chat_state, img_list):
|
||||
if chat_state is not None:
|
||||
chat_state.messages = []
|
||||
if img_list is not None:
|
||||
img_list = []
|
||||
return (None, gr.update(value=None, interactive=True),
|
||||
gr.update(
|
||||
value=None,
|
||||
placeholder='Please upload your image first',
|
||||
interactive=False),
|
||||
gr.update(value='Upload & Start Chat',
|
||||
interactive=True), chat_state, img_list,
|
||||
gr.update(value='Restart', interactive=False),
|
||||
gr.update(value='English', interactive=True))
|
||||
|
||||
|
||||
def upload_img(gr_img, language, chat_state):
|
||||
if gr_img is None:
|
||||
return (None,
|
||||
gr.update(
|
||||
placeholder='Please upload your image first',
|
||||
interactive=False),
|
||||
gr.update(value='Upload & Start Chat',
|
||||
interactive=True), chat_state, None,
|
||||
gr.update(value='Restart', interactive=False),
|
||||
gr.update(value='English', interactive=True))
|
||||
|
||||
if (language == 'English'):
|
||||
chat_state = EN_CONV_VISION.copy()
|
||||
else:
|
||||
chat_state = ZH_CONV_VISION.copy()
|
||||
img_list = []
|
||||
gr_img_array = np.asarray(gr_img)
|
||||
chat.upload_img(gr_img_array, chat_state, img_list)
|
||||
return (gr.update(interactive=False),
|
||||
gr.update(placeholder='Type and press Enter', interactive=True),
|
||||
gr.update(value='Start Chatting',
|
||||
interactive=False), chat_state, img_list,
|
||||
gr.update(value='Restart',
|
||||
interactive=True), gr.update(interactive=False))
|
||||
|
||||
|
||||
def ask(user_message, chatbot, chat_state):
|
||||
if (len(user_message) == 0):
|
||||
return gr.update(
|
||||
value=None,
|
||||
placeholder='Input should not be empty!',
|
||||
interactive=True), chatbot, chat_state
|
||||
chat.ask(user_message, chat_state)
|
||||
chatbot = chatbot + [[user_message, None]]
|
||||
return '', chatbot, chat_state
|
||||
|
||||
|
||||
def answer(chatbot, chat_state, img_list):
|
||||
llm_message = chat.answer(
|
||||
conv=chat_state,
|
||||
img_list=img_list,
|
||||
generation_cfg=model.generation_cfg)
|
||||
chatbot[-1][1] = llm_message
|
||||
return chatbot, chat_state, img_list
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
title = 'MMPretrain MiniGPT-4 Inference Demo'
|
||||
with gr.Blocks(analytics_enabled=False, title=title) as demo:
|
||||
gr.Markdown(f'# {title}')
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
image = gr.Image(type='pil')
|
||||
language = gr.Dropdown(['English', 'Chinese'],
|
||||
label='Language',
|
||||
info='Select chatbot\'s language',
|
||||
value='English',
|
||||
interactive=True)
|
||||
upload_button = gr.Button(
|
||||
value='Upload & Start Chat', interactive=True)
|
||||
clear = gr.Button(value='Restart', interactive=False)
|
||||
|
||||
with gr.Column():
|
||||
chat_state = gr.State()
|
||||
img_list = gr.State()
|
||||
chatbot = gr.Chatbot(
|
||||
label='MiniGPT-4', min_width=320, height=600)
|
||||
text_input = gr.Textbox(
|
||||
label='User',
|
||||
placeholder='Please upload your image first',
|
||||
interactive=False)
|
||||
|
||||
upload_button.click(upload_img, [image, language, chat_state], [
|
||||
image, text_input, upload_button, chat_state, img_list, clear,
|
||||
language
|
||||
])
|
||||
text_input.submit(ask, [text_input, chatbot, chat_state],
|
||||
[text_input, chatbot, chat_state]).then(
|
||||
answer, [chatbot, chat_state, img_list],
|
||||
[chatbot, chat_state, img_list])
|
||||
clear.click(reset, [chat_state, img_list], [
|
||||
chatbot, image, text_input, upload_button, chat_state, img_list,
|
||||
clear, language
|
||||
])
|
||||
|
||||
demo.launch(share=True)
|
Binary file not shown.
After Width: | Height: | Size: 220 KiB |
|
@ -91,10 +91,6 @@ def merge_args(cfg, args):
|
|||
|
||||
# enable automatic-mixed-precision training
|
||||
if args.amp is True:
|
||||
optim_wrapper = cfg.optim_wrapper.get('type', 'OptimWrapper')
|
||||
assert optim_wrapper in ['OptimWrapper', 'AmpOptimWrapper'], \
|
||||
'`--amp` is not supported custom optimizer wrapper type ' \
|
||||
f'`{optim_wrapper}.'
|
||||
cfg.optim_wrapper.type = 'AmpOptimWrapper'
|
||||
cfg.optim_wrapper.setdefault('loss_scale', 'dynamic')
|
||||
|
||||
|
|
Loading…
Reference in New Issue