diff --git a/README.md b/README.md index dc5c6cde..5318df5b 100644 --- a/README.md +++ b/README.md @@ -86,13 +86,15 @@ https://github.com/open-mmlab/mmpretrain/assets/26739999/e4dcd3a2-f895-4d1b-a351 ## What's new -🌟 v1.0.2 was released in 15/08/2023 +🌟 v1.2.0 was released in 04/01/2023 -Support [MFF](./configs/mff/) self-supervised algorithm and enhance the codebase. More details can be found in the [changelog](https://mmpretrain.readthedocs.io/en/latest/notes/changelog.html). +- Support LLaVA 1.5. +- Implement of RAM with a gradio interface. -🌟 v1.0.1 was released in 28/07/2023 +🌟 v1.1.0 was released in 12/10/2023 -Fix some bugs and enhance the codebase. Please refer to [changelog](https://mmpretrain.readthedocs.io/en/latest/notes/changelog.html) for more details. +- Support Mini-GPT4 training and provide a Chinese model (based on Baichuan-7B) +- Support zero-shot classification based on CLIP. 🌟 v1.0.0 was released in 04/07/2023 diff --git a/README_zh-CN.md b/README_zh-CN.md index 6820dd64..9ee8dffc 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -84,13 +84,15 @@ https://github.com/open-mmlab/mmpretrain/assets/26739999/e4dcd3a2-f895-4d1b-a351 ## 更新日志 -🌟 2023/8/15 发布了 v1.0.2 版本 +🌟 2024/01/04 发布了 v1.2.0 版本 -支持了 [MFF](./configs/mff/) 自监督算法,增强算法库功能。细节请参考 [更新日志](https://mmpretrain.readthedocs.io/zh_CN/latest/notes/changelog.html)。 +- 支持了 LLaVA 1.5 +- 实现了一个 RAM 模型的 gradio 推理例程 -🌟 2023/7/28 发布了 v1.0.1 版本 +🌟 2023/10/12 发布了 v1.1.0 版本 -修复部分 bug 和增强算法库功能。细节请参考 [更新日志](https://mmpretrain.readthedocs.io/zh_CN/latest/notes/changelog.html)。 +- 支持 Mini-GPT4 训练并提供一个基于 Baichuan-7B 的中文模型 +- 支持基于 CLIP 的零样本分类。 🌟 2023/7/4 发布了 v1.0.0 版本 @@ -333,10 +335,10 @@ MMPreTrain 是一款由不同学校和公司共同贡献的开源项目。我们 ## 欢迎加入 OpenMMLab 社区 -扫描下方的二维码可关注 OpenMMLab 团队的 [知乎官方账号](https://www.zhihu.com/people/openmmlab),加入 OpenMMLab 团队的 [官方交流 QQ 群](https://jq.qq.com/?_wv=1027&k=aCvMxdr3) 或联络 OpenMMLab 官方微信小助手 +扫描下方的二维码可关注 OpenMMLab 团队的 [知乎官方账号](https://www.zhihu.com/people/openmmlab),扫描下方微信二维码添加喵喵好友,进入 MMPretrain 微信交流社群。【加好友申请格式:研究方向+地区+学校/公司+姓名】 <div align="center"> -<img src="./resources/zhihu_qrcode.jpg" height="400"/> <img src="./resources/xiaozhushou_weixin_qrcode.jpeg" height="400"/> +<img src="./resources/zhihu_qrcode.jpg" height="400"/> <img src="./resources/miaomiao_qrcode.jpg" height="400"/> </div> 我们会在 OpenMMLab 社区为大家 diff --git a/configs/llava/README.md b/configs/llava/README.md index 7aaf57d7..581abfe5 100644 --- a/configs/llava/README.md +++ b/configs/llava/README.md @@ -16,46 +16,28 @@ Instruction tuning large language models (LLMs) using machine-generated instruct <!-- [TABS-BEGIN] --> -**Prepare the checkpoint** - -According to the license of LLaMA, we cannot provide the merged checkpoint directly. Please use the below -script to download and get the merged the checkpoint. - -```shell -python tools/model_converters/llava-delta2mmpre.py huggyllama/llama-7b liuhaotian/LLaVA-Lightning-7B-delta-v1-1 ./LLaVA-Lightning-7B-delta-v1-1.pth -``` - **Use the model** ```python import torch from mmpretrain import get_model, inference_model -model = get_model('llava-7b-v1_caption', pretrained='MERGED_CHECKPOINT_PATH', device='cuda') -out = inference_model(model, 'demo/cat-dog.png') +out = inference_model('llava-7b-v1_caption', 'demo/cat-dog.png', device='cuda') print(out) # {'pred_caption': 'In the image, there are two cats sitting on a blanket.'} ``` -**Test Command** - -Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset). - -Test: - -```shell -python tools/test.py configs/llava/llava-7b-v1_caption.py MERGED_CHECKPOINT_PATH -``` - <!-- [TABS-END] --> ## Models and results ### Image Caption on COCO -| Model | Params (M) | BLEU-4 | CIDER | Config | Download | -| :-------------------- | :--------: | :------: | :------: | :------------------------------: | :--------------------: | -| `llava-7b-v1_caption` | 7045.82 | Upcoming | Upcoming | [config](llava-7b-v1_caption.py) | See the above tutorial | +| Model | Params (M) | Config | Download | +| :---------------------- | :--------: | :--------------------------------: | :-------------------------------------------------------------------------------------------------------------: | +| `llava-7b-v1_caption` | 7045.82 | [config](llava-7b-v1_caption.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1_liuhaotian_20231025-c9e119b6.pth) | +| `llava-7b-v1.5_caption` | 7062.90 | [config](llava-7b-v1.5_caption.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth) | +| `llava-7b-v1.5_vqa` | 7062.90 | [config](llava-7b-v1.5_vqa.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth) | ## Citation diff --git a/configs/llava/llava-7b-v1.5_caption.py b/configs/llava/llava-7b-v1.5_caption.py new file mode 100644 index 00000000..371c9b5f --- /dev/null +++ b/configs/llava/llava-7b-v1.5_caption.py @@ -0,0 +1,76 @@ +_base_ = '../_base_/default_runtime.py' + +meta_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." # noqa: E501 +image_size = 336 +prompt_tmpl = f'''{meta_prompt} User: <image> +Describe the image in detail. ASSISTANT:''' + +# model settings +model = dict( + type='Llava', + tokenizer=dict( + type='AutoTokenizer', name_or_path='liuhaotian/llava-v1.5-7b'), + vision_encoder=dict( + type='VisionTransformer', + arch='l', + patch_size=14, + img_size=image_size, + pre_norm=True, + norm_cfg=dict(type='LN', eps=1e-5), + layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')), + final_norm=False, + out_type='raw', + pretrained='https://download.openmmlab.com/mmclassification/v0/clip/' + 'vit-large-p14_clip-openai-pre_336px_20231025-fb1315ed.pth', + ), + mm_hidden_size=1024, + use_im_patch=False, + use_im_start_end=False, + mm_proj_depth=2, + lang_encoder=dict( + type='AutoModelForCausalLM', + name_or_path='huggyllama/llama-7b', + ), + task='caption', + prompt_tmpl=prompt_tmpl, + generation_cfg=dict(num_beams=3, max_new_tokens=50, length_penalty=-1.0), +) + +# data settings +data_preprocessor = dict( + type='MultiModalDataPreprocessor', + mean=[122.770938, 116.7460125, 104.09373615], + std=[68.5005327, 66.6321579, 70.32316305], + to_rgb=True, +) + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + scale=(image_size, image_size), + interpolation='bicubic', + backend='pillow'), + dict(type='PackInputs', meta_keys=['image_id']), +] + +test_dataloader = dict( + batch_size=8, + num_workers=5, + dataset=dict( + type='COCOCaption', + data_root='data/coco', + ann_file='annotations/coco_karpathy_val.json', + pipeline=test_pipeline, + ), + sampler=dict(type='DefaultSampler', shuffle=False), + persistent_workers=True, +) + +test_evaluator = dict( + type='COCOCaption', + ann_file='data/coco/annotations/coco_karpathy_val_gt.json', +) + +# schedule settings +test_cfg = dict() diff --git a/configs/llava/llava-7b-v1.5_vqa.py b/configs/llava/llava-7b-v1.5_vqa.py new file mode 100644 index 00000000..5cb9812c --- /dev/null +++ b/configs/llava/llava-7b-v1.5_vqa.py @@ -0,0 +1,76 @@ +_base_ = '../_base_/default_runtime.py' + +meta_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." # noqa: E501 +image_size = 336 +prompt_tmpl = f'''{meta_prompt} User: <image> +{{question}} ASSISTANT:''' + +# model settings +model = dict( + type='Llava', + tokenizer=dict( + type='AutoTokenizer', name_or_path='liuhaotian/llava-v1.5-7b'), + vision_encoder=dict( + type='VisionTransformer', + arch='l', + patch_size=14, + img_size=image_size, + pre_norm=True, + norm_cfg=dict(type='LN', eps=1e-5), + layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')), + final_norm=False, + out_type='raw', + pretrained='https://download.openmmlab.com/mmclassification/v0/clip/' + 'vit-large-p14_clip-openai-pre_336px_20231025-fb1315ed.pth', + ), + mm_hidden_size=1024, + use_im_patch=False, + use_im_start_end=False, + mm_proj_depth=2, + lang_encoder=dict( + type='AutoModelForCausalLM', + name_or_path='huggyllama/llama-7b', + ), + task='vqa', + prompt_tmpl=prompt_tmpl, + generation_cfg=dict(max_new_tokens=100), +) + +# data settings +data_preprocessor = dict( + type='MultiModalDataPreprocessor', + mean=[122.770938, 116.7460125, 104.09373615], + std=[68.5005327, 66.6321579, 70.32316305], + to_rgb=True, +) + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + scale=(image_size, image_size), + interpolation='bicubic', + backend='pillow'), + dict(type='PackInputs', meta_keys=['image_id', 'question']), +] + +test_dataloader = dict( + batch_size=8, + num_workers=5, + dataset=dict( + type='COCOCaption', + data_root='data/coco', + ann_file='annotations/coco_karpathy_val.json', + pipeline=test_pipeline, + ), + sampler=dict(type='DefaultSampler', shuffle=False), + persistent_workers=True, +) + +test_evaluator = dict( + type='COCOCaption', + ann_file='data/coco/annotations/coco_karpathy_val_gt.json', +) + +# schedule settings +test_cfg = dict() diff --git a/configs/llava/llava-7b-v1_caption.py b/configs/llava/llava-7b-v1_caption.py index f7558bed..92e2d1fb 100644 --- a/configs/llava/llava-7b-v1_caption.py +++ b/configs/llava/llava-7b-v1_caption.py @@ -1,16 +1,9 @@ _base_ = '../_base_/default_runtime.py' meta_prompt = 'You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab.You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.Follow the instructions carefully and explain your answers in detail.' # noqa: E501 -im_patch_token = '<im_patch>' -patch_size = 14 image_size = 224 -num_patches = (image_size // patch_size)**2 -caption_prompt = ' '.join([ - meta_prompt, - 'User: a photo of\n', - im_patch_token * num_patches, - 'ASSISTANT:', -]) +prompt_tmpl = f'''{meta_prompt} User: <im_start><image><im_end> +Describe the image in detail. ASSISTANT:''' # model settings model = dict( @@ -22,6 +15,7 @@ model = dict( type='VisionTransformer', arch='l', patch_size=14, + img_size=image_size, pre_norm=True, norm_cfg=dict(type='LN', eps=1e-5), layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')), @@ -32,15 +26,16 @@ model = dict( 'vit-large-p14_clip-openai-pre_3rdparty_20230517-95e2af0b.pth'), ), mm_hidden_size=1024, - use_im_start_end=False, - use_mm_proj=True, + use_im_patch=False, + use_im_start_end=True, + mm_proj_depth=1, lang_encoder=dict( type='AutoModelForCausalLM', name_or_path='huggyllama/llama-7b', ), task='caption', - prompt_tmpl=caption_prompt, - generation_cfg=dict(num_beams=3, max_new_tokens=20, length_penalty=-2.0), + prompt_tmpl=prompt_tmpl, + generation_cfg=dict(max_new_tokens=50), ) # data settings diff --git a/configs/llava/metafile.yml b/configs/llava/metafile.yml index 2b3cfc4d..406a214c 100644 --- a/configs/llava/metafile.yml +++ b/configs/llava/metafile.yml @@ -21,5 +21,31 @@ Models: Metrics: BLEU-4: null CIDER: null - Weights: null + Weights: https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1_liuhaotian_20231025-c9e119b6.pth Config: configs/llava/llava-7b-v1_caption.py + - Name: llava-7b-v1.5_caption + Metadata: + FLOPs: null + Parameters: 7062900736 + In Collection: LLaVA + Results: + - Task: Image Caption + Dataset: COCO + Metrics: + BLEU-4: null + CIDER: null + Weights: https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth + Config: configs/llava/llava-7b-v1.5_caption.py + - Name: llava-7b-v1.5_vqa + Metadata: + FLOPs: null + Parameters: 7062900736 + In Collection: LLaVA + Results: + - Task: Visual Question Answering + Dataset: COCO + Metrics: + BLEU-4: null + CIDER: null + Weights: https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth + Config: configs/llava/llava-7b-v1.5_vqa.py diff --git a/configs/minigpt4/README.md b/configs/minigpt4/README.md index 01e53954..23666fc9 100644 --- a/configs/minigpt4/README.md +++ b/configs/minigpt4/README.md @@ -34,9 +34,10 @@ For Vicuna model, please refer to [MiniGPT-4 page](https://github.com/Vision-CAI ### Pretrained models -| Model | Params (M) | Flops (G) | Config | Download | -| :------------------------------ | :--------: | :-------: | :--------------------------------------: | :------------------------------------------------------------------------------------------------------------: | -| `minigpt-4_vicuna-7b_caption`\* | 8121.32 | N/A | [config](minigpt-4_vicuna-7b_caption.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_linear-projection_20230615-714b5f52.pth) | +| Model | Params (M) | Flops (G) | Config | Download | +| :------------------------------ | :--------: | :-------: | :----------------------------------------: | :----------------------------------------------------------------------------------------------------------: | +| `minigpt-4_baichuan-7b_caption` | 8094.77 | N/A | [config](minigpt-4_baichuan-7b_caption.py) | [model](https://download.openmmlab.com/mmclassification/v1/minigpt4/minigpt-4_linear_baichuan7b_20231011-5dca7ed6.pth) | +| `minigpt-4_vicuna-7b_caption`\* | 8121.32 | N/A | [config](minigpt-4_vicuna-7b_caption.py) | [model](https://download.openmmlab.com/mmclassification/v1/minigpt4/minigpt-4_linear_vicuna7b_20230615-714b5f52.pth) | *Models with * are converted from the [official repo](https://github.com/Vision-CAIR/MiniGPT-4/tree/main). The config files of these models are only for inference. We haven't reproduce the training results.* diff --git a/configs/minigpt4/metafile.yml b/configs/minigpt4/metafile.yml index a7879d98..f70cc9ba 100644 --- a/configs/minigpt4/metafile.yml +++ b/configs/minigpt4/metafile.yml @@ -19,8 +19,19 @@ Models: - Task: Image Caption Dataset: COCO Metrics: null - Weights: https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_linear-projection_20230615-714b5f52.pth + Weights: https://download.openmmlab.com/mmclassification/v1/minigpt4/minigpt-4_linear_vicuna7b_20230615-714b5f52.pth Config: configs/minigpt4/minigpt-4_vicuna-7b_caption.py Converted From: Weights: https://github.com/Vision-CAIR/MiniGPT-4/tree/main Code: https://github.com/Vision-CAIR/MiniGPT-4/tree/main + - Name: minigpt-4_baichuan-7b_caption + Metadata: + FLOPs: null + Parameters: 8094769024 + In Collection: MiniGPT4 + Results: + - Task: Image Caption + Dataset: COCO + Metrics: null + Weights: https://download.openmmlab.com/mmclassification/v1/minigpt4/minigpt-4_linear_baichuan7b_20231011-5dca7ed6.pth + Config: configs/minigpt4/minigpt-4_baichuan-7b_caption.py diff --git a/configs/minigpt4/minigpt-4_baichuan-7b_caption.py b/configs/minigpt4/minigpt-4_baichuan-7b_caption.py new file mode 100644 index 00000000..7e610a09 --- /dev/null +++ b/configs/minigpt4/minigpt-4_baichuan-7b_caption.py @@ -0,0 +1,190 @@ +_base_ = [ + '../_base_/default_runtime.py', +] + +data_preprocessor = dict( + type='MultiModalDataPreprocessor', + mean=[122.770938, 116.7460125, 104.09373615], + std=[68.5005327, 66.6321579, 70.32316305], + to_rgb=True, +) + +# dataset settings +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + scale=(224, 224), + interpolation='bicubic', + backend='pillow'), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict( + type='CleanCaption', + keys='chat_content', + remove_chars='', + lowercase=False), + dict( + type='PackInputs', + algorithm_keys=['chat_content', 'lang'], + meta_keys=['image_id']), +] + +train_dataloader = dict( + batch_size=2, + num_workers=4, + dataset=dict( + type='MiniGPT4Dataset', + data_root='YOUR_DATA_DIRECTORY', + ann_file='YOUR_DATA_FILE', + pipeline=train_pipeline), + sampler=dict(type='DefaultSampler', shuffle=True), + collate_fn=dict(type='default_collate'), + drop_last=False, +) + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + scale=(224, 224), + interpolation='bicubic', + backend='pillow'), + dict(type='PackInputs', meta_keys=['image_id']), +] + +test_evaluator = dict( + type='COCOCaption', + ann_file='data/coco/annotations/coco_karpathy_val_gt.json', +) + +test_dataloader = dict( + batch_size=1, + dataset=dict( + type='COCOCaption', + data_root='data/coco', + ann_file='annotations/coco_karpathy_val.json', + pipeline=test_pipeline)) + +# model settings +model = dict( + type='MiniGPT4', + vision_encoder=dict( + type='BEiTViT', + # eva-g without the final layer + arch=dict( + embed_dims=1408, + num_layers=39, + num_heads=16, + feedforward_channels=6144, + ), + img_size=224, + patch_size=14, + layer_scale_init_value=0.0, + frozen_stages=39, + use_abs_pos_emb=True, + use_rel_pos_bias=False, + final_norm=False, + use_shared_rel_pos_bias=False, + out_type='raw', + pretrained= # noqa + 'https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_eva-g-p14_20230615-e908c021.pth' # noqa + ), + q_former_model=dict( + type='Qformer', + model_style='bert-base-uncased', + vision_model_width=1408, + add_cross_attention=True, + cross_attention_freq=2, + num_query_token=32, + pretrained= # noqa + 'https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_qformer_20230615-1dfa889c.pth' # noqa + ), + lang_encoder=dict( + type='AutoModelForCausalLM', + name_or_path='baichuan-inc/baichuan-7B', + trust_remote_code=True), + tokenizer=dict( + type='AutoTokenizer', + name_or_path='baichuan-inc/baichuan-7B', + trust_remote_code=True), + task='caption', + prompt_template=dict([('en', '###Ask: {} ###Answer: '), + ('zh', '###问:{} ###答:')]), + raw_prompts=dict([ + ('en', [('<Img><ImageHere></Img> ' + 'Describe this image in detail.'), + ('<Img><ImageHere></Img> ' + 'Take a look at this image and describe what you notice.'), + ('<Img><ImageHere></Img> ' + 'Please provide a detailed description of the picture.'), + ('<Img><ImageHere></Img> ' + 'Could you describe the contents of this image for me?')]), + ('zh', [('<Img><ImageHere></Img> ' + '详细描述这张图片。'), ('<Img><ImageHere></Img> ' + '浏览这张图片并描述你注意到什么。'), + ('<Img><ImageHere></Img> ' + '请对这张图片进行详细的描述。'), + ('<Img><ImageHere></Img> ' + '你能为我描述这张图片的内容吗?')]) + ]), + max_txt_len=160, + end_sym='###') + +strategy = dict( + type='DeepSpeedStrategy', + fp16=dict( + enabled=True, + auto_cast=False, + fp16_master_weights_and_grads=False, + loss_scale=0, + loss_scale_window=1000, + hysteresis=1, + min_loss_scale=1, + initial_scale_power=16, + ), + inputs_to_half=[0], + zero_optimization=dict( + stage=2, + allgather_partitions=True, + allgather_bucket_size=2e8, + reduce_scatter=True, + reduce_bucket_size='auto', + overlap_comm=True, + contiguous_gradients=True, + ), +) + +# schedule settings +optim_wrapper = dict( + type='DeepSpeedOptimWrapper', + optimizer=dict(type='AdamW', lr=1e-3, weight_decay=0.05)) + +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-3 / 500, + by_epoch=False, + begin=0, + end=500, + ), + dict( + type='CosineAnnealingLR', + eta_min=2e-4, + by_epoch=False, + begin=500, + ), +] + +train_cfg = dict(by_epoch=True, max_epochs=6) +test_cfg = dict() + +runner_type = 'FlexibleRunner' + +default_hooks = dict( + checkpoint=dict( + type='CheckpointHook', + interval=1, + by_epoch=True, + save_last=True, + max_keep_ckpts=1, + )) diff --git a/configs/minigpt4/minigpt-4_vicuna-7b_caption.py b/configs/minigpt4/minigpt-4_vicuna-7b_caption.py index 704760af..f468e2d8 100644 --- a/configs/minigpt4/minigpt-4_vicuna-7b_caption.py +++ b/configs/minigpt4/minigpt-4_vicuna-7b_caption.py @@ -55,13 +55,25 @@ model = dict( type='AutoModelForCausalLM', name_or_path='YOUR_PATH_TO_VICUNA'), tokenizer=dict(type='LlamaTokenizer', name_or_path='YOUR_PATH_TO_VICUNA'), task='caption', - prompt_template='###Human: {} ###Assistant: ', - raw_prompts=[ - '<Img><ImageHere></Img> Describe this image in detail.', - '<Img><ImageHere></Img> Take a look at this image and describe what you notice.', # noqa - '<Img><ImageHere></Img> Please provide a detailed description of the picture.', # noqa - '<Img><ImageHere></Img> Could you describe the contents of this image for me?', # noqa - ], + prompt_template=dict([('en', '###Ask: {} ###Answer: '), + ('zh', '###问:{} ###答:')]), + raw_prompts=dict([ + ('en', [('<Img><ImageHere></Img> ' + 'Describe this image in detail.'), + ('<Img><ImageHere></Img> ' + 'Take a look at this image and describe what you notice.'), + ('<Img><ImageHere></Img> ' + 'Please provide a detailed description of the picture.'), + ('<Img><ImageHere></Img> ' + 'Could you describe the contents of this image for me?')]), + ('zh', [('<Img><ImageHere></Img> ' + '详细描述这张图片。'), ('<Img><ImageHere></Img> ' + '浏览这张图片并描述你注意到什么。'), + ('<Img><ImageHere></Img> ' + '请对这张图片进行详细的描述。'), + ('<Img><ImageHere></Img> ' + '你能为我描述这张图片的内容吗?')]) + ]), max_txt_len=160, end_sym='###') diff --git a/dataset-index.yml b/dataset-index.yml index ecf7f5b5..40ca6206 100644 --- a/dataset-index.yml +++ b/dataset-index.yml @@ -1,11 +1,11 @@ imagenet1k: - dataset: ImageNet-1K + dataset: OpenDataLab/ImageNet-1K download_root: data data_root: data/imagenet script: tools/dataset_converters/odl_imagenet1k_preprocess.sh cub: - dataset: CUB-200-2011 + dataset: OpenDataLab/CUB-200-2011 download_root: data data_root: data/CUB_200_2011 script: tools/dataset_converters/odl_cub_preprocess.sh diff --git a/docker/serve/Dockerfile b/docker/serve/Dockerfile index bff871b7..c50c4e8e 100644 --- a/docker/serve/Dockerfile +++ b/docker/serve/Dockerfile @@ -1,9 +1,9 @@ -ARG PYTORCH="1.12.1" -ARG CUDA="11.3" +ARG PYTORCH="2.0.1" +ARG CUDA="11.7" ARG CUDNN="8" FROM pytorch/torchserve:latest-gpu -ARG MMPRE="1.0.2" +ARG MMPRE="1.2.0" ENV PYTHONUNBUFFERED TRUE diff --git a/docs/en/notes/changelog.md b/docs/en/notes/changelog.md index f84d691a..499ed24f 100644 --- a/docs/en/notes/changelog.md +++ b/docs/en/notes/changelog.md @@ -1,5 +1,38 @@ # Changelog (MMPreTrain) +## v1.2.0(04/01/2024) + +### New Features + +- [Feature] Support LLaVA 1.5 ([#1853](https://github.com/open-mmlab/mmpretrain/pull/1853)) +- [Feature] Implement of RAM with a gradio interface. ([#1802](https://github.com/open-mmlab/mmpretrain/pull/1802)) + +### Bug Fix + +- [Fix] Fix resize mix argument bug. + +## v1.1.0(12/10/2023) + +### New Features + +- [Feature] Implement of Zero-Shot CLIP Classifier ([#1737](https://github.com/open-mmlab/mmpretrain/pull/1737)) +- [Feature] Add minigpt4 gradio demo and training script. ([#1758](https://github.com/open-mmlab/mmpretrain/pull/1758)) + +### Improvements + +- [Config] New Version of config Adapting MobileNet Algorithm ([#1774](https://github.com/open-mmlab/mmpretrain/pull/1774)) +- [Config] Support DINO self-supervised learning in project ([#1756](https://github.com/open-mmlab/mmpretrain/pull/1756)) +- [Config] New Version of config Adapting Swin Transformer Algorithm ([#1780](https://github.com/open-mmlab/mmpretrain/pull/1780)) +- [Enhance] Add iTPN Supports for Non-three channel image ([#1735](https://github.com/open-mmlab/mmpretrain/pull/1735)) +- [Docs] Update dataset download script from opendatalab to openXlab ([#1765](https://github.com/open-mmlab/mmpretrain/pull/1765)) +- [Docs] Update COCO-Retrieval dataset docs. ([#1806](https://github.com/open-mmlab/mmpretrain/pull/1806)) + +### Bug Fix + +- Update `train.py` to compat with new config. +- Update OFA module to compat with the latest huggingface. +- Fix pipeline bug in ImageRetrievalInferencer. + ## v1.0.2(15/08/2023) ### New Features diff --git a/docs/en/notes/faq.md b/docs/en/notes/faq.md index 9f78a048..da45841b 100644 --- a/docs/en/notes/faq.md +++ b/docs/en/notes/faq.md @@ -16,7 +16,8 @@ and make sure you fill in all required information in the template. | MMPretrain version | MMEngine version | MMCV version | | :----------------: | :---------------: | :--------------: | - | 1.0.2 (main) | mmengine >= 0.8.3 | mmcv >= 2.0.0 | + | 1.2.0 (main) | mmengine >= 0.8.3 | mmcv >= 2.0.0 | + | 1.1.1 | mmengine >= 0.8.3 | mmcv >= 2.0.0 | | 1.0.0 | mmengine >= 0.8.0 | mmcv >= 2.0.0 | | 1.0.0rc8 | mmengine >= 0.7.1 | mmcv >= 2.0.0rc4 | | 1.0.0rc7 | mmengine >= 0.5.0 | mmcv >= 2.0.0rc4 | diff --git a/docs/en/user_guides/dataset_prepare.md b/docs/en/user_guides/dataset_prepare.md index 7421be22..17ec229b 100644 --- a/docs/en/user_guides/dataset_prepare.md +++ b/docs/en/user_guides/dataset_prepare.md @@ -144,15 +144,15 @@ ImageNet has multiple versions, but the most commonly used one is [ILSVRC 2012]( ````{group-tab} Download by MIM -MIM supports downloading from [OpenDataLab](https://opendatalab.com/) and preprocessing ImageNet dataset with one command line. +MIM supports downloading from [OpenXlab](https://openxlab.org.cn/datasets) and preprocessing ImageNet dataset with one command line. -_You need to register an account at [OpenDataLab official website](https://opendatalab.com/) and login by CLI._ +_You need to register an account at [OpenXlab official website](https://openxlab.org.cn/datasets) and login by CLI._ ```Bash -# install OpenDataLab CLI tools -pip install -U opendatalab -# log in OpenDataLab, register if you don't have an account. -odl login +# install OpenXlab CLI tools +pip install -U openxlab +# log in OpenXLab +openxlab login # download and preprocess by MIM, better to execute in $MMPreTrain directory. mim download mmpretrain --dataset imagenet1k ``` @@ -278,7 +278,7 @@ test_dataloader = val_dataloader | [`SUN397`](mmpretrain.datasets.SUN397)(data_root[, split, pipeline, ...]) | ["train", "test"] | [SUN397](https://vision.princeton.edu/projects/2010/SUN/) Dataset. | | [`VOC`](mmpretrain.datasets.VOC)(data_root[, image_set_path, pipeline, ...]) | ["train", "val", "tranval", "test"] | [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC/) Dataset. | -Some dataset homepage links may be unavailable, and you can download datasets through [OpenDataLab](https://opendatalab.com/), such as [Stanford Cars](https://opendatalab.com/Stanford_Cars/download). +Some dataset homepage links may be unavailable, and you can download datasets through [OpenXLab](https://openxlab.org.cn/datasets), such as [Stanford Cars](https://openxlab.org.cn/datasets/OpenDataLab/Stanford_Cars). ## Supported Multi-modality Datasets diff --git a/docs/zh_CN/notes/faq.md b/docs/zh_CN/notes/faq.md index efd2ff5e..9e94cd8b 100644 --- a/docs/zh_CN/notes/faq.md +++ b/docs/zh_CN/notes/faq.md @@ -13,7 +13,8 @@ | MMPretrain 版本 | MMEngine 版本 | MMCV 版本 | | :-------------: | :---------------: | :--------------: | - | 1.0.2 (main) | mmengine >= 0.8.3 | mmcv >= 2.0.0 | + | 1.2.0 (main) | mmengine >= 0.8.3 | mmcv >= 2.0.0 | + | 1.1.1 | mmengine >= 0.8.3 | mmcv >= 2.0.0 | | 1.0.0 | mmengine >= 0.8.0 | mmcv >= 2.0.0 | | 1.0.0rc8 | mmengine >= 0.7.1 | mmcv >= 2.0.0rc4 | | 1.0.0rc7 | mmengine >= 0.5.0 | mmcv >= 2.0.0rc4 | diff --git a/docs/zh_CN/user_guides/dataset_prepare.md b/docs/zh_CN/user_guides/dataset_prepare.md index 59a0d0af..aa1e1fde 100644 --- a/docs/zh_CN/user_guides/dataset_prepare.md +++ b/docs/zh_CN/user_guides/dataset_prepare.md @@ -142,15 +142,15 @@ ImageNet 有多个版本,但最常用的一个是 [ILSVRC 2012](http://www.ima ````{group-tab} MIM 下载 -MIM支持使用一条命令行从 [OpenDataLab](https://opendatalab.com/) 下载并预处理 ImageNet 数据集。 +MIM支持使用一条命令行从 [OpenXLab](https://openxlab.org.cn/datasets?lang=zh-CN) 下载并预处理 ImageNet 数据集。 -_需要在 [OpenDataLab 官网](https://opendatalab.com/) 注册账号并命令行登录_。 +_需要在 [OpenXLab 官网](https://openxlab.org.cn/datasets?lang=zh-CN) 注册账号并命令行登录_。 ```Bash -# 安装opendatalab库 -pip install -U opendatalab -# 登录到 OpenDataLab, 如果还没有注册,请到官网注册一个 -odl login +# 安装 OpenXLab CLI 工具 +pip install -U openxlab +# 登录 OpenXLab +openxlab login # 使用 MIM 下载数据集, 最好在 $MMPreTrain 目录执行 mim download mmpretrain --dataset imagenet1k ``` @@ -276,7 +276,7 @@ test_dataloader = val_dataloader | [`SUN397`](mmpretrain.datasets.SUN397)(data_root[, split, pipeline, ...]) | ["train", "test"] | [SUN397](https://vision.princeton.edu/projects/2010/SUN/) 数据集 | | [`VOC`](mmpretrain.datasets.VOC)(data_root[, image_set_path, pipeline, ...]) | ["train", "val", "tranval", "test"] | [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC/) 数据集 | -有些数据集主页链接可能已经失效,您可以通过[OpenDataLab](https://opendatalab.com/)下载数据集,例如 [Stanford Cars](https://opendatalab.com/Stanford_Cars/download)数据集。 +有些数据集主页链接可能已经失效,您可以通过[OpenXLab](https://openxlab.org.cn/datasets?lang=zh-CN)下载数据集,例如 [Stanford Cars](https://openxlab.org.cn/datasets/OpenDataLab/Stanford_Cars)数据集。 ## OpenMMLab 2.0 标准数据集 diff --git a/mmpretrain/__init__.py b/mmpretrain/__init__.py index 0b0f573f..66866a86 100644 --- a/mmpretrain/__init__.py +++ b/mmpretrain/__init__.py @@ -7,7 +7,7 @@ from .apis import * # noqa: F401, F403 from .version import __version__ mmcv_minimum_version = '2.0.0' -mmcv_maximum_version = '2.1.0' +mmcv_maximum_version = '2.4.0' mmcv_version = digit_version(mmcv.__version__) mmengine_minimum_version = '0.8.3' diff --git a/mmpretrain/apis/image_retrieval.py b/mmpretrain/apis/image_retrieval.py index deae1de7..27919b20 100644 --- a/mmpretrain/apis/image_retrieval.py +++ b/mmpretrain/apis/image_retrieval.py @@ -108,6 +108,7 @@ class ImageRetrievalInferencer(BaseInferencer): # A config of dataset from mmpretrain.registry import DATASETS test_pipeline = [dict(type='LoadImageFromFile'), self.pipeline] + prototype.setdefault('pipeline', test_pipeline) dataset = DATASETS.build(prototype) dataloader = build_dataloader(dataset) elif isinstance(prototype, DataLoader): diff --git a/mmpretrain/configs/_base_/datasets/cub_bs8_384.py b/mmpretrain/configs/_base_/datasets/cub_bs8_384.py new file mode 100644 index 00000000..b193bf83 --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/cub_bs8_384.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler + +from mmpretrain.datasets import (CUB, CenterCrop, LoadImageFromFile, + PackInputs, RandomCrop, RandomFlip, Resize) +from mmpretrain.evaluation import Accuracy + +# dataset settings +dataset_type = CUB +data_preprocessor = dict( + num_classes=200, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=Resize, scale=510), + dict(type=RandomCrop, crop_size=384), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=Resize, scale=510), + dict(type=CenterCrop, crop_size=384), + dict(type=PackInputs), +] + +train_dataloader = dict( + batch_size=8, + num_workers=2, + dataset=dict( + type=dataset_type, + data_root='data/CUB_200_2011', + split='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), +) + +val_dataloader = dict( + batch_size=8, + num_workers=2, + dataset=dict( + type=dataset_type, + data_root='data/CUB_200_2011', + split='test', + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), +) +val_evaluator = dict(type=Accuracy, topk=(1, )) + +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs64_swin_256.py b/mmpretrain/configs/_base_/datasets/imagenet_bs64_swin_256.py new file mode 100644 index 00000000..9690ff84 --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/imagenet_bs64_swin_256.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler + +from mmpretrain.datasets import (CenterCrop, ImageNet, LoadImageFromFile, + PackInputs, RandAugment, RandomErasing, + RandomFlip, RandomResizedCrop, ResizeEdge) +from mmpretrain.evaluation import Accuracy + +# dataset settings +dataset_type = ImageNet +data_preprocessor = dict( + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +bgr_mean = data_preprocessor['mean'][::-1] +bgr_std = data_preprocessor['std'][::-1] + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=RandomResizedCrop, + scale=256, + backend='pillow', + interpolation='bicubic'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict( + type=RandAugment, + policies='timm_increasing', + num_policies=2, + total_level=10, + magnitude_level=9, + magnitude_std=0.5, + hparams=dict( + pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')), + dict( + type=RandomErasing, + erase_prob=0.25, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=1 / 3, + fill_color=bgr_mean, + fill_std=bgr_std), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=ResizeEdge, + scale=292, # ( 256 / 224 * 256 ) + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type=CenterCrop, crop_size=256), + dict(type=PackInputs), +] + +train_dataloader = dict( + batch_size=64, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), +) + +val_dataloader = dict( + batch_size=64, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='val', + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), +) +val_evaluator = dict(type=Accuracy, topk=(1, 5)) + +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/mmpretrain/configs/_base_/models/swin_transformer_base.py b/mmpretrain/configs/_base_/models/swin_transformer_base.py new file mode 100644 index 00000000..c73c254d --- /dev/null +++ b/mmpretrain/configs/_base_/models/swin_transformer_base.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmpretrain.models import (CrossEntropyLoss, GlobalAveragePooling, + ImageClassifier, LinearClsHead, SwinTransformer) + +# model settings +model = dict( + type=ImageClassifier, + backbone=dict( + type=SwinTransformer, + arch='base', + img_size=384, + stage_cfgs=dict(block_cfgs=dict(window_size=12))), + neck=dict(type=GlobalAveragePooling), + head=dict( + type=LinearClsHead, + num_classes=1000, + in_channels=1024, + loss=dict(type=CrossEntropyLoss, loss_weight=1.0), + topk=(1, 5))) diff --git a/mmpretrain/configs/_base_/models/swin_transformer_v2_base.py b/mmpretrain/configs/_base_/models/swin_transformer_v2_base.py new file mode 100644 index 00000000..c7566b5e --- /dev/null +++ b/mmpretrain/configs/_base_/models/swin_transformer_v2_base.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmpretrain.models import (GlobalAveragePooling, ImageClassifier, + LabelSmoothLoss, LinearClsHead, + SwinTransformerV2) + +# model settings +model = dict( + type=ImageClassifier, + backbone=dict( + type=SwinTransformerV2, arch='base', img_size=384, drop_path_rate=0.2), + neck=dict(type=GlobalAveragePooling), + head=dict( + type=LinearClsHead, + num_classes=1000, + in_channels=1024, + init_cfg=None, # suppress the default init_cfg of LinearClsHead. + loss=dict(type=LabelSmoothLoss, label_smooth_val=0.1, mode='original'), + cal_acc=False)) diff --git a/mmpretrain/configs/_base_/schedules/cub_bs64.py b/mmpretrain/configs/_base_/schedules/cub_bs64.py new file mode 100644 index 00000000..2ca40bfe --- /dev/null +++ b/mmpretrain/configs/_base_/schedules/cub_bs64.py @@ -0,0 +1,39 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.optim import CosineAnnealingLR, LinearLR +from torch.optim import SGD + +# optimizer +optim_wrapper = dict( + optimizer=dict( + type=SGD, lr=0.01, momentum=0.9, weight_decay=0.0005, nesterov=True)) + +# learning policy +param_scheduler = [ + # warm up learning rate scheduler + dict( + type=LinearLR, + start_factor=0.01, + by_epoch=True, + begin=0, + end=5, + # update by iter + convert_to_iter_based=True), + # main learning rate scheduler + dict( + type=CosineAnnealingLR, + T_max=95, + by_epoch=True, + begin=5, + end=100, + ) +] + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=64) diff --git a/mmpretrain/configs/swin_transformer/swin_base_16xb64_in1k.py b/mmpretrain/configs/swin_transformer/swin_base_16xb64_in1k.py new file mode 100644 index 00000000..09af3d01 --- /dev/null +++ b/mmpretrain/configs/swin_transformer/swin_base_16xb64_in1k.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, LabelSmoothLoss, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_224 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict(img_size=224, drop_path_rate=0.5, stage_cfgs=None), + head=dict( + init_cfg=None, # suppress the default init_cfg of LinearClsHead. + loss=dict( + type=LabelSmoothLoss, + label_smooth_val=0.1, + mode='original', + loss_weight=0), + topk=None, + cal_acc=False), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) + +# schedule settings +optim_wrapper = dict(clip_grad=dict(max_norm=5.0)) diff --git a/mmpretrain/configs/swin_transformer/swin_base_16xb64_in1k_384px.py b/mmpretrain/configs/swin_transformer/swin_base_16xb64_in1k_384px.py new file mode 100644 index 00000000..aacdc327 --- /dev/null +++ b/mmpretrain/configs/swin_transformer/swin_base_16xb64_in1k_384px.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_384 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# schedule settings +optim_wrapper = dict(clip_grad=dict(max_norm=5.0)) diff --git a/mmpretrain/configs/swin_transformer/swin_large_16xb64_in1k.py b/mmpretrain/configs/swin_transformer/swin_large_16xb64_in1k.py new file mode 100644 index 00000000..b8fc2793 --- /dev/null +++ b/mmpretrain/configs/swin_transformer/swin_large_16xb64_in1k.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_224 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict(arch='large', img_size=224, stage_cfgs=None), + head=dict(in_channels=1536), +) + +# schedule settings +optim_wrapper = dict(clip_grad=dict(max_norm=5.0)) diff --git a/mmpretrain/configs/swin_transformer/swin_large_16xb64_in1k_384px.py b/mmpretrain/configs/swin_transformer/swin_large_16xb64_in1k_384px.py new file mode 100644 index 00000000..9a449aa6 --- /dev/null +++ b/mmpretrain/configs/swin_transformer/swin_large_16xb64_in1k_384px.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_384 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict(arch='large'), + head=dict(in_channels=1536), +) + +# schedule settings +optim_wrapper = dict(clip_grad=dict(max_norm=5.0)) diff --git a/mmpretrain/configs/swin_transformer/swin_large_8xb8_cub_384px.py b/mmpretrain/configs/swin_transformer/swin_large_8xb8_cub_384px.py new file mode 100644 index 00000000..2003cd3a --- /dev/null +++ b/mmpretrain/configs/swin_transformer/swin_large_8xb8_cub_384px.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.hooks import CheckpointHook, LoggerHook +from mmengine.model import PretrainedInit +from torch.optim.adamw import AdamW + +from mmpretrain.models import ImageClassifier + +with read_base(): + from .._base_.datasets.cub_bs8_384 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_base import * + from .._base_.schedules.cub_bs64 import * + +# model settings +checkpoint = 'https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin-large_3rdparty_in21k-384px.pth' # noqa + +model.update( + backbone=dict( + arch='large', + init_cfg=dict( + type=PretrainedInit, checkpoint=checkpoint, prefix='backbone')), + head=dict(num_classes=200, in_channels=1536)) + +# schedule settings +optim_wrapper = dict( + optimizer=dict( + _delete_=True, + type=AdamW, + lr=5e-6, + weight_decay=0.0005, + eps=1e-8, + betas=(0.9, 0.999)), + paramwise_cfg=dict( + norm_decay_mult=0.0, + bias_decay_mult=0.0, + custom_keys={ + '.absolute_pos_embed': dict(decay_mult=0.0), + '.relative_position_bias_table': dict(decay_mult=0.0) + }), + clip_grad=dict(max_norm=5.0), +) + +default_hooks = dict( + # log every 20 intervals + logger=dict(type=LoggerHook, interval=20), + # save last three checkpoints + checkpoint=dict(type=CheckpointHook, interval=1, max_keep_ckpts=3)) diff --git a/mmpretrain/configs/swin_transformer/swin_small_16xb64_in1k.py b/mmpretrain/configs/swin_transformer/swin_small_16xb64_in1k.py new file mode 100644 index 00000000..59792528 --- /dev/null +++ b/mmpretrain/configs/swin_transformer/swin_small_16xb64_in1k.py @@ -0,0 +1,37 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, LabelSmoothLoss, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_224 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict( + arch='small', img_size=224, drop_path_rate=0.3, stage_cfgs=None), + head=dict( + in_channels=768, + init_cfg=None, # suppress the default init_cfg of LinearClsHead. + loss=dict( + type=LabelSmoothLoss, + label_smooth_val=0.1, + mode='original', + loss_weight=0), + topk=None, + cal_acc=False), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) + +# schedule settings +optim_wrapper = dict(clip_grad=dict(max_norm=5.0)) diff --git a/mmpretrain/configs/swin_transformer/swin_tiny_16xb64_in1k.py b/mmpretrain/configs/swin_transformer/swin_tiny_16xb64_in1k.py new file mode 100644 index 00000000..733e1ef0 --- /dev/null +++ b/mmpretrain/configs/swin_transformer/swin_tiny_16xb64_in1k.py @@ -0,0 +1,37 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, LabelSmoothLoss, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_224 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict( + arch='tiny', img_size=224, drop_path_rate=0.2, stage_cfgs=None), + head=dict( + in_channels=768, + init_cfg=None, # suppress the default init_cfg of LinearClsHead. + loss=dict( + type=LabelSmoothLoss, + label_smooth_val=0.1, + mode='original', + loss_weight=0), + topk=None, + cal_acc=False), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) + +# schedule settings +optim_wrapper = dict(clip_grad=dict(max_norm=5.0)) diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_base_w12_8xb128_in21k_192px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_base_w12_8xb128_in21k_192px.py new file mode 100644 index 00000000..1ecc4363 --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_base_w12_8xb128_in21k_192px.py @@ -0,0 +1,32 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, Mixup + +with read_base(): + from .._base_.datasets.imagenet21k_bs128 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict( + img_size=192, drop_path_rate=0.5, window_size=[12, 12, 12, 6]), + head=dict(num_classes=21841), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) + +# dataset settings +data_preprocessor = dict(num_classes=21841) + +_base_['train_pipeline'][1]['scale'] = 192 # RandomResizedCrop +_base_['test_pipeline'][1]['scale'] = 219 # ResizeEdge +_base_['test_pipeline'][2]['crop_size'] = 192 # CenterCrop diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_base_w16_16xb64_in1k_256px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_base_w16_16xb64_in1k_256px.py new file mode 100644 index 00000000..103afb42 --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_base_w16_16xb64_in1k_256px.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_256 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict( + img_size=256, drop_path_rate=0.5, window_size=[16, 16, 16, 8]), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_base_w16_in21k_pre_16xb64_in1k_256px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_base_w16_in21k_pre_16xb64_in1k_256px.py new file mode 100644 index 00000000..6588f50f --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_base_w16_in21k_pre_16xb64_in1k_256px.py @@ -0,0 +1,26 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_256 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict( + img_size=256, + window_size=[16, 16, 16, 8], + pretrained_window_sizes=[12, 12, 12, 6]), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_base_w24_in21k_pre_16xb64_in1k_384px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_base_w24_in21k_pre_16xb64_in1k_384px.py new file mode 100644 index 00000000..118c085e --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_base_w24_in21k_pre_16xb64_in1k_384px.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_384 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict( + window_size=[24, 24, 24, 12], pretrained_window_sizes=[12, 12, 12, 6])) diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_base_w8_16xb64_in1k_256px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_base_w8_16xb64_in1k_256px.py new file mode 100644 index 00000000..d40144cb --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_base_w8_16xb64_in1k_256px.py @@ -0,0 +1,23 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_256 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict(img_size=256, drop_path_rate=0.5), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_large_w12_8xb128_in21k_192px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_large_w12_8xb128_in21k_192px.py new file mode 100644 index 00000000..1ecc4363 --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_large_w12_8xb128_in21k_192px.py @@ -0,0 +1,32 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, Mixup + +with read_base(): + from .._base_.datasets.imagenet21k_bs128 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict( + img_size=192, drop_path_rate=0.5, window_size=[12, 12, 12, 6]), + head=dict(num_classes=21841), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) + +# dataset settings +data_preprocessor = dict(num_classes=21841) + +_base_['train_pipeline'][1]['scale'] = 192 # RandomResizedCrop +_base_['test_pipeline'][1]['scale'] = 219 # ResizeEdge +_base_['test_pipeline'][2]['crop_size'] = 192 # CenterCrop diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_large_w16_in21k_pre_16xb64_in1k_256px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_large_w16_in21k_pre_16xb64_in1k_256px.py new file mode 100644 index 00000000..0a1b59df --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_large_w16_in21k_pre_16xb64_in1k_256px.py @@ -0,0 +1,24 @@ +# Only for evaluation +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +from mmpretrain.models import CrossEntropyLoss + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_256 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict( + arch='large', + img_size=256, + window_size=[16, 16, 16, 8], + pretrained_window_sizes=[12, 12, 12, 6]), + head=dict( + in_channels=1536, + loss=dict(type=CrossEntropyLoss, loss_weight=1.0), + topk=(1, 5))) diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_large_w24_in21k_pre_16xb64_in1k_384px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_large_w24_in21k_pre_16xb64_in1k_384px.py new file mode 100644 index 00000000..b20bcead --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_large_w24_in21k_pre_16xb64_in1k_384px.py @@ -0,0 +1,24 @@ +# Only for evaluation +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +from mmpretrain.models import CrossEntropyLoss + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_384 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict( + arch='large', + img_size=384, + window_size=[24, 24, 24, 12], + pretrained_window_sizes=[12, 12, 12, 6]), + head=dict( + in_channels=1536, + loss=dict(type=CrossEntropyLoss, loss_weight=1.0), + topk=(1, 5))) diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_small_w16_16xb64_in1k_256px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_small_w16_16xb64_in1k_256px.py new file mode 100644 index 00000000..dfd15c31 --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_small_w16_16xb64_in1k_256px.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_256 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict( + arch='small', + img_size=256, + drop_path_rate=0.3, + window_size=[16, 16, 16, 8]), + head=dict(in_channels=768), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_small_w8_16xb64_in1k_256px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_small_w8_16xb64_in1k_256px.py new file mode 100644 index 00000000..bfec3466 --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_small_w8_16xb64_in1k_256px.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_256 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict(arch='small', img_size=256, drop_path_rate=0.3), + head=dict(in_channels=768), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_tiny_w16_16xb64_in1k_256px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_tiny_w16_16xb64_in1k_256px.py new file mode 100644 index 00000000..f2fa1609 --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_tiny_w16_16xb64_in1k_256px.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_256 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict( + arch='tiny', + img_size=256, + drop_path_rate=0.2, + window_size=[16, 16, 16, 8]), + head=dict(in_channels=768), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) diff --git a/mmpretrain/configs/swin_transformer_v2/swinv2_tiny_w8_16xb64_in1k_256px.py b/mmpretrain/configs/swin_transformer_v2/swinv2_tiny_w8_16xb64_in1k_256px.py new file mode 100644 index 00000000..8cca2b38 --- /dev/null +++ b/mmpretrain/configs/swin_transformer_v2/swinv2_tiny_w8_16xb64_in1k_256px.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base +from mmengine.model import ConstantInit, TruncNormalInit + +from mmpretrain.models import CutMix, Mixup + +with read_base(): + from .._base_.datasets.imagenet_bs64_swin_256 import * + from .._base_.default_runtime import * + from .._base_.models.swin_transformer_v2_base import * + from .._base_.schedules.imagenet_bs1024_adamw_swin import * + +# model settings +model.update( + backbone=dict(arch='tiny', img_size=256, drop_path_rate=0.2), + head=dict(in_channels=768), + init_cfg=[ + dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.), + dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict( + augments=[dict(type=Mixup, alpha=0.8), + dict(type=CutMix, alpha=1.0)])) diff --git a/mmpretrain/datasets/__init__.py b/mmpretrain/datasets/__init__.py index 29753d70..e621e157 100644 --- a/mmpretrain/datasets/__init__.py +++ b/mmpretrain/datasets/__init__.py @@ -43,6 +43,7 @@ if WITH_MULTIMODAL: from .gqa_dataset import GQA from .iconqa import IconQA from .infographic_vqa import InfographicVQA + from .minigpt4_dataset import MiniGPT4Dataset from .nocaps import NoCaps from .ocr_vqa import OCRVQA from .refcoco import RefCOCO @@ -56,5 +57,6 @@ if WITH_MULTIMODAL: 'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption', 'FlamingoEvalCOCOVQA', 'Flickr30kCaption', 'Flickr30kRetrieval', 'RefCOCO', 'VisualGenomeQA', 'ScienceQA', 'NoCaps', 'GQA', 'TextVQA', - 'VSR', 'VizWiz', 'OCRVQA', 'InfographicVQA', 'IconQA' + 'VSR', 'VizWiz', 'OCRVQA', 'InfographicVQA', 'IconQA', + 'MiniGPT4Dataset' ]) diff --git a/mmpretrain/datasets/coco_retrieval.py b/mmpretrain/datasets/coco_retrieval.py index 60d1586a..be8a0bcb 100644 --- a/mmpretrain/datasets/coco_retrieval.py +++ b/mmpretrain/datasets/coco_retrieval.py @@ -1,18 +1,45 @@ # Copyright (c) OpenMMLab. All rights reserved. import json +import os.path as osp from collections import OrderedDict -from typing import List +from os import PathLike +from typing import List, Sequence, Union from mmengine import get_file_backend -from mmpretrain.registry import DATASETS +from mmpretrain.registry import DATASETS, TRANSFORMS from .base_dataset import BaseDataset +def expanduser(data_prefix): + if isinstance(data_prefix, (str, PathLike)): + return osp.expanduser(data_prefix) + else: + return data_prefix + + @DATASETS.register_module() class COCORetrieval(BaseDataset): """COCO Retrieval dataset. + COCO (Common Objects in Context): The COCO dataset contains more than + 330K images,each of which has approximately 5 descriptive annotations. + This dataset was releasedin collaboration between Microsoft and Carnegie + Mellon University + + COCO_2014 dataset directory: :: + + COCO_2014 + ├── val2014 + ├── train2014 + ├── annotations + ├── instances_train2014.json + ├── instances_val2014.json + ├── person_keypoints_train2014.json + ├── person_keypoints_val2014.json + ├── captions_train2014.json + ├── captions_val2014.json + Args: ann_file (str): Annotation file path. test_mode (bool): Whether dataset is used for evaluation. This will @@ -23,8 +50,52 @@ class COCORetrieval(BaseDataset): data_prefix (str | dict): Prefix for training data. Defaults to ''. pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. **kwargs: Other keyword arguments in :class:`BaseDataset`. + + Examples: + >>> from mmpretrain.datasets import COCORetrieval + >>> train_dataset=COCORetrieval(data_root='coco2014/') + >>> train_dataset + Dataset COCORetrieval + Number of samples: 414113 + Annotation file: /coco2014/annotations/captions_train2014.json + Prefix of images: /coco2014/ + >>> from mmpretrain.datasets import COCORetrieval + >>> val_dataset = COCORetrieval(data_root='coco2014/') + >>> val_dataset + Dataset COCORetrieval + Number of samples: 202654 + Annotation file: /coco2014/annotations/captions_val2014.json + Prefix of images: /coco2014/ """ + def __init__(self, + ann_file: str, + test_mode: bool = False, + data_prefix: Union[str, dict] = '', + data_root: str = '', + pipeline: Sequence = (), + **kwargs): + + if isinstance(data_prefix, str): + data_prefix = dict(img_path=expanduser(data_prefix)) + + ann_file = expanduser(ann_file) + transforms = [] + for transform in pipeline: + if isinstance(transform, dict): + transforms.append(TRANSFORMS.build(transform)) + else: + transforms.append(transform) + + super().__init__( + data_root=data_root, + data_prefix=data_prefix, + test_mode=test_mode, + pipeline=transforms, + ann_file=ann_file, + **kwargs, + ) + def load_data_list(self) -> List[dict]: """Load data list.""" # get file backend diff --git a/mmpretrain/datasets/minigpt4_dataset.py b/mmpretrain/datasets/minigpt4_dataset.py new file mode 100644 index 00000000..e14e5c35 --- /dev/null +++ b/mmpretrain/datasets/minigpt4_dataset.py @@ -0,0 +1,79 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset +from mmengine.fileio import get_file_backend + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class MiniGPT4Dataset(BaseDataset): + """Dataset for training MiniGPT4. + + MiniGPT4 dataset directory: + + minigpt4_dataset + ├── image + │ ├── id0.jpg + │ │── id1.jpg + │ │── id2.jpg + │ └── ... + └── conversation_data.json + + The structure of conversation_data.json: + + [ + // English data + { + "id": str(id0), + "conversation": "###Ask: <Img><ImageHere></Img> [Ask content] + ###Answer: [Answer content]" + }, + + // Chinese data + { + "id": str(id1), + "conversation": "###问:<Img><ImageHere></Img> [Ask content] + ###答:[Answer content]" + }, + + ... + ] + + Args: + data_root (str): The root directory for ``ann_file`` and ``image``. + ann_file (str): Conversation file path. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def load_data_list(self) -> List[dict]: + file_backend = get_file_backend(self.data_root) + conversation_path = file_backend.join_path(self.data_root, + self.ann_file) + conversation = mmengine.load(conversation_path) + img_ids = {} + n = 0 + for conv in conversation: + img_id = conv['id'] + if img_id not in img_ids.keys(): + img_ids[img_id] = n + n += 1 + + img_root = file_backend.join_path(self.data_root, 'image') + data_list = [] + for conv in conversation: + img_file = '{}.jpg'.format(conv['id']) + chat_content = conv['conversation'] + lang = 'en' if chat_content.startswith('###Ask: ') else 'zh' + data_info = { + 'image_id': img_ids[conv['id']], + 'img_path': file_backend.join_path(img_root, img_file), + 'chat_content': chat_content, + 'lang': lang, + } + + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/models/heads/mae_head.py b/mmpretrain/models/heads/mae_head.py index 1a5366d1..b76ecedd 100644 --- a/mmpretrain/models/heads/mae_head.py +++ b/mmpretrain/models/heads/mae_head.py @@ -14,15 +14,18 @@ class MAEPretrainHead(BaseModule): norm_pix_loss (bool): Whether or not normalize target. Defaults to False. patch_size (int): Patch size. Defaults to 16. + in_channels (int): Number of input channels. Defaults to 3. """ def __init__(self, loss: dict, norm_pix: bool = False, - patch_size: int = 16) -> None: + patch_size: int = 16, + in_channels: int = 3) -> None: super().__init__() self.norm_pix = norm_pix self.patch_size = patch_size + self.in_channels = in_channels self.loss_module = MODELS.build(loss) def patchify(self, imgs: torch.Tensor) -> torch.Tensor: @@ -30,19 +33,19 @@ class MAEPretrainHead(BaseModule): Args: imgs (torch.Tensor): A batch of images. The shape should - be :math:`(B, 3, H, W)`. + be :math:`(B, C, H, W)`. Returns: torch.Tensor: Patchified images. The shape is - :math:`(B, L, \text{patch_size}^2 \times 3)`. + :math:`(B, L, \text{patch_size}^2 \times C)`. """ p = self.patch_size assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 h = w = imgs.shape[2] // p - x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) + x = imgs.reshape(shape=(imgs.shape[0], self.in_channels, h, p, w, p)) x = torch.einsum('nchpwq->nhwpqc', x) - x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) + x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * self.in_channels)) return x def unpatchify(self, x: torch.Tensor) -> torch.Tensor: @@ -50,18 +53,18 @@ class MAEPretrainHead(BaseModule): Args: x (torch.Tensor): The shape is - :math:`(B, L, \text{patch_size}^2 \times 3)`. + :math:`(B, L, \text{patch_size}^2 \times C)`. Returns: - torch.Tensor: The shape is :math:`(B, 3, H, W)`. + torch.Tensor: The shape is :math:`(B, C, H, W)`. """ p = self.patch_size h = w = int(x.shape[1]**.5) assert h * w == x.shape[1] - x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) + x = x.reshape(shape=(x.shape[0], h, w, p, p, self.in_channels)) x = torch.einsum('nhwpqc->nchpwq', x) - imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) + imgs = x.reshape(shape=(x.shape[0], self.in_channels, h * p, h * p)) return imgs def construct_target(self, target: torch.Tensor) -> torch.Tensor: @@ -71,7 +74,7 @@ class MAEPretrainHead(BaseModule): normalize the image according to ``norm_pix``. Args: - target (torch.Tensor): Image with the shape of B x 3 x H x W + target (torch.Tensor): Image with the shape of B x C x H x W Returns: torch.Tensor: Tokenized images with the shape of B x L x C diff --git a/mmpretrain/models/multimodal/__init__.py b/mmpretrain/models/multimodal/__init__.py index 73645f0f..e68504c6 100644 --- a/mmpretrain/models/multimodal/__init__.py +++ b/mmpretrain/models/multimodal/__init__.py @@ -11,6 +11,7 @@ if WITH_MULTIMODAL: from .minigpt4 import * # noqa: F401, F403 from .ofa import * # noqa: F401, F403 from .otter import * # noqa: F401, F403 + from .ram import * # noqa: F401, F403 else: from mmpretrain.registry import MODELS from mmpretrain.utils.dependency import register_multimodal_placeholder @@ -19,5 +20,5 @@ else: 'Blip2Caption', 'Blip2Retrieval', 'Blip2VQA', 'BlipCaption', 'BlipNLVR', 'BlipRetrieval', 'BlipGrounding', 'BlipVQA', 'Flamingo', 'OFA', 'ChineseCLIP', 'MiniGPT4', 'Llava', 'Otter', 'CLIP', - 'CLIPZeroShot' + 'CLIPZeroShot', 'RAM', 'RAMNormal', 'RAMOpenset' ], MODELS) diff --git a/mmpretrain/models/multimodal/llava/llava.py b/mmpretrain/models/multimodal/llava/llava.py index 103d8129..f829b092 100644 --- a/mmpretrain/models/multimodal/llava/llava.py +++ b/mmpretrain/models/multimodal/llava/llava.py @@ -24,8 +24,8 @@ class Llava(BaseModel): use_im_start_end (bool): Whether to use the im_start and im_end tokens mm_vision_select_layer (int): The index from vision encoder output. Defaults to -1. - use_mm_proj (bool): Whether to enable multi-modal projection. - Defaults to True. + mm_proj_depth (int): The number of linear layers for multi-modal + projection. Defaults to 1. load_lang_pretrained (bool): Whether to load the pretrained model of language encoder. Defaults to False. generation_cfg (dict): The extra generation config, accept the keyword @@ -51,9 +51,10 @@ class Llava(BaseModel): mm_hidden_size: int, prompt_tmpl: str, task: str = 'caption', + use_im_patch: bool = True, use_im_start_end: bool = False, mm_vision_select_layer: int = -1, - use_mm_proj: bool = True, + mm_proj_depth: int = 1, generation_cfg: dict = dict(), load_lang_pretrained: bool = False, data_preprocessor: Optional[dict] = None, @@ -75,7 +76,9 @@ class Llava(BaseModel): # init tokenizer self.tokenizer = TOKENIZER.build(tokenizer) # add Llava special tokens to the tokenizer - self.tokenizer.add_tokens([self.im_patch_token], special_tokens=True) + if use_im_patch: + self.tokenizer.add_tokens([self.im_patch_token], + special_tokens=True) if use_im_start_end: self.tokenizer.add_tokens([self.im_start_token, self.im_end_token], special_tokens=True) @@ -108,14 +111,12 @@ class Llava(BaseModel): vision_encoder=vision_encoder, lang_encoder=lang_encoder, mm_hidden_size=mm_hidden_size, - use_mm_proj=use_mm_proj, + mm_proj_depth=mm_proj_depth, use_im_start_end=use_im_start_end, im_start_token=self.tokenizer.convert_tokens_to_ids( self.im_start_token), im_end_token=self.tokenizer.convert_tokens_to_ids( self.im_end_token), - im_patch_token=self.tokenizer.convert_tokens_to_ids( - self.im_patch_token), mm_vision_select_layer=mm_vision_select_layer) self.generation_cfg = generation_cfg @@ -207,16 +208,24 @@ class Llava(BaseModel): Returns: List[DataSample]: Return list of data samples. """ - prompts = [] + tokens = [] for sample in data_samples: - final_prompt = self.prompt_tmpl.format(**sample.to_dict()) - prompts.append(final_prompt) + prompt = self.prompt_tmpl.format(**sample.to_dict()) + input_ids = [] + while '<image>' in prompt: + prefix, _, prompt = prompt.partition('<image>') + input_ids.extend( + self.tokenizer(prefix, add_special_tokens=False).input_ids) + input_ids.append(-200) + if prompt: + input_ids.extend( + self.tokenizer(prompt, add_special_tokens=False).input_ids) + tokens.append(dict(input_ids=input_ids)) self.tokenizer.padding_side = 'left' - input_text = self.tokenizer( - prompts, + input_text = self.tokenizer.pad( + tokens, padding='longest', - truncation=True, return_tensors='pt', max_length=2000, ).to(device) diff --git a/mmpretrain/models/multimodal/llava/modules.py b/mmpretrain/models/multimodal/llava/modules.py index afa6eefa..fa3c6bbb 100644 --- a/mmpretrain/models/multimodal/llava/modules.py +++ b/mmpretrain/models/multimodal/llava/modules.py @@ -31,10 +31,10 @@ class LlavaLlamaForCausalLM(PreTrainedModel): lang_encoder, mm_hidden_size, use_im_start_end=True, - use_mm_proj=True, + mm_proj_depth=1, im_start_token: Optional[int] = None, im_end_token: Optional[int] = None, - im_patch_token: Optional[int] = None, + im_token_index: int = -200, mm_vision_select_layer: int = -1): super().__init__(lang_encoder.config) self.vision_tower = vision_encoder @@ -43,16 +43,26 @@ class LlavaLlamaForCausalLM(PreTrainedModel): self.use_im_start_end = use_im_start_end self.im_start_token = im_start_token self.im_end_token = im_end_token - self.im_patch_token = im_patch_token self.mm_hidden_size = mm_hidden_size self.mm_vision_select_layer = mm_vision_select_layer + self.im_token_index = im_token_index self.lang_hidden_size = lang_encoder.config.hidden_size - if use_mm_proj and not hasattr(lang_encoder.model, 'mm_projector'): + if mm_proj_depth == 1: + # Llava V1 mm_projector = nn.Linear(self.mm_hidden_size, self.lang_hidden_size) self.lang_encoder.model.add_module('mm_projector', mm_projector) - elif not use_mm_proj: + elif mm_proj_depth > 1: + # Llava V1.5 + modules = [nn.Linear(self.mm_hidden_size, self.lang_hidden_size)] + for _ in range(1, mm_proj_depth): + modules.append(nn.GELU()) + modules.append( + nn.Linear(self.lang_hidden_size, self.lang_hidden_size)) + mm_projector = nn.Sequential(*modules) + self.lang_encoder.model.add_module('mm_projector', mm_projector) + elif mm_proj_depth == 0: self.lang_encoder.model.add_module('mm_projector', nn.Identity()) self.post_init() @@ -80,16 +90,12 @@ class LlavaLlamaForCausalLM(PreTrainedModel): return_dict if return_dict is not None else self.config.use_return_dict) - # decoder outputs consists of - # (dec_features, layer_state, dec_hidden, dec_attn) - if inputs_embeds is None: - inputs_embeds = self.lang_encoder.model.embed_tokens(input_ids) - - inputs_embeds = self.forward_vision_tower(input_ids, inputs_embeds, - images) + (input_ids, attention_mask, past_key_values, inputs_embeds, + labels) = self.forward_vision_tower(input_ids, attention_mask, + past_key_values, labels, images) return self.lang_encoder( - input_ids=None, + input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -127,106 +133,93 @@ class LlavaLlamaForCausalLM(PreTrainedModel): def forward_vision_tower( self, input_ids: torch.LongTensor, - inputs_embeds: torch.FloatTensor, - images: Union[torch.FloatTensor, list, None] = None, + attention_mask: torch.LongTensor, + past_key_values: torch.FloatTensor, + labels: torch.LongTensor, + images: Union[torch.FloatTensor, None] = None, ): - if self.use_im_start_end: - assert self.im_start_token is not None - assert self.im_end_token is not None - if images is not None: - assert self.im_patch_token is not None - - if self.vision_tower is None or images is None or ( - input_ids.shape[1] == 1 and not self.training): - return inputs_embeds + if self.vision_tower is None or images is None or input_ids.shape[ + 1] == 1: + if (past_key_values is not None and self.vision_tower is not None + and images is not None and input_ids.shape[1] == 1): + attention_mask = torch.ones( + (attention_mask.shape[0], + past_key_values[-1][-1].shape[-2] + 1), + dtype=attention_mask.dtype, + device=attention_mask.device) + return input_ids, attention_mask, past_key_values, None, labels with torch.no_grad(): - if isinstance(images, (list, tuple)): - # variable length images - image_features = [] - for image in images: - feats = self.vision_tower(image.unsqueeze(0)) - image_feature = feats[self.mm_vision_select_layer][:, 1:] - image_features.append(image_feature) - else: - feats = self.vision_tower(images) - image_features = feats[self.mm_vision_select_layer][:, 1:] + # TODO: support variable number of images (single now) + feats = self.vision_tower(images) + image_features = feats[-1][:, 1:] - mm_projector = self.lang_encoder.model.mm_projector - if isinstance(images, (list, tuple)): - image_features = [ - mm_projector(image_feature)[0] - for image_feature in image_features - ] - else: - image_features = mm_projector(image_features) - - dummy_image_features = torch.zeros( - 256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype) - dummy_image_features = mm_projector(dummy_image_features) + image_features = self.lang_encoder.model.mm_projector(image_features) new_input_embeds = [] - cur_image_idx = 0 - for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds): - if (cur_input_ids != self.im_patch_token).all(): - # multimodal LLM, but the current sample is not multimodal - cur_input_embeds = cur_input_embeds + ( - 0. * dummy_image_features).sum() - new_input_embeds.append(cur_input_embeds) - cur_image_idx += 1 - continue - if self.use_im_start_end: - cur_image_features = image_features[cur_image_idx] - num_patches = cur_image_features.shape[0] - if (cur_input_ids == self.im_start_token).sum() != ( - cur_input_ids == self.im_end_token).sum(): - raise ValueError('The number of image start tokens and ' - 'image end tokens should be the same.') - image_start_tokens = torch.where( - cur_input_ids == self.im_start_token)[0] - for image_start_token_pos in image_start_tokens: - cur_image_features = image_features[cur_image_idx].to( - device=cur_input_embeds.device) - num_patches = cur_image_features.shape[0] - if cur_input_ids[image_start_token_pos + num_patches + - 1] != self.im_end_token: - raise ValueError('The image end token should follow ' - 'the image start token.') - cur_new_input_embeds = torch.cat( - (cur_input_embeds[:image_start_token_pos + 1], - cur_image_features, - cur_input_embeds[image_start_token_pos + num_patches + - 1:]), - dim=0) - cur_image_idx += 1 - new_input_embeds.append(cur_new_input_embeds) - else: - cur_image_features = image_features[cur_image_idx] - num_patches = cur_image_features.shape[0] - if (cur_input_ids == self.im_patch_token).sum() != num_patches: - print(f'Debug: num_patches: {num_patches}') - raise ValueError( - 'The number of image patch tokens should ' - 'be the same as the number of image patches.') - masked_indices = torch.where( - cur_input_ids == self.im_patch_token)[0] - mask_index_start = masked_indices[0] - if (masked_indices != torch.arange( - mask_index_start, - mask_index_start + num_patches, - device=masked_indices.device, - dtype=masked_indices.dtype)).any(): - raise ValueError( - 'The image patch tokens should be consecutive.') - cur_new_input_embeds = torch.cat( - (cur_input_embeds[:mask_index_start], cur_image_features, - cur_input_embeds[mask_index_start + num_patches:]), - dim=0) - new_input_embeds.append(cur_new_input_embeds) - cur_image_idx += 1 - inputs_embeds = torch.stack(new_input_embeds, dim=0) + new_labels = [] if labels is not None else None + new_attn_mask = [] if attention_mask is not None else None + for batch_idx, cur_input_ids in enumerate(input_ids): + cur_img = image_features[batch_idx] - return inputs_embeds + if (cur_input_ids != self.im_token_index).all(): + # multimodal LLM, but the current sample is not multimodal + new_input_embeds.append(self.embed_tokens(cur_input_ids)) + if labels is not None: + new_labels.append(labels[batch_idx]) + if attention_mask is not None: + new_attn_mask.append(attention_mask[batch_idx]) + continue + + img_idx = torch.where(cur_input_ids == self.im_token_index)[0][0] + if self.use_im_start_end: + cur_new_input_embeds = torch.cat( + [ + self.embed_tokens(cur_input_ids[:img_idx - 1]), + self.embed_tokens(cur_input_ids[img_idx - 1:img_idx]), + cur_img, + self.embed_tokens( + cur_input_ids[img_idx + 1:img_idx + 2]), + self.embed_tokens(cur_input_ids[img_idx + 2:]), + ], + dim=0, + ) + else: + cur_new_input_embeds = torch.cat( + [ + self.embed_tokens(cur_input_ids[:img_idx]), + cur_img, + self.embed_tokens(cur_input_ids[img_idx + 1:]), + ], + dim=0, + ) + new_input_embeds.append(cur_new_input_embeds) + + if labels is not None: + cur_new_labels = torch.cat([ + labels[batch_idx, :img_idx], + labels.new_full((cur_img.size(0), ), -100), + labels[batch_idx, img_idx + 1:], + ], + dim=0) + new_labels.append(cur_new_labels) + + if attention_mask is not None: + cur_attn_mask = torch.cat([ + attention_mask[batch_idx, :img_idx], + attention_mask.new_full((cur_img.size(0), ), True), + attention_mask[batch_idx, img_idx + 1:], + ], + dim=0) + new_attn_mask.append(cur_attn_mask) + + inputs_embeds = torch.stack(new_input_embeds, dim=0) + if labels is not None: + labels = torch.stack(new_labels, dim=0) + if attention_mask is not None: + attention_mask = torch.stack(new_attn_mask, dim=0) + + return None, attention_mask, past_key_values, inputs_embeds, labels @staticmethod def _reorder_cache(past_key_values, beam_idx): @@ -236,3 +229,6 @@ class LlavaLlamaForCausalLM(PreTrainedModel): past_state.index_select(0, beam_idx) for past_state in layer_past), ) return reordered_past + + def embed_tokens(self, input_ids): + return self.lang_encoder.model.embed_tokens(input_ids) diff --git a/mmpretrain/models/multimodal/minigpt4/minigpt4.py b/mmpretrain/models/multimodal/minigpt4/minigpt4.py index eccbb27e..d25d0b6b 100644 --- a/mmpretrain/models/multimodal/minigpt4/minigpt4.py +++ b/mmpretrain/models/multimodal/minigpt4/minigpt4.py @@ -31,12 +31,12 @@ class MiniGPT4(BaseModel): True. num_query_token (int): Number of query tokens of Qformer. Defaults to 32. - prompt_template (str): Prompt template of the model. Defaults to - '###Human: {} ###Assistant: '. - raw_prompts (list): Prompts for training. Defaults to None. + prompt_template (dict): Multi-language prompt template of the model. Defaults to dict([ ('en', '###Ask: {} ###Answer: '), + ('zh', '###问:{} ###答:')]) + raw_prompts (dict): Prompts for training. Defaults to dict(). max_txt_len (int): Max token length while doing tokenization. Defaults to 32. - end_sym (str): Ended symbol of the sequence. Defaults to '\\n'. + end_sym (str): Ended symbol of the sequence. Defaults to '###'. generation_cfg (dict): The config of text generation. Defaults to dict(). data_preprocessor (:obj:`BaseDataPreprocessor`): Used for @@ -54,10 +54,12 @@ class MiniGPT4(BaseModel): freeze_vit: bool = True, freeze_q_former: bool = True, num_query_token: int = 32, - prompt_template: str = '###Human: {} ###Assistant: ', - raw_prompts: Optional[list] = None, + prompt_template: dict = dict([('en', + '###Ask: {} ###Answer: '), + ('zh', '###问:{} ###答:')]), + raw_prompts: dict = dict(), max_txt_len: int = 32, - end_sym: str = '\n', + end_sym: str = '###', generation_cfg: dict = dict(), data_preprocessor: Optional[dict] = None, init_cfg: Optional[dict] = None): @@ -135,16 +137,23 @@ class MiniGPT4(BaseModel): self.end_token_id = self.llama_tokenizer.encode(end_sym)[-1] # set prompts - if raw_prompts is not None: - filted_prompts = [ - raw_prompt for raw_prompt in raw_prompts + self.en_prompt_list, self.zh_prompt_list = [], [] + if raw_prompts.get('en') is not None: + en_filted_prompts = [ + raw_prompt for raw_prompt in raw_prompts['en'] if '<ImageHere>' in raw_prompt ] - self.prompt_list = [ - prompt_template.format(p) for p in filted_prompts + self.en_prompt_list = [ + prompt_template['en'].format(p) for p in en_filted_prompts + ] + if raw_prompts.get('zh') is not None: + zh_filted_prompts = [ + raw_prompt for raw_prompt in raw_prompts['zh'] + if '<ImageHere>' in raw_prompt + ] + self.zh_prompt_list = [ + prompt_template['zh'].format(p) for p in zh_filted_prompts ] - else: - self.prompt_list = [] # update generation configs self.generation_cfg = dict( @@ -153,7 +162,7 @@ class MiniGPT4(BaseModel): do_sample=True, min_length=1, top_p=0.9, - repetition_penalty=1.0, + repetition_penalty=1.1, length_penalty=1.0, temperature=1.0) self.generation_cfg.update(**generation_cfg) @@ -161,6 +170,10 @@ class MiniGPT4(BaseModel): if hasattr(self, 'register_load_state_dict_post_hook'): self.register_load_state_dict_post_hook(self._load_llama_proj_hook) + def half(self): + self.llama_model = self.llama_model.half() + return self + def encode_img(self, images: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """The function to encode the images.""" @@ -184,33 +197,39 @@ class MiniGPT4(BaseModel): return inputs_llama, atts_llama def prompt_wrap(self, img_embeds: torch.Tensor, atts_img: torch.Tensor, - prompt: str) -> Tuple[torch.Tensor, torch.Tensor]: + prompt: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: """The function to wrap the image and prompt. - Currently, the function only supports applying one prompt to all input - images in the one batch. + Make sure that len(prompt) == img_embeds.shape[0]. Args: img_embeds (torch.Tensor): The embedding of the input images. atts_img (torch.Tensor): Attention map of the image embeddings. - prompt (str): The prompt of the batch data. + prompt (List[str]): The prompt of the batch data. Returns: Tuple[torch.Tensor, torch.Tensor]: The embedding and attention map. """ - if prompt: - batch_size = img_embeds.shape[0] - p_before, p_after = prompt.split('<ImageHere>') + if len(prompt) > 0: + p_before_list, p_after_list = [], [] + for pro in prompt: + p_before, p_after = pro.split('<ImageHere>') + p_before_list.append(p_before) + p_after_list.append(p_after) p_before_tokens = self.llama_tokenizer( - p_before, return_tensors='pt', + p_before_list, + return_tensors='pt', + padding='longest', add_special_tokens=False).to(img_embeds.device) p_after_tokens = self.llama_tokenizer( - p_after, return_tensors='pt', + p_after_list, + return_tensors='pt', + padding='longest', add_special_tokens=False).to(img_embeds.device) p_before_embeds = self.llama_model.model.embed_tokens( - p_before_tokens.input_ids).expand(batch_size, -1, -1) + p_before_tokens.input_ids) p_after_embeds = self.llama_model.model.embed_tokens( - p_after_tokens.input_ids).expand(batch_size, -1, -1) + p_after_tokens.input_ids) wrapped_img_embeds = torch.cat( [p_before_embeds, img_embeds, p_after_embeds], dim=1) wrapped_atts_img = atts_img[:, :1].expand( @@ -234,17 +253,22 @@ class MiniGPT4(BaseModel): """ img_embeds, atts_img = self.encode_img(images) - if self.task == 'caption' and self.prompt_list: - prompt = random.choice(self.prompt_list) - img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, - prompt) - self.llama_tokenizer.padding_side = 'right' - text = [t + self.end_sym for t in data_samples['text_input']] + prompts, texts = [], [] + for t in data_samples: + chat_content = t.chat_content + split_mark = '###Answer: ' if t.lang == 'en' else '###答:' + prompt, text = chat_content.split(split_mark) + prompt += split_mark + text += self.end_sym + prompts.append(prompt) + texts.append(text) + + img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompts) to_regress_tokens = self.llama_tokenizer( - text, + texts, return_tensors='pt', padding='longest', truncation=True, @@ -295,10 +319,12 @@ class MiniGPT4(BaseModel): with torch.no_grad(): img_embeds, atts_img = self.encode_img(images) - if self.task == 'caption' and self.prompt_list: - prompt = random.choice(self.prompt_list) - img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, - prompt) + prompts = [ + random.choice(self.zh_prompt_list) if hasattr(t, 'lang') + and t.lang == 'zh' else random.choice(self.en_prompt_list) + for t in data_samples + ] + img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompts) batch_size = img_embeds.shape[0] bos = torch.ones( @@ -336,7 +362,6 @@ class MiniGPT4(BaseModel): for output, data_sample in zip(outputs, data_samples): if self.task == 'caption': output = output.split('###')[0] - output = output.split('Assistant:')[-1].strip() data_sample.pred_caption = output else: # raw output diff --git a/mmpretrain/models/multimodal/ofa/ofa_modules.py b/mmpretrain/models/multimodal/ofa/ofa_modules.py index 1c79049b..ef5c8533 100644 --- a/mmpretrain/models/multimodal/ofa/ofa_modules.py +++ b/mmpretrain/models/multimodal/ofa/ofa_modules.py @@ -1301,6 +1301,7 @@ class OFAEncoderDecoder(BaseModule, GenerationMixin): Defaults to an empty dict. init_cfg (dict, optional): The initialization config. Defaults to None. """ + base_model_prefix = '' def __init__( self, diff --git a/mmpretrain/models/multimodal/ram/__init__.py b/mmpretrain/models/multimodal/ram/__init__.py new file mode 100644 index 00000000..35619d88 --- /dev/null +++ b/mmpretrain/models/multimodal/ram/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .ram import RAM, RAMNormal, RAMOpenset + +__all__ = ['RAM', 'RAMNormal', 'RAMOpenset'] diff --git a/mmpretrain/models/multimodal/ram/bert.py b/mmpretrain/models/multimodal/ram/bert.py new file mode 100644 index 00000000..f54b2ce8 --- /dev/null +++ b/mmpretrain/models/multimodal/ram/bert.py @@ -0,0 +1,1197 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modify from: +# https://github.com/xinyu1205/recognize-anything/blob/main/ram/models/bert.py + +import math +from typing import Tuple + +import torch +import torch.utils.checkpoint +from torch import Tensor, device, nn +from torch.nn import CrossEntropyLoss +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions) +from transformers.modeling_utils import (PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer) +from transformers.models.bert.configuration_bert import BertConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class BertEmbeddings_nopos(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id) + # self.position_embeddings = nn.Embedding( + # config.max_position_embeddings, config.hidden_size) + '''self.LayerNorm is not snake-cased to stick with + TensorFlow model variable name and be able to load''' + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous + # in memory and exported when serialized + # self.register_buffer("position_ids", + # torch.arange(config.max_position_embeddings).expand((1, -1))) + # self.position_embedding_type = \ + # getattr(config, "position_embedding_type", "absolute") + + self.config = config + + def forward(self, + input_ids=None, + position_ids=None, + inputs_embeds=None, + past_key_values_length=0): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] # noqa: F841 + + # if position_ids is None: + # position_ids = self.position_ids[:, \ + # past_key_values_length : seq_length + \ + # past_key_values_length] + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + embeddings = inputs_embeds + + # if self.position_embedding_type == "absolute": + # position_embeddings = self.position_embeddings(position_ids) + # # print('add position_embeddings!!!!') + # embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with + # TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous + # in memory and exported when serialized + self.register_buffer( + 'position_ids', + torch.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr(config, + 'position_embedding_type', + 'absolute') + + self.config = config + + def forward(self, + input_ids=None, + position_ids=None, + inputs_embeds=None, + past_key_values_length=0): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length: + seq_length + + past_key_values_length] + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + embeddings = inputs_embeds + + if self.position_embedding_type == 'absolute': + position_embeddings = self.position_embeddings(position_ids) + # print('add position_embeddings!!!!') + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and \ + not hasattr(config, 'embedding_size'): + raise ValueError('''The hidden size (%d) is not a multiple of + the number of attention heads (%d)''' % + (config.hidden_size, config.num_attention_heads)) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / + config.num_attention_heads) + self.all_head_size = self.num_attention_heads * \ + self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, + 'position_embedding_type', + 'absolute') + if (self.position_embedding_type == 'relative_key' + or self.position_embedding_type == 'relative_key_query'): + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, + self.attention_head_size) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + # print(self.key.weight.shape) + key_layer = self.transpose_for_scores( + self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores( + self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # compatible with higher versions of transformers + if key_layer.shape[0] > query_layer.shape[0]: + key_layer = key_layer[:query_layer.shape[0], :, :, :] + attention_mask = attention_mask[:query_layer.shape[0], :, :] + value_layer = value_layer[:query_layer.shape[0], :, :, :] + + # Take the dot product between "query" and "key" + # to get the raw attention scores. + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + + if (self.position_embedding_type == 'relative_key' + or self.position_embedding_type == 'relative_key_query'): + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == 'relative_key': + relative_position_scores = torch.einsum( + 'bhld,lrd->bhlr', query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == 'relative_key_query': + relative_position_scores_query = torch.einsum( + 'bhld,lrd->bhlr', query_layer, positional_embedding) + relative_position_scores_key = torch.einsum( + 'bhrd,lrd->bhlr', key_layer, positional_embedding) + attention_scores = attention_scores + \ + relative_position_scores_query + \ + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt( + self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for + # all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, + attention_probs) if output_attentions else (context_layer, ) + + outputs = outputs + (past_key_value, ) + return outputs + + +class BertSelfOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, + self.self.attention_head_size, self.pruned_heads) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len( + heads) + self.self.all_head_size = self.self.attention_head_size * \ + self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output, + ) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if self.config.add_cross_attention: + self.crossattention = BertAttention( + config, is_cross_attention=self.config.add_cross_attention) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + mode=None, + ): + + if mode == 'tagging': + + assert encoder_hidden_states is not None, \ + '''encoder_hidden_states must be given + for cross-attention layers''' + + cross_attention_outputs = self.crossattention( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = cross_attention_outputs[ + 1:-1] # add cross attentions if we output attention weights + + present_key_value = cross_attention_outputs[-1] + + else: + # decoder uni-directional self-attention + # cached key/values tuple is at positions 1,2 + self_attn_past_key_value = \ + (past_key_value[:2] + if past_key_value is not None else None) + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + + if mode == 'multimodal': + assert encoder_hidden_states is not None, \ + '''encoder_hidden_states must be + given for cross-attention layers''' + + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[ + 1: + -1] # add cross attentions if we output attention weights + layer_output = apply_chunking_to_forward(self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output) + outputs = (layer_output, ) + outputs + + outputs = outputs + (present_key_value, ) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + mode='multimodal', + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + ) if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[ + i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warn('''`use_cache=True` is incompatible with + gradient checkpointing. Setting `use_cache=False`...''' + ) + use_cache = False + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, past_key_value, + output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + mode=mode, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + mode=mode, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1], ) + if output_attentions: + all_self_attentions = all_self_attentions + ( + layer_outputs[1], ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear( + config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that + # the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """An abstract class to handle weights initialization and a simple + interface for downloading and loading pretrained models.""" + + config_class = BertConfig + base_model_prefix = 'bert' + _keys_to_ignore_on_load_missing = [r'position_ids'] + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version + # which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """The model can behave as an encoder (with only self-attention) as well as + a decoder, in which case a layer of cross-attention is added between the + self-attention layers, following the architecture described in `Attention + is all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, + Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. + + Gomez, Lukasz Kaiser and Illia Polosukhin. argument and + :obj:`add_cross_attention` set to :obj:`True`; an + :obj:`encoder_hidden_states` is then expected as an input to the forward + pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """Prunes heads of the model. + + heads_to_prune: + dict of {layer_num: list of heads to prune in this layer} + See base class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask(self, attention_mask: Tensor, + input_shape: Tuple[int], device: device, + is_decoder: bool) -> Tensor: + """Makes broadcastable attention and causal masks so that future and + masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, + zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, + with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions + # [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it + # broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask + # in addition to the padding mask + # - if the model is an encoder, make the mask + # broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat( + batch_size, seq_length, 1) <= seq_ids[None, :, None] + # in case past_key_values are used we need to + # add a prefix ones mask to the causal mask + # causal and attention masks must have same type + # with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[ + 1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, seq_length, prefix_seq_len), + device=device, + dtype=causal_mask.dtype), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = ( + causal_mask[:None, :, :] * + attention_mask[:, None, None, :]) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + '''Wrong shape for input_ids (shape {}) or attention_mask + (shape {})'''.format(input_shape, attention_mask.shape)) + + # Since attention_mask is 1.0 + # for positions we want to attend and 0.0 + # for masked positions, this operation will + # create a tensor which is 0.0 for positions + # we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores + # before the softmax, this is effectively + # the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + mode='multimodal', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj: + `(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer + of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj: + `(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token + indices of the encoder input. This mask is used in + the cross-attention if the model is configured as + a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length : + obj:`config.n_layers` with each tuple having 4 tensors of shape : + obj:`(batch_size, num_heads, sequence_length - 1, + embed_size_per_head)`): + Contains precomputed key and value hidden states of the + attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally + input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to + this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj: + `(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value + states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions if output_attentions is not None else + self.config.output_attentions) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = ( + return_dict + if return_dict is not None else self.config.use_return_dict) + + if is_decoder: + use_cache = ( + use_cache if use_cache is not None else self.config.use_cache) + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError('''You cannot specify both + input_ids and inputs_embeds at the same time''') + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + device = input_ids.device + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = inputs_embeds.device + elif encoder_embeds is not None: + input_shape = encoder_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = encoder_embeds.device + else: + raise ValueError('''You have to specify either + input_ids or inputs_embeds or encoder_embeds''') + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[ + 2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), + device=device) + + # We can provide a self-attention mask of dimensions + # [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to + # make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = \ + (self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder)) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to + # [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = \ + (encoder_hidden_states[0].size()) + else: + encoder_batch_size, encoder_sequence_length, _ = \ + (encoder_hidden_states.size()) + encoder_hidden_shape = (encoder_batch_size, + encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) + for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape + # [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape + # [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, + self.config.num_hidden_layers) + + if encoder_embeds is None: + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + else: + embedding_output = encoder_embeds + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + mode=mode, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler( + sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r'pooler'] + _keys_to_ignore_on_load_missing = [ + r'position_ids', r'predictions.decoder.bias' + ] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction='mean', + mode='multimodal', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj: + `(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer + of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj: + `(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token + indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. + Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj: + `(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right + language modeling loss (next word prediction). + Indices should be in + ``[-100, 0, ..., config.vocab_size]`` + (see ``input_ids`` docstring) Tokens with indices set to + ``-100`` are ignored (masked), the loss is only computed + for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length + :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj: + `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention + blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally + input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to + this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj: + `(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states + are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import (BertTokenizer, + BertLMHeadModel, BertConfig) + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained( + 'bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", + return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = ( + return_dict + if return_dict is not None else self.config.use_return_dict) + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + mode=mode, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + # sequence_output.shape torch.Size([85, 30, 768]) + # prediction_scores.shape torch.Size([85, 30, 30524]) + # labels.shape torch.Size([85, 30]) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift + # prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, : + -1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss( + reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1)) + if reduction == 'none': + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + if not return_dict: + output = (prediction_scores, ) + outputs[2:] + return ((lm_loss, ) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, + input_ids, + past=None, + attention_mask=None, + **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, + # the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + 'input_ids': + input_ids, + 'attention_mask': + attention_mask, + 'past_key_values': + past, + 'encoder_hidden_states': + model_kwargs.get('encoder_hidden_states', None), + 'encoder_attention_mask': + model_kwargs.get('encoder_attention_mask', None), + 'is_decoder': + True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past), ) + return reordered_past diff --git a/mmpretrain/models/multimodal/ram/config/__init__.py b/mmpretrain/models/multimodal/ram/config/__init__.py new file mode 100644 index 00000000..ef101fec --- /dev/null +++ b/mmpretrain/models/multimodal/ram/config/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/mmpretrain/models/multimodal/ram/config/ram_swin_large_14m.py b/mmpretrain/models/multimodal/ram/config/ram_swin_large_14m.py new file mode 100644 index 00000000..e4b88653 --- /dev/null +++ b/mmpretrain/models/multimodal/ram/config/ram_swin_large_14m.py @@ -0,0 +1,93 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# data settings +test_transforms_cfg = [ + dict(type='Resize', scale=(384, 384), interpolation='bicubic'), + dict( + type='mmpretrain.PackInputs', + algorithm_keys=['text'], + meta_keys=['image_id', 'scale_factor'], + ), +] + + +def get_ram_cfg(mode='normal'): + assert mode in ['normal', 'openset'], 'mode must "normal" or "openset"' + model_type = 'RAMNormal' if mode == 'normal' else 'RAMOpenset' + model_cfg = dict( + type=model_type, + tokenizer=dict( + type='BertTokenizer', + name_or_path='/public/DATA/qbw/ckpt/bert-base-uncased', + use_fast=False), + vision_backbone=dict( + type='SwinTransformer', + arch='large', + img_size=384, + window_size=12, + ), + tag_encoder={ + 'architectures': ['BertModel'], + 'attention_probs_dropout_prob': 0.1, + 'hidden_act': 'gelu', + 'hidden_dropout_prob': 0.1, + 'hidden_size': 768, + 'initializer_range': 0.02, + 'intermediate_size': 3072, + 'layer_norm_eps': 1e-12, + 'max_position_embeddings': 512, + 'model_type': 'bert', + 'num_attention_heads': 12, + 'num_hidden_layers': 12, + 'pad_token_id': 0, + 'type_vocab_size': 2, + 'vocab_size': 30524, + 'encoder_width': 512, + 'add_cross_attention': True + }, + text_decoder={ + 'architectures': ['BertModel'], + 'attention_probs_dropout_prob': 0.1, + 'hidden_act': 'gelu', + 'hidden_dropout_prob': 0.1, + 'hidden_size': 768, + 'initializer_range': 0.02, + 'intermediate_size': 3072, + 'layer_norm_eps': 1e-12, + 'max_position_embeddings': 512, + 'model_type': 'bert', + 'num_attention_heads': 12, + 'num_hidden_layers': 12, + 'pad_token_id': 0, + 'type_vocab_size': 2, + 'vocab_size': 30524, + 'encoder_width': 768, + 'add_cross_attention': True + }, + tagging_head={ + 'architectures': ['BertModel'], + 'attention_probs_dropout_prob': 0.1, + 'hidden_act': 'gelu', + 'hidden_dropout_prob': 0.1, + 'hidden_size': 768, + 'initializer_range': 0.02, + 'intermediate_size': 3072, + 'layer_norm_eps': 1e-12, + 'max_position_embeddings': 512, + 'model_type': 'bert', + 'num_attention_heads': 4, + 'num_hidden_layers': 2, + 'pad_token_id': 0, + 'type_vocab_size': 2, + 'vocab_size': 30522, + 'encoder_width': 512, + 'add_cross_attention': True, + 'add_tag_cross_attention': False + }, + data_preprocessor=dict( + type='MultiModalDataPreprocessor', + mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255], + std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255], + to_rgb=False, + ), + ) + return model_cfg diff --git a/mmpretrain/models/multimodal/ram/data/ram_tag_list.pickle b/mmpretrain/models/multimodal/ram/data/ram_tag_list.pickle new file mode 100644 index 00000000..0519d1ee Binary files /dev/null and b/mmpretrain/models/multimodal/ram/data/ram_tag_list.pickle differ diff --git a/mmpretrain/models/multimodal/ram/data/ram_tag_list_chinese.pickle b/mmpretrain/models/multimodal/ram/data/ram_tag_list_chinese.pickle new file mode 100644 index 00000000..4abe105e Binary files /dev/null and b/mmpretrain/models/multimodal/ram/data/ram_tag_list_chinese.pickle differ diff --git a/mmpretrain/models/multimodal/ram/data/ram_tag_list_threshold.pickle b/mmpretrain/models/multimodal/ram/data/ram_tag_list_threshold.pickle new file mode 100644 index 00000000..2be681d6 Binary files /dev/null and b/mmpretrain/models/multimodal/ram/data/ram_tag_list_threshold.pickle differ diff --git a/mmpretrain/models/multimodal/ram/gradio_demo.py b/mmpretrain/models/multimodal/ram/gradio_demo.py new file mode 100644 index 00000000..206e6b40 --- /dev/null +++ b/mmpretrain/models/multimodal/ram/gradio_demo.py @@ -0,0 +1,109 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse + +import gradio as gr +import torch + +from mmpretrain.registry import MODELS, TRANSFORMS +from .config.ram_swin_large_14m import get_ram_cfg, test_transforms_cfg +from .run.inference import inference + +parser = argparse.ArgumentParser( + description='RAM(Recognize Anything Model) demo') +parser.add_argument( + 'ram_ckpt', type=str, help='pretrained file for ram (absolute path)') +parser.add_argument( + 'clip_ckpt', + type=str, + help='clip vit-base-p16 pretrained file (absolute path)') +args = parser.parse_args() + +if torch.cuda.is_available(): + devices = [ + torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count()) + ] +elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + devices = [torch.device('mps')] +else: + devices = [torch.device('cpu')] + + +def get_free_device(): + if hasattr(torch.cuda, 'mem_get_info'): + free = [torch.cuda.mem_get_info(gpu)[0] for gpu in devices] + select = max(zip(free, range(len(free))))[1] + else: + import random + select = random.randint(0, len(devices) - 1) + return devices[select] + + +device = get_free_device() + + +def ram_inference(image, tag_list, mode, threshold): + test_transforms = TRANSFORMS.get('Compose')(transforms=test_transforms_cfg) + model = MODELS.build(get_ram_cfg(mode=mode)) + model.load_state_dict(torch.load(args.ram_ckpt)) + model.device = device + + if mode == 'openset': + categories = tag_list + if categories != '': + categories = categories.strip().split() + else: + categories = None + model.set_openset( + categories=categories, + clip_ckpt=args.clip_ckpt, + threshold=threshold) + + sample = dict(img=image) + result = inference(sample, model, test_transforms, mode=mode) + tag, tag_chinese, logits = \ + result.get('tag_output')[0][0], result.get('tag_output')[1][0],\ + result.get('logits_output')[0] + + def wrap(tags, logits): + if tags is None: + return 'Openset mode has no tag_en' + tag_lst = tags.split('|') + rt_lst = [] + for i, tag in enumerate(tag_lst): + tag = tag.strip() + rt_lst.append(tag + f': {logits[i]:.2f}') + return ' | '.join(rt_lst) + + return [wrap(tag, logits), wrap(tag_chinese, logits)] + + +def build_gradio(): + inputs = [ + gr.components.Image(label='image'), + gr.components.Textbox( + lines=2, + label='tag_list', + placeholder= + 'please input the categories split by keyboard "blank": ', + value=''), + gr.components.Radio(['normal', 'openset'], + label='mode', + value='normal'), + gr.components.Slider( + minimum=0, maximum=1, value=0.68, step=0.01, label='threshold') + ] + return gr.Interface( + fn=ram_inference, + inputs=inputs, + outputs=[ + gr.components.Textbox(), + gr.components.Textbox(info="it's translated from the english tags") + ]) + + +def main(): + build_gradio().launch() + + +if __name__ == '__main__': + main() diff --git a/mmpretrain/models/multimodal/ram/openset_utils.py b/mmpretrain/models/multimodal/ram/openset_utils.py new file mode 100644 index 00000000..5fa0f52e --- /dev/null +++ b/mmpretrain/models/multimodal/ram/openset_utils.py @@ -0,0 +1,212 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmpretrain.registry import MODELS + + +def article(name): + return 'an' if name[0] in 'aeiou' else 'a' + + +def processed_name(name, rm_dot=False): + # _ for lvis + # / for obj365 + res = name.replace('_', ' ').replace('/', ' or ').lower() + if rm_dot: + res = res.rstrip('.') + return res + + +single_template = ['a photo of a {}.'] + +multiple_templates = [ + 'There is {article} {} in the scene.', + 'There is the {} in the scene.', + 'a photo of {article} {} in the scene.', + 'a photo of the {} in the scene.', + 'a photo of one {} in the scene.', + 'itap of {article} {}.', + 'itap of my {}.', # itap: I took a picture of + 'itap of the {}.', + 'a photo of {article} {}.', + 'a photo of my {}.', + 'a photo of the {}.', + 'a photo of one {}.', + 'a photo of many {}.', + 'a good photo of {article} {}.', + 'a good photo of the {}.', + 'a bad photo of {article} {}.', + 'a bad photo of the {}.', + 'a photo of a nice {}.', + 'a photo of the nice {}.', + 'a photo of a cool {}.', + 'a photo of the cool {}.', + 'a photo of a weird {}.', + 'a photo of the weird {}.', + 'a photo of a small {}.', + 'a photo of the small {}.', + 'a photo of a large {}.', + 'a photo of the large {}.', + 'a photo of a clean {}.', + 'a photo of the clean {}.', + 'a photo of a dirty {}.', + 'a photo of the dirty {}.', + 'a bright photo of {article} {}.', + 'a bright photo of the {}.', + 'a dark photo of {article} {}.', + 'a dark photo of the {}.', + 'a photo of a hard to see {}.', + 'a photo of the hard to see {}.', + 'a low resolution photo of {article} {}.', + 'a low resolution photo of the {}.', + 'a cropped photo of {article} {}.', + 'a cropped photo of the {}.', + 'a close-up photo of {article} {}.', + 'a close-up photo of the {}.', + 'a jpeg corrupted photo of {article} {}.', + 'a jpeg corrupted photo of the {}.', + 'a blurry photo of {article} {}.', + 'a blurry photo of the {}.', + 'a pixelated photo of {article} {}.', + 'a pixelated photo of the {}.', + 'a black and white photo of the {}.', + 'a black and white photo of {article} {}.', + 'a plastic {}.', + 'the plastic {}.', + 'a toy {}.', + 'the toy {}.', + 'a plushie {}.', + 'the plushie {}.', + 'a cartoon {}.', + 'the cartoon {}.', + 'an embroidered {}.', + 'the embroidered {}.', + 'a painting of the {}.', + 'a painting of a {}.', +] + +openimages_rare_unseen = [ + 'Aerial photography', 'Aircraft engine', 'Ale', 'Aloe', 'Amphibian', + 'Angling', 'Anole', 'Antique car', 'Arcade game', 'Arthropod', + 'Assault rifle', 'Athletic shoe', 'Auto racing', 'Backlighting', + 'Bagpipes', 'Ball game', 'Barbecue chicken', 'Barechested', 'Barquentine', + 'Beef tenderloin', 'Billiard room', 'Billiards', 'Bird of prey', + 'Black swan', 'Black-and-white', 'Blond', 'Boating', 'Bonbon', + 'Bottled water', 'Bouldering', 'Bovine', 'Bratwurst', 'Breadboard', + 'Briefs', 'Brisket', 'Brochette', 'Calabaza', 'Camera operator', 'Canola', + 'Childbirth', 'Chordophone', 'Church bell', 'Classical sculpture', + 'Close-up', 'Cobblestone', 'Coca-cola', 'Combat sport', 'Comics', + 'Compact car', 'Computer speaker', 'Cookies and crackers', + 'Coral reef fish', 'Corn on the cob', 'Cosmetics', 'Crocodilia', + 'Digital camera', 'Dishware', 'Divemaster', 'Dobermann', 'Dog walking', + 'Domestic rabbit', 'Domestic short-haired cat', 'Double-decker bus', + 'Drums', 'Electric guitar', 'Electric piano', 'Electronic instrument', + 'Equestrianism', 'Equitation', 'Erinaceidae', 'Extreme sport', 'Falafel', + 'Figure skating', 'Filling station', 'Fire apparatus', 'Firearm', + 'Flatbread', 'Floristry', 'Forklift truck', 'Freight transport', + 'Fried food', 'Fried noodles', 'Frigate', 'Frozen yogurt', 'Frying', + 'Full moon', 'Galleon', 'Glacial landform', 'Gliding', 'Go-kart', 'Goats', + 'Grappling', 'Great white shark', 'Gumbo', 'Gun turret', 'Hair coloring', + 'Halter', 'Headphones', 'Heavy cruiser', 'Herding', 'High-speed rail', + 'Holding hands', 'Horse and buggy', 'Horse racing', 'Hound', + 'Hunting knife', 'Hurdling', 'Inflatable', 'Jackfruit', 'Jeans', 'Jiaozi', + 'Junk food', 'Khinkali', 'Kitesurfing', 'Lawn game', 'Leaf vegetable', + 'Lechon', 'Lifebuoy', 'Locust', 'Lumpia', 'Luxury vehicle', 'Machine tool', + 'Medical imaging', 'Melee weapon', 'Microcontroller', 'Middle ages', + 'Military person', 'Military vehicle', 'Milky way', 'Miniature Poodle', + 'Modern dance', 'Molluscs', 'Monoplane', 'Motorcycling', 'Musical theatre', + 'Narcissus', 'Nest box', 'Newsagent\'s shop', 'Nile crocodile', + 'Nordic skiing', 'Nuclear power plant', 'Orator', 'Outdoor shoe', + 'Parachuting', 'Pasta salad', 'Peafowl', 'Pelmeni', 'Perching bird', + 'Performance car', 'Personal water craft', 'Pit bull', 'Plant stem', + 'Pork chop', 'Portrait photography', 'Primate', 'Procyonidae', + 'Prosciutto', 'Public speaking', 'Racewalking', 'Ramen', + 'Rear-view mirror', 'Residential area', 'Ribs', 'Rice ball', + 'Road cycling', 'Roller skating', 'Roman temple', 'Rowing', 'Rural area', + 'Sailboat racing', 'Scaled reptile', 'Scuba diving', 'Senior citizen', + 'Shallot', 'Shinto shrine', 'Shooting range', 'Siberian husky', 'Sledding', + 'Soba', 'Solar energy', 'Sport climbing', 'Sport utility vehicle', + 'Steamed rice', 'Stemware', 'Sumo', 'Surfing Equipment', 'Team sport', + 'Touring car', 'Toy block', 'Trampolining', 'Underwater diving', + 'Vegetarian food', 'Wallaby', 'Water polo', 'Watercolor paint', 'Whiskers', + 'Wind wave', 'Woodwind instrument', 'Yakitori', 'Zeppelin' +] + + +def get_clip_model(): + model = dict( + type='CLIPZeroShot', + vision_backbone=dict( + type='VisionTransformer', + arch='base', + img_size=224, + patch_size=16, + drop_rate=0., + layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')), + pre_norm=True, + ), + projection=dict( + type='CLIPProjection', in_channels=768, out_channels=512), + text_backbone=dict( + type='CLIPTransformer', + width=512, + layers=12, + heads=8, + attn_mask=True, + ), + tokenizer=dict( + type='AutoTokenizer', + name_or_path='openai/clip-vit-base-patch16', + use_fast=False), + vocab_size=49408, + transformer_width=512, + proj_dim=512, + context_length=77, + data_preprocessor=dict( + type='MultiModalDataPreprocessor', + mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255], + std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255], + to_rgb=False, + ), + ) + return MODELS.build(model) + + +def build_openset_label_embedding(categories=None, clip_ckpt_path=''): + if categories is None: + print('Categories is None, so using rare_unseen categories') + categories = openimages_rare_unseen + model = get_clip_model() + model.load_state_dict(torch.load(clip_ckpt_path)) + templates = multiple_templates + + run_on_gpu = torch.cuda.is_available() + + with torch.no_grad(): + openset_label_embedding = [] + for category in categories: + texts = [ + template.format( + processed_name(category, rm_dot=True), + article=article(category)) for template in templates + ] + texts = [ + 'This is ' + text + if text.startswith('a') or text.startswith('the') else text + for text in texts + ] + texts = model.tokenize(texts) # tokenize + if run_on_gpu: + texts = texts.cuda() + model = model.cuda() + text_embeddings = model.extract_text_feat(texts) + text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True) + text_embedding = text_embeddings.mean(dim=0) + text_embedding /= text_embedding.norm() + openset_label_embedding.append(text_embedding) + openset_label_embedding = torch.stack(openset_label_embedding, dim=1) + if run_on_gpu: + openset_label_embedding = openset_label_embedding.cuda() + + openset_label_embedding = openset_label_embedding.t() + return openset_label_embedding, categories diff --git a/mmpretrain/models/multimodal/ram/ram.py b/mmpretrain/models/multimodal/ram/ram.py new file mode 100644 index 00000000..c5d22f07 --- /dev/null +++ b/mmpretrain/models/multimodal/ram/ram.py @@ -0,0 +1,332 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import pickle +from abc import abstractmethod +from typing import List, Optional + +import numpy as np +import torch +import torch.nn as nn +from mmengine.model import BaseModel + +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample +from .bert import BertConfig, BertLMHeadModel, BertModel +from .openset_utils import build_openset_label_embedding +from .utils import tie_encoder_decoder_weights + + +def get_path(path): + file_path = os.path.abspath(os.path.dirname(__file__)) + if not os.path.isabs(path): + return os.path.join(file_path, path) + + +class RAM(BaseModel): + """The implementation of `RAM <https://arxiv.org/abs/2306.03514>`_.""" + + def __init__(self, + tokenizer: dict, + vision_backbone: dict, + tag_encoder: dict, + tagging_head: dict, + text_decoder: dict, + device: str = 'cpu', + vision_width: int = 1536, + prompt='a picture of ', + threshold=0.68, + delete_tag_index=[], + tag_list='./data/ram_tag_list.pickle', + tag_list_chinese='./data/ram_tag_list_chinese.pickle', + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if data_preprocessor is None: + data_preprocessor = {} + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + + self.device = device + # build the visual encoder + self.visual_encoder = MODELS.build(vision_backbone) + + # build the tokenizer + self.tokenizer = TOKENIZER.build(tokenizer) + self.tokenizer.add_special_tokens({'bos_token': '[DEC]'}) + self.tokenizer.add_special_tokens( + {'additional_special_tokens': ['[ENC]']}) + self.tokenizer.enc_token_id = \ + self.tokenizer.additional_special_tokens_ids[0] + + # build the tag encoder + # encoder_config = BertConfig.from_json_file(med_config) + # encoder_config.encoder_width = 512 + encoder_config = BertConfig.from_dict(tag_encoder) + self.tag_encoder = BertModel( + config=encoder_config, add_pooling_layer=False) + + # build image-tag-text decoder + # decoder_config = BertConfig.from_json_file(med_config) + decoder_config = BertConfig.from_dict(text_decoder) + self.text_decoder = BertLMHeadModel(config=decoder_config) + + self.delete_tag_index = delete_tag_index + self.prompt = prompt + self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1 + + # load tag list + self.tag_list = self.load_tag_list(get_path(tag_list)) + self.tag_list_chinese = self.load_tag_list(get_path(tag_list_chinese)) + + # create image-tag recognition decoder + self.threshold = threshold + self.num_class = len(self.tag_list) + # q2l_config = \ + # BertConfig.from_json_file(f'{CONFIG_PATH}/configs/q2l_config.json') + # q2l_config.encoder_width = 512 + q2l_config = BertConfig.from_dict(tagging_head) + self.tagging_head = BertModel( + config=q2l_config, add_pooling_layer=False) + self.tagging_head.resize_token_embeddings(len(self.tokenizer)) + self.label_embed = nn.Parameter( + torch.zeros(self.num_class, q2l_config.encoder_width)) + + if q2l_config.hidden_size != 512: + self.wordvec_proj = nn.Linear(512, q2l_config.hidden_size) + else: + self.wordvec_proj = nn.Identity() + + self.fc = nn.Linear(q2l_config.hidden_size, 1) + + self.del_selfattention() + + # share weights of the lowest 2-layer of + # "image-tag interaction encoder" with + # the "image-tag recogntion decoder" + tie_encoder_decoder_weights(self.tag_encoder, self.tagging_head, '', + ' ') + self.image_proj = nn.Linear(vision_width, 512) + # self.label_embed = nn.Parameter(torch.load( + # f'{CONFIG_PATH}/data/textual_label_embedding.pth', + # map_location='cpu').float()) + + # adjust thresholds for some tags + self.class_threshold = torch.ones(self.num_class) * self.threshold + ram_class_threshold_path = get_path( + './data/ram_tag_list_threshold.pickle') + with open(ram_class_threshold_path, 'rb') as f: + ram_class_threshold = pickle.load(f) + for key, value in enumerate(ram_class_threshold): + self.class_threshold[key] = value + + def load_tag_list(self, tag_list_file): + with open(tag_list_file, 'rb') as f: + tag_list = pickle.load(f) + tag_list = np.array(tag_list) + return tag_list + + # delete self-attention layer of image-tag recognition decoder + # to reduce computation, follower Query2Label + def del_selfattention(self): + del self.tagging_head.embeddings + for layer in self.tagging_head.encoder.layer: + del layer.attention + + def get_label_embed(self): + return torch.nn.functional.relu(self.wordvec_proj(self.label_embed)) + + def extract_visual_feature(self, images): + image_embeds = self.visual_encoder(images)[0] + image_embeds = image_embeds.flatten(2, 3) + attn_pool = nn.AdaptiveAvgPool1d(1) + cls_token = attn_pool(image_embeds).permute(0, 2, 1).contiguous() + image_embeds = image_embeds.permute(0, 2, 1).contiguous() + image_embeds = torch.cat([cls_token, image_embeds], dim=1) + image_embeds = self.image_proj(image_embeds) + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long).to(images.device) + return image_embeds, image_atts + + def image2tag(self, label_embed, image_embeds, image_atts): + # recognized image tags using image-tag recogntiion decoder + # image_cls_embeds = image_embeds[:, 0, :] + image_spatial_embeds = image_embeds[:, 1:, :] + + bs = image_spatial_embeds.shape[0] + label_embed = label_embed.unsqueeze(0).repeat(bs, 1, 1) + tagging_embed = self.tagging_head( + encoder_embeds=label_embed, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=False, + mode='tagging', + ) + + logits = self.fc(tagging_embed[0]).squeeze(-1) + return logits + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[list] = None, + mode: str = 'predict', + **kwargs, + ): + if mode == 'predict': + return self.predict(images, data_samples, **kwargs) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + @abstractmethod + def predict(self, + images: torch.Tensor, + data_samples: DataSample = None) -> DataSample: + raise NotImplementedError + + +@MODELS.register_module() +class RAMNormal(RAM): + + def __init__(self, + tokenizer: dict, + vision_backbone: dict, + tag_encoder: dict, + tagging_head: dict, + text_decoder: dict, + device: str = 'cpu', + vision_width: int = 1536, + prompt='a picture of ', + threshold=0.68, + delete_tag_index=[], + tag_list='./data/ram_tag_list.pickle', + tag_list_chinese='./data/ram_tag_list_chinese.pickle', + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + super().__init__( + tokenizer, + vision_backbone, + tag_encoder, + tagging_head, + text_decoder, + device, + vision_width, + prompt, + threshold, + delete_tag_index, + tag_list, + tag_list_chinese, + data_preprocessor, + init_cfg, + ) + + def tag_process(self, logits): + targets = torch.where( + torch.sigmoid(logits) > self.class_threshold.to(logits.device), + torch.tensor(1.0).to(logits.device), + torch.zeros(self.num_class).to(logits.device)) + + tag = targets.cpu().numpy() + tag[:, self.delete_tag_index] = 0 + tag_output = [] + tag_output_chinese = [] + logits_output = [] + + bs = logits.shape[0] + for b in range(bs): + index = np.argwhere(tag[b] == 1) + token = self.tag_list[index].squeeze(axis=1) + logits_output.append( + torch.sigmoid(logits)[b][index[:, 0]].cpu().numpy()) + tag_output.append(' | '.join(token)) + token_chinese = self.tag_list_chinese[index].squeeze(axis=1) + tag_output_chinese.append(' | '.join(token_chinese)) + + return [(tag_output, tag_output_chinese), logits_output] + + def predict(self, + images: torch.Tensor, + data_samples: DataSample = None) -> DataSample: + self.eval() + self.to(self.device) + images = images.to(self.device) + label_embed = self.get_label_embed() + image_embeds, image_atts = self.extract_visual_feature(images) + logits = self.image2tag(label_embed, image_embeds, image_atts) + tag_output, logits_output = self.tag_process(logits) + data_samples.set_field(logits_output, 'logits_output') + data_samples.set_field(tag_output, 'tag_output') + return data_samples + + +@MODELS.register_module() +class RAMOpenset(RAMNormal): + + def __init__(self, + tokenizer: dict, + vision_backbone: dict, + tag_encoder: dict, + tagging_head: dict, + text_decoder: dict, + device: str = 'cpu', + vision_width: int = 1536, + prompt='a picture of ', + threshold=0.68, + delete_tag_index=[], + tag_list='./data/ram_tag_list.pickle', + tag_list_chinese='./data/ram_tag_list_chinese.pickle', + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + super().__init__( + tokenizer, + vision_backbone, + tag_encoder, + tagging_head, + text_decoder, + device, + vision_width, + prompt, + threshold, + delete_tag_index, + tag_list, + tag_list_chinese, + data_preprocessor, + init_cfg, + ) + + def set_openset(self, + categories: List[str] = None, + clip_ckpt: str = '', + threshold: float = 0.68): + openset_label_embedding, openset_categories = \ + build_openset_label_embedding( + categories, clip_ckpt + ) + self.tag_list = np.array(openset_categories) + self.label_embed = nn.Parameter(openset_label_embedding.float()) + self.num_class = len(openset_categories) + + # the threshold for unseen categories is often lower + self.class_threshold = torch.ones(self.num_class) * threshold + + def tag_process(self, logits): + targets = torch.where( + torch.sigmoid(logits) > self.class_threshold.to(logits.device), + torch.tensor(1.0).to(logits.device), + torch.zeros(self.num_class).to(logits.device)) + + tag = targets.cpu().numpy() + tag[:, self.delete_tag_index] = 0 + + bs = logits.shape[0] + tag_output = [] + logits_output = [] + for b in range(bs): + index = np.argwhere(tag[b] == 1) + token = self.tag_list[index].squeeze(axis=1) + logits_output.append( + torch.sigmoid(logits)[b][index[:, 0]].cpu().numpy()) + tag_output.append(' | '.join(token)) + + return [(tag_output, [None]), logits_output] diff --git a/mmpretrain/models/multimodal/ram/run/__init__.py b/mmpretrain/models/multimodal/ram/run/__init__.py new file mode 100644 index 00000000..ef101fec --- /dev/null +++ b/mmpretrain/models/multimodal/ram/run/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/mmpretrain/models/multimodal/ram/run/inference.py b/mmpretrain/models/multimodal/ram/run/inference.py new file mode 100644 index 00000000..da5afcf5 --- /dev/null +++ b/mmpretrain/models/multimodal/ram/run/inference.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + + +def inference_ram(sample, model): + + with torch.no_grad(): + result = model.test_step(sample) + + return result + + +def inference_ram_openset(sample, model): + with torch.no_grad(): + result = model.test_step(sample) + + return result + + +def inference(sample, model, transforms, mode='normal'): + sample = transforms(sample) + if sample['inputs'].ndim == 3: + sample['inputs'] = sample['inputs'].unsqueeze(dim=0) + assert mode in ['normal', 'openset' + ], 'mode of inference must be "normal" or "openset"' + if mode == 'normal': + return inference_ram(sample, model) + else: + return inference_ram_openset(sample, model) diff --git a/mmpretrain/models/multimodal/ram/utils.py b/mmpretrain/models/multimodal/ram/utils.py new file mode 100644 index 00000000..32cb115b --- /dev/null +++ b/mmpretrain/models/multimodal/ram/utils.py @@ -0,0 +1,87 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from torch import nn + + +def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, + base_model_prefix: str, skip_key: str): + uninitialized_encoder_weights: List[str] = [] + if decoder.__class__ != encoder.__class__: + print(f'''{decoder.__class__} and {encoder.__class__} are not equal. + In this case make sure that + all encoder weights are correctly initialized.''') + + def tie_encoder_to_decoder_recursively( + decoder_pointer: nn.Module, + encoder_pointer: nn.Module, + module_name: str, + uninitialized_encoder_weights: List[str], + skip_key: str, + depth=0, + ): + assert isinstance(decoder_pointer, nn.Module) and isinstance( + encoder_pointer, nn.Module + ), f'{decoder_pointer} and {encoder_pointer}' + \ + 'have to be of type torch.nn.Module' + if hasattr(decoder_pointer, 'weight') and skip_key not in module_name: + assert hasattr(encoder_pointer, 'weight') + encoder_pointer.weight = decoder_pointer.weight + if hasattr(decoder_pointer, 'bias'): + assert hasattr(encoder_pointer, 'bias') + encoder_pointer.bias = decoder_pointer.bias + print(module_name + ' is tied') + return + + encoder_modules = encoder_pointer._modules + decoder_modules = decoder_pointer._modules + if len(decoder_modules) > 0: + assert (len(encoder_modules) > + 0), f'''Encoder module {encoder_pointer} + does not match decoder module {decoder_pointer}''' + + all_encoder_weights = set([ + module_name + '/' + sub_name + for sub_name in encoder_modules.keys() + ]) + encoder_layer_pos = 0 + for name, module in decoder_modules.items(): + if name.isdigit(): + encoder_name = str(int(name) + encoder_layer_pos) + decoder_name = name + if not isinstance( + decoder_modules[decoder_name], + type(encoder_modules[encoder_name])) and len( + encoder_modules) != len(decoder_modules): + # this can happen if the name corresponds to + # the position in a list module list of layers + # in this case the decoder has added a + # cross-attention that the encoder doesn't have + # thus skip this step and + # subtract one layer pos from encoder + encoder_layer_pos -= 1 + continue + elif name not in encoder_modules: + continue + elif depth > 500: + raise ValueError( + '''Max depth of recursive function `tie_encoder_to_decoder` reached. + It seems that there is a circular dependency + between two or more `nn.Modules` of your model.''') + else: + decoder_name = encoder_name = name + tie_encoder_to_decoder_recursively( + decoder_modules[decoder_name], + encoder_modules[encoder_name], + module_name + '/' + name, + uninitialized_encoder_weights, + skip_key, + depth=depth + 1, + ) + all_encoder_weights.remove(module_name + '/' + encoder_name) + + uninitialized_encoder_weights += list(all_encoder_weights) + + # tie weights recursively + tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, + uninitialized_encoder_weights, skip_key) diff --git a/mmpretrain/models/selfsup/itpn.py b/mmpretrain/models/selfsup/itpn.py index 85efd254..488a9963 100644 --- a/mmpretrain/models/selfsup/itpn.py +++ b/mmpretrain/models/selfsup/itpn.py @@ -64,6 +64,7 @@ class iTPNHiViT(HiViT): layer_scale_init_value: float = 0.0, mask_ratio: float = 0.75, reconstruction_type: str = 'pixel', + **kwargs, ): super().__init__( arch=arch, @@ -80,7 +81,9 @@ class iTPNHiViT(HiViT): norm_cfg=norm_cfg, ape=ape, rpe=rpe, - layer_scale_init_value=layer_scale_init_value) + layer_scale_init_value=layer_scale_init_value, + **kwargs, + ) self.pos_embed.requires_grad = False self.mask_ratio = mask_ratio diff --git a/mmpretrain/models/utils/batch_augments/resizemix.py b/mmpretrain/models/utils/batch_augments/resizemix.py index 89cfb720..c70f81b3 100644 --- a/mmpretrain/models/utils/batch_augments/resizemix.py +++ b/mmpretrain/models/utils/batch_augments/resizemix.py @@ -87,7 +87,7 @@ class ResizeMix(CutMix): (y1, y2, x1, x2), lam = self.cutmix_bbox_and_lam(img_shape, lam) batch_inputs[:, :, y1:y2, x1:x2] = F.interpolate( batch_inputs[index], - size=(y2 - y1, x2 - x1), + size=(int(y2 - y1), int(x2 - x1)), mode=self.interpolation, align_corners=False) mixed_scores = lam * batch_scores + (1 - lam) * batch_scores[index, :] diff --git a/mmpretrain/models/utils/tokenizer.py b/mmpretrain/models/utils/tokenizer.py index 5b8a324b..fddda432 100644 --- a/mmpretrain/models/utils/tokenizer.py +++ b/mmpretrain/models/utils/tokenizer.py @@ -12,6 +12,7 @@ from .huggingface import register_hf_tokenizer register_hf_tokenizer(AutoTokenizer) register_hf_tokenizer(LlamaTokenizer) +register_hf_tokenizer(BertTokenizer) @register_hf_tokenizer() diff --git a/mmpretrain/version.py b/mmpretrain/version.py index 24b33124..1822b7f2 100644 --- a/mmpretrain/version.py +++ b/mmpretrain/version.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved -__version__ = '1.0.2' +__version__ = '1.2.0' def parse_version_info(version_str): diff --git a/projects/gradio_demo/conversation.py b/projects/gradio_demo/conversation.py new file mode 100644 index 00000000..3c594690 --- /dev/null +++ b/projects/gradio_demo/conversation.py @@ -0,0 +1,137 @@ +# Modified from +# https://github.com/Vision-CAIR/MiniGPT-4/blob/main/minigpt4/conversation/conversation.py +import dataclasses +from typing import List + +import torch + + +@dataclasses.dataclass +class Conversation: + system: str + roles: List[str] + messages: List[List[str]] + sep: str = '###' + + def get_prompt(self): + ret = self.system + self.sep + for role, message in self.messages: + if message: + ret += role + ': ' + message + self.sep + else: + ret += role + ':' + return ret + + def append_message(self, role, message): + self.messages.append([role, message]) + + def copy(self): + return Conversation( + system=self.system, + roles=[role for role in self.roles], + messages=[[y for y in x] for x in self.messages], + sep=self.sep, + ) + + def dict(self): + return { + 'system': self.system, + 'roles': self.roles, + 'messages': self.messages, + 'offset': self.offset, + 'sep': self.sep, + } + + +EN_CONV_VISION = Conversation( + system='Give the following image. ' + 'You will be able to see the image once I provide it to you. ' + 'Please answer my questions in detail.', + roles=['Ask', 'Answer'], + messages=[], + sep='###', +) + +ZH_CONV_VISION = Conversation( + system='给定一张图片,请仔细观察这张图片,并回答我的问题。', + roles=['问', '答'], + messages=[], + sep='###', +) + + +class Chat: + + def __init__(self, inferencer, device, is_half=False): + self.device = device + self.inferencer = inferencer + self.model = inferencer.model + self.is_half = is_half + if is_half: + self.model = self.model.half() + self.model = self.model.to(device) + self.max_length = 2000 + + def upload_img(self, image, conv, img_list): + img = next(self.inferencer.preprocess([image])) + img = self.model.data_preprocessor(img, False)['images'] + img = img.to(self.device) + image_emb, _ = self.model.encode_img(img) + img_list.append(image_emb) + conv.append_message(conv.roles[0], '<Img><ImageHere></Img>') + + def get_context_emb(self, conv, img_list): + prompt = conv.get_prompt() + prompt_segs = prompt.split('<ImageHere>') + seg_tokens = [ + self.model.llama_tokenizer( + seg, return_tensors='pt', + add_special_tokens=(i == 0)).to(self.device).input_ids + for i, seg in enumerate(prompt_segs) + ] + seg_embs = [ + self.model.llama_model.model.embed_tokens(seg_token) + for seg_token in seg_tokens + ] + mixed_embs = [ + emb for pair in zip(seg_embs[:-1], img_list) for emb in pair + ] + [seg_embs[-1]] + mixed_embs = torch.cat(mixed_embs, dim=1) + return mixed_embs + + def ask(self, text, conv): + if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[ + 0] and conv.messages[-1][1][-6:] == '</Img>': + conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text]) + else: + conv.append_message(conv.roles[0], text) + + def answer(self, conv, img_list, generation_cfg): + conv.append_message(conv.roles[1], None) + embs = self.get_context_emb(conv, img_list) + cur_max_len = generation_cfg['max_new_tokens'] + embs.shape[1] + if cur_max_len > self.max_length: + print('Warning: The number of tokens in current conversation' + 'exceeds the max length. ' + 'The model will not see the contexts outside the range.') + begin_idx = max(0, cur_max_len - self.max_length) + embs = embs[:, begin_idx:] + if self.is_half: + embs = embs.half() + outputs = self.model.llama_model.generate( + inputs_embeds=embs, + eos_token_id=self.model.end_token_id, + **generation_cfg) + + output_token = outputs[0] + if output_token[0] == 0: + output_token = output_token[1:] + elif output_token[0] == 1: + output_token = output_token[1:] + output_text = self.model.llama_tokenizer.decode( + output_token, + add_special_tokens=False, + skip_special_tokens=True) + output_text = output_text.split('###')[0] + conv.messages[-1][1] = output_text + return output_text diff --git a/projects/gradio_demo/minigpt4_demo.py b/projects/gradio_demo/minigpt4_demo.py new file mode 100644 index 00000000..e4d61426 --- /dev/null +++ b/projects/gradio_demo/minigpt4_demo.py @@ -0,0 +1,144 @@ +import argparse + +import gradio as gr +import numpy as np +import torch +from conversation import EN_CONV_VISION, ZH_CONV_VISION, Chat + +from mmpretrain import ImageCaptionInferencer + +parser = argparse.ArgumentParser(description='MiniGPT4 demo') +parser.add_argument( + 'cfg', type=str, help='config file for minigpt4 (absolute path)') +parser.add_argument( + 'ckpt', type=str, help='pretrained file for minigpt4 (absolute path)') +args = parser.parse_args() + +if torch.cuda.is_available(): + devices = [ + torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count()) + ] +elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + devices = [torch.device('mps')] +else: + devices = [torch.device('cpu')] + + +def get_free_device(): + if hasattr(torch.cuda, 'mem_get_info'): + free = [torch.cuda.mem_get_info(gpu)[0] for gpu in devices] + select = max(zip(free, range(len(free))))[1] + else: + import random + select = random.randint(0, len(devices) - 1) + return devices[select] + + +device = get_free_device() +inferencer = ImageCaptionInferencer(model=args.cfg, pretrained=args.ckpt) +model = inferencer.model +chat = Chat(inferencer, device=device, is_half=(device.type != 'cpu')) + + +def reset(chat_state, img_list): + if chat_state is not None: + chat_state.messages = [] + if img_list is not None: + img_list = [] + return (None, gr.update(value=None, interactive=True), + gr.update( + value=None, + placeholder='Please upload your image first', + interactive=False), + gr.update(value='Upload & Start Chat', + interactive=True), chat_state, img_list, + gr.update(value='Restart', interactive=False), + gr.update(value='English', interactive=True)) + + +def upload_img(gr_img, language, chat_state): + if gr_img is None: + return (None, + gr.update( + placeholder='Please upload your image first', + interactive=False), + gr.update(value='Upload & Start Chat', + interactive=True), chat_state, None, + gr.update(value='Restart', interactive=False), + gr.update(value='English', interactive=True)) + + if (language == 'English'): + chat_state = EN_CONV_VISION.copy() + else: + chat_state = ZH_CONV_VISION.copy() + img_list = [] + gr_img_array = np.asarray(gr_img) + chat.upload_img(gr_img_array, chat_state, img_list) + return (gr.update(interactive=False), + gr.update(placeholder='Type and press Enter', interactive=True), + gr.update(value='Start Chatting', + interactive=False), chat_state, img_list, + gr.update(value='Restart', + interactive=True), gr.update(interactive=False)) + + +def ask(user_message, chatbot, chat_state): + if (len(user_message) == 0): + return gr.update( + value=None, + placeholder='Input should not be empty!', + interactive=True), chatbot, chat_state + chat.ask(user_message, chat_state) + chatbot = chatbot + [[user_message, None]] + return '', chatbot, chat_state + + +def answer(chatbot, chat_state, img_list): + llm_message = chat.answer( + conv=chat_state, + img_list=img_list, + generation_cfg=model.generation_cfg) + chatbot[-1][1] = llm_message + return chatbot, chat_state, img_list + + +if __name__ == '__main__': + title = 'MMPretrain MiniGPT-4 Inference Demo' + with gr.Blocks(analytics_enabled=False, title=title) as demo: + gr.Markdown(f'# {title}') + with gr.Row(): + with gr.Column(): + image = gr.Image(type='pil') + language = gr.Dropdown(['English', 'Chinese'], + label='Language', + info='Select chatbot\'s language', + value='English', + interactive=True) + upload_button = gr.Button( + value='Upload & Start Chat', interactive=True) + clear = gr.Button(value='Restart', interactive=False) + + with gr.Column(): + chat_state = gr.State() + img_list = gr.State() + chatbot = gr.Chatbot( + label='MiniGPT-4', min_width=320, height=600) + text_input = gr.Textbox( + label='User', + placeholder='Please upload your image first', + interactive=False) + + upload_button.click(upload_img, [image, language, chat_state], [ + image, text_input, upload_button, chat_state, img_list, clear, + language + ]) + text_input.submit(ask, [text_input, chatbot, chat_state], + [text_input, chatbot, chat_state]).then( + answer, [chatbot, chat_state, img_list], + [chatbot, chat_state, img_list]) + clear.click(reset, [chat_state, img_list], [ + chatbot, image, text_input, upload_button, chat_state, img_list, + clear, language + ]) + + demo.launch(share=True) diff --git a/requirements/mminstall.txt b/requirements/mminstall.txt index d23d0ac7..9b736b02 100644 --- a/requirements/mminstall.txt +++ b/requirements/mminstall.txt @@ -1,2 +1,2 @@ -mmcv>=2.0.0,<2.1.0 +mmcv>=2.0.0,<2.4.0 mmengine>=0.8.3,<1.0.0 diff --git a/requirements/optional.txt b/requirements/optional.txt index 85853cda..5f31808f 100644 --- a/requirements/optional.txt +++ b/requirements/optional.txt @@ -1,4 +1,4 @@ albumentations>=0.3.2 --no-binary qudida,albumentations # For Albumentations data transform -grad-cam >= 1.3.7 # For CAM visualization +grad-cam >= 1.3.7,<1.5.0 # For CAM visualization requests # For torchserve scikit-learn # For t-SNE visualization and unit tests. diff --git a/resources/miaomiao_qrcode.jpg b/resources/miaomiao_qrcode.jpg new file mode 100644 index 00000000..d34cbae6 Binary files /dev/null and b/resources/miaomiao_qrcode.jpg differ diff --git a/tests/test_models/test_backbones/test_repmlp.py b/tests/test_models/test_backbones/test_repmlp.py index bfcb5dfc..f03fce4e 100644 --- a/tests/test_models/test_backbones/test_repmlp.py +++ b/tests/test_models/test_backbones/test_repmlp.py @@ -169,4 +169,5 @@ class TestRepMLP(TestCase): assert len(feats_) == len(feats__) for i in range(len(feats)): - self.assertTrue(torch.allclose(feats__[i], feats_[i])) + self.assertTrue( + torch.allclose(feats__[i], feats_[i], rtol=0.01, atol=0.01)) diff --git a/tools/model_converters/llava-delta2mmpre.py b/tools/model_converters/llava-delta2mmpre.py index bc51b19d..104ed07d 100644 --- a/tools/model_converters/llava-delta2mmpre.py +++ b/tools/model_converters/llava-delta2mmpre.py @@ -9,23 +9,21 @@ from huggingface_hub import snapshot_download from transformers.modeling_utils import load_state_dict prog_description = """\ -Merge Llava delta weights and original weights, -and save as MMPreTrain checkpoint. +Convert Llava weights and original weights. """ def parse_args(): parser = argparse.ArgumentParser(description=prog_description) - parser.add_argument( - 'src_path', type=str, help='The original checkpoint dir') - parser.add_argument( - 'delta_path', type=str, help='The delta checkpoint dir') - parser.add_argument('dst_path', type=str, help='The saved checkpoint path') + parser.add_argument('src', type=str, help='The original checkpoint dir') + parser.add_argument('dst', type=str, help='The saved checkpoint path') + parser.add_argument('--delta', type=str, help='The delta checkpoint dir') args = parser.parse_args() return args def load_checkpoint(path: Path): + path = Path(path) if path.is_file(): return torch.load(path) @@ -41,19 +39,23 @@ def load_checkpoint(path: Path): def main(): args = parse_args() - if Path(args.src_path).exists(): - src_path = Path(args.src_path) + if Path(args.src).exists(): + src_path = args.src else: - src_path = Path(snapshot_download(args.src_path)) + src_path = snapshot_download( + args.src, allow_patterns='pytorch_model*.bin') src_state_dict = load_checkpoint(src_path) - if Path(args.delta_path).exists(): - delta_path = Path(args.delta_path) + if args.delta is None: + delta_state_dict = {} + elif Path(args.delta).exists(): + delta_state_dict = load_checkpoint(args.delta) else: - delta_path = Path(snapshot_download(args.delta_path)) - delta_state_dict = load_checkpoint(delta_path) + delta_path = snapshot_download( + args.delta, allow_patterns='pytorch_model*.bin') + delta_state_dict = load_checkpoint(delta_path) - merged_state_dict = OrderedDict() + new_state_dict = OrderedDict() for k, v in src_state_dict.items(): if k in delta_state_dict: delta_v = delta_state_dict.pop(k) @@ -63,12 +65,13 @@ def main(): v = delta_v else: v += delta_v - merged_state_dict['model.lang_encoder.' + k] = v + if 'rotary_emb.inv_freq' not in k: + new_state_dict['model.lang_encoder.' + k] = v for k, v in delta_state_dict.items(): - merged_state_dict['model.lang_encoder.' + k] = v + new_state_dict['model.lang_encoder.' + k] = v - torch.save(merged_state_dict, args.dst_path) + torch.save(new_state_dict, args.dst) print('Done!!') diff --git a/tools/model_converters/ram2mmpretrain.py b/tools/model_converters/ram2mmpretrain.py new file mode 100644 index 00000000..5ee3b476 --- /dev/null +++ b/tools/model_converters/ram2mmpretrain.py @@ -0,0 +1,117 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from collections import OrderedDict +from copy import deepcopy + +import mmengine +import torch +from mmengine.runner import CheckpointLoader + + +def convert_swin(ckpt): + new_ckpt = OrderedDict() + convert_mapping = dict() + + def correct_unfold_reduction_order(x): + out_channel, in_channel = x.shape + x = x.reshape(out_channel, 4, in_channel // 4) + x = x[:, [0, 2, 1, 3], :].transpose(1, + 2).reshape(out_channel, in_channel) + return x + + def correct_unfold_norm_order(x): + in_channel = x.shape[0] + x = x.reshape(4, in_channel // 4) + x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel) + return x + + for k, v in ckpt.items(): + if 'attn_mask' in k: + continue + if k.startswith('head'): + continue + elif k.startswith('layers'): + new_v = v + if 'attn.' in k: + new_k = k.replace('attn.', 'attn.w_msa.') + elif 'mlp.' in k: + if 'mlp.fc1.' in k: + new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.') + elif 'mlp.fc2.' in k: + new_k = k.replace('mlp.fc2.', 'ffn.layers.1.') + else: + new_k = k.replace('mlp.', 'ffn.') + elif 'downsample' in k: + new_k = k + if 'reduction.' in k: + new_v = correct_unfold_reduction_order(v) + elif 'norm.' in k: + new_v = correct_unfold_norm_order(v) + else: + new_k = k + new_k = new_k.replace('layers', 'stages', 1) + elif k.startswith('patch_embed'): + new_v = v + if 'proj' in k: + new_k = k.replace('proj', 'projection') + else: + new_k = k + elif k.startswith('norm'): + new_v = v + new_k = k.replace('norm', 'norm3') + else: + new_v = v + new_k = k + + new_ckpt[new_k] = new_v + convert_mapping[k] = new_k + + return new_ckpt, convert_mapping + + +def main(): + parser = argparse.ArgumentParser( + description='Convert keys in official pretrained RAM models to' + 'MMPretrain style.') + parser.add_argument('src', help='src model path or url') + # The dst path must be a full path of the new checkpoint. + parser.add_argument('dst', help='save path') + args = parser.parse_args() + + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + + visual_ckpt = OrderedDict() + for key in state_dict: + if key.startswith('visual_encoder.'): + new_key = key.replace('visual_encoder.', '') + visual_ckpt[new_key] = state_dict[key] + + new_visual_ckpt, convert_mapping = convert_swin(visual_ckpt) + new_ckpt = deepcopy(state_dict) + for key in state_dict: + if key.startswith('visual_encoder.'): + if 'attn_mask' in key: + del new_ckpt[key] + continue + del new_ckpt[key] + old_key = key.replace('visual_encoder.', '') + new_ckpt[key.replace(old_key, + convert_mapping[old_key])] = deepcopy( + new_visual_ckpt[key.replace( + old_key, + convert_mapping[old_key]).replace( + 'visual_encoder.', '')]) + + mmengine.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(new_ckpt, args.dst) + + +if __name__ == '__main__': + main() diff --git a/tools/train.py b/tools/train.py index 84c1eec9..89c8548f 100644 --- a/tools/train.py +++ b/tools/train.py @@ -91,10 +91,6 @@ def merge_args(cfg, args): # enable automatic-mixed-precision training if args.amp is True: - optim_wrapper = cfg.optim_wrapper.get('type', 'OptimWrapper') - assert optim_wrapper in ['OptimWrapper', 'AmpOptimWrapper'], \ - '`--amp` is not supported custom optimizer wrapper type ' \ - f'`{optim_wrapper}.' cfg.optim_wrapper.type = 'AmpOptimWrapper' cfg.optim_wrapper.setdefault('loss_scale', 'dynamic')