diff --git a/configs/minigpt4/minigpt-4_baichuan-7b_caption.py b/configs/minigpt4/minigpt-4_baichuan-7b_caption.py
new file mode 100644
index 00000000..5a737d04
--- /dev/null
+++ b/configs/minigpt4/minigpt-4_baichuan-7b_caption.py
@@ -0,0 +1,185 @@
+_base_ = [
+ '../_base_/default_runtime.py',
+]
+
+data_preprocessor = dict(
+ type='MultiModalDataPreprocessor',
+ mean=[122.770938, 116.7460125, 104.09373615],
+ std=[68.5005327, 66.6321579, 70.32316305],
+ to_rgb=True,
+)
+
+# dataset settings
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='Resize',
+ scale=(224, 224),
+ interpolation='bicubic',
+ backend='pillow'),
+ dict(type='RandomFlip', prob=0.5, direction='horizontal'),
+ dict(
+ type='CleanCaption',
+ keys='chat_content',
+ remove_chars='',
+ lowercase=False),
+ dict(
+ type='PackInputs',
+ algorithm_keys=['chat_content', 'lang'],
+ meta_keys=['image_id']),
+]
+
+train_dataloader = dict(
+ batch_size=2,
+ num_workers=4,
+ dataset=dict(
+ type='MiniGPT4Dataset',
+ data_root='YOUR_DATA_DIRECTORY',
+ ann_file='YOUR_DATA_FILE',
+ pipeline=train_pipeline),
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ collate_fn=dict(type='default_collate'),
+ drop_last=False,
+)
+
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='Resize',
+ scale=(224, 224),
+ interpolation='bicubic',
+ backend='pillow'),
+ dict(type='PackInputs', meta_keys=['image_id']),
+]
+
+test_evaluator = dict(
+ type='COCOCaption',
+ ann_file='data/coco/annotations/coco_karpathy_val_gt.json',
+)
+
+test_dataloader = dict(
+ batch_size=1,
+ dataset=dict(
+ type='COCOCaption',
+ data_root='data/coco',
+ ann_file='annotations/coco_karpathy_val.json',
+ pipeline=test_pipeline))
+
+# model settings
+model = dict(
+ type='MiniGPT4',
+ vision_encoder=dict(
+ type='BEiTViT',
+ # eva-g without the final layer
+ arch=dict(
+ embed_dims=1408,
+ num_layers=39,
+ num_heads=16,
+ feedforward_channels=6144,
+ ),
+ img_size=224,
+ patch_size=14,
+ layer_scale_init_value=0.0,
+ frozen_stages=39,
+ use_abs_pos_emb=True,
+ use_rel_pos_bias=False,
+ final_norm=False,
+ use_shared_rel_pos_bias=False,
+ out_type='raw',
+ pretrained= # noqa
+ 'https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_eva-g-p14_20230615-e908c021.pth' # noqa
+ ),
+ q_former_model=dict(
+ type='Qformer',
+ model_style='bert-base-uncased',
+ vision_model_width=1408,
+ add_cross_attention=True,
+ cross_attention_freq=2,
+ num_query_token=32,
+ pretrained= # noqa
+ 'https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_qformer_20230615-1dfa889c.pth' # noqa
+ ),
+ lang_encoder=dict(
+ type='AutoModelForCausalLM',
+ name_or_path='YOUR_PATH_TO_BAICHUAN',
+ trust_remote_code=True),
+ tokenizer=dict(
+ type='AutoTokenizer',
+ name_or_path='YOUR_PATH_TO_BAICHUAN',
+ trust_remote_code=True),
+ task='caption',
+ en_prompt_template='###Ask: {} ###Answer: ',
+ zh_prompt_template='###问:{} ###答:',
+ raw_prompts=[
+ [
+ '
Describe this image in detail.',
+ '
Take a look at this image and describe what you notice.', # noqa
+ '
Please provide a detailed description of the picture.', # noqa
+ '
Could you describe the contents of this image for me?' # noqa
+ ],
+ [
+ '
详细描述这张图片。',
+ '
浏览这张图片并描述你注意到什么。',
+ '
请对这张图片进行详细的描述。',
+ '
你能为我描述这张图片的内容吗?'
+ ]
+ ],
+ max_txt_len=160,
+ end_sym='###')
+
+strategy = dict(
+ type='DeepSpeedStrategy',
+ fp16=dict(
+ enabled=True,
+ auto_cast=False,
+ fp16_master_weights_and_grads=False,
+ loss_scale=0,
+ loss_scale_window=1000,
+ hysteresis=1,
+ min_loss_scale=1,
+ initial_scale_power=16,
+ ),
+ inputs_to_half=[0],
+ zero_optimization=dict(
+ stage=2,
+ allgather_partitions=True,
+ allgather_bucket_size=2e8,
+ reduce_scatter=True,
+ reduce_bucket_size='auto',
+ overlap_comm=True,
+ contiguous_gradients=True,
+ ),
+)
+
+# schedule settings
+optim_wrapper = dict(
+ type='DeepSpeedOptimWrapper',
+ optimizer=dict(type='AdamW', lr=1e-3, weight_decay=0.05))
+
+param_scheduler = [
+ dict(
+ type='LinearLR',
+ start_factor=1e-3 / 500,
+ by_epoch=False,
+ begin=0,
+ end=500,
+ ),
+ dict(
+ type='CosineAnnealingLR',
+ eta_min=2e-4,
+ by_epoch=False,
+ begin=500,
+ ),
+]
+
+train_cfg = dict(by_epoch=True, max_epochs=6)
+test_cfg = dict()
+
+default_hooks = dict(
+ checkpoint=dict(
+ type='CheckpointHook',
+ interval=1,
+ by_epoch=True,
+ save_last=True,
+ max_keep_ckpts=1,
+ ))
diff --git a/configs/minigpt4/minigpt-4_vicuna-7b_caption.py b/configs/minigpt4/minigpt-4_vicuna-7b_caption.py
index 704760af..3130c25f 100644
--- a/configs/minigpt4/minigpt-4_vicuna-7b_caption.py
+++ b/configs/minigpt4/minigpt-4_vicuna-7b_caption.py
@@ -57,10 +57,18 @@ model = dict(
task='caption',
prompt_template='###Human: {} ###Assistant: ',
raw_prompts=[
- '
Describe this image in detail.',
- '
Take a look at this image and describe what you notice.', # noqa
- '
Please provide a detailed description of the picture.', # noqa
- '
Could you describe the contents of this image for me?', # noqa
+ [
+ '
Describe this image in detail.',
+ '
Take a look at this image and describe what you notice.', # noqa
+ '
Please provide a detailed description of the picture.', # noqa
+ '
Could you describe the contents of this image for me?' # noqa
+ ],
+ [
+ '
详细描述这张图片。',
+ '
浏览这张图片并描述你注意到什么。',
+ '
请对这张图片进行详细的描述。',
+ '
你能为我描述这张图片的内容吗?'
+ ]
],
max_txt_len=160,
end_sym='###')