diff --git a/configs/minigpt4/minigpt-4_baichuan-7b_caption.py b/configs/minigpt4/minigpt-4_baichuan-7b_caption.py new file mode 100644 index 00000000..5a737d04 --- /dev/null +++ b/configs/minigpt4/minigpt-4_baichuan-7b_caption.py @@ -0,0 +1,185 @@ +_base_ = [ + '../_base_/default_runtime.py', +] + +data_preprocessor = dict( + type='MultiModalDataPreprocessor', + mean=[122.770938, 116.7460125, 104.09373615], + std=[68.5005327, 66.6321579, 70.32316305], + to_rgb=True, +) + +# dataset settings +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + scale=(224, 224), + interpolation='bicubic', + backend='pillow'), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict( + type='CleanCaption', + keys='chat_content', + remove_chars='', + lowercase=False), + dict( + type='PackInputs', + algorithm_keys=['chat_content', 'lang'], + meta_keys=['image_id']), +] + +train_dataloader = dict( + batch_size=2, + num_workers=4, + dataset=dict( + type='MiniGPT4Dataset', + data_root='YOUR_DATA_DIRECTORY', + ann_file='YOUR_DATA_FILE', + pipeline=train_pipeline), + sampler=dict(type='DefaultSampler', shuffle=True), + collate_fn=dict(type='default_collate'), + drop_last=False, +) + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + scale=(224, 224), + interpolation='bicubic', + backend='pillow'), + dict(type='PackInputs', meta_keys=['image_id']), +] + +test_evaluator = dict( + type='COCOCaption', + ann_file='data/coco/annotations/coco_karpathy_val_gt.json', +) + +test_dataloader = dict( + batch_size=1, + dataset=dict( + type='COCOCaption', + data_root='data/coco', + ann_file='annotations/coco_karpathy_val.json', + pipeline=test_pipeline)) + +# model settings +model = dict( + type='MiniGPT4', + vision_encoder=dict( + type='BEiTViT', + # eva-g without the final layer + arch=dict( + embed_dims=1408, + num_layers=39, + num_heads=16, + feedforward_channels=6144, + ), + img_size=224, + patch_size=14, + layer_scale_init_value=0.0, + frozen_stages=39, + use_abs_pos_emb=True, + use_rel_pos_bias=False, + final_norm=False, + use_shared_rel_pos_bias=False, + out_type='raw', + pretrained= # noqa + 'https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_eva-g-p14_20230615-e908c021.pth' # noqa + ), + q_former_model=dict( + type='Qformer', + model_style='bert-base-uncased', + vision_model_width=1408, + add_cross_attention=True, + cross_attention_freq=2, + num_query_token=32, + pretrained= # noqa + 'https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_qformer_20230615-1dfa889c.pth' # noqa + ), + lang_encoder=dict( + type='AutoModelForCausalLM', + name_or_path='YOUR_PATH_TO_BAICHUAN', + trust_remote_code=True), + tokenizer=dict( + type='AutoTokenizer', + name_or_path='YOUR_PATH_TO_BAICHUAN', + trust_remote_code=True), + task='caption', + en_prompt_template='###Ask: {} ###Answer: ', + zh_prompt_template='###问:{} ###答:', + raw_prompts=[ + [ + ' Describe this image in detail.', + ' Take a look at this image and describe what you notice.', # noqa + ' Please provide a detailed description of the picture.', # noqa + ' Could you describe the contents of this image for me?' # noqa + ], + [ + ' 详细描述这张图片。', + ' 浏览这张图片并描述你注意到什么。', + ' 请对这张图片进行详细的描述。', + ' 你能为我描述这张图片的内容吗?' + ] + ], + max_txt_len=160, + end_sym='###') + +strategy = dict( + type='DeepSpeedStrategy', + fp16=dict( + enabled=True, + auto_cast=False, + fp16_master_weights_and_grads=False, + loss_scale=0, + loss_scale_window=1000, + hysteresis=1, + min_loss_scale=1, + initial_scale_power=16, + ), + inputs_to_half=[0], + zero_optimization=dict( + stage=2, + allgather_partitions=True, + allgather_bucket_size=2e8, + reduce_scatter=True, + reduce_bucket_size='auto', + overlap_comm=True, + contiguous_gradients=True, + ), +) + +# schedule settings +optim_wrapper = dict( + type='DeepSpeedOptimWrapper', + optimizer=dict(type='AdamW', lr=1e-3, weight_decay=0.05)) + +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-3 / 500, + by_epoch=False, + begin=0, + end=500, + ), + dict( + type='CosineAnnealingLR', + eta_min=2e-4, + by_epoch=False, + begin=500, + ), +] + +train_cfg = dict(by_epoch=True, max_epochs=6) +test_cfg = dict() + +default_hooks = dict( + checkpoint=dict( + type='CheckpointHook', + interval=1, + by_epoch=True, + save_last=True, + max_keep_ckpts=1, + )) diff --git a/configs/minigpt4/minigpt-4_vicuna-7b_caption.py b/configs/minigpt4/minigpt-4_vicuna-7b_caption.py index 704760af..3130c25f 100644 --- a/configs/minigpt4/minigpt-4_vicuna-7b_caption.py +++ b/configs/minigpt4/minigpt-4_vicuna-7b_caption.py @@ -57,10 +57,18 @@ model = dict( task='caption', prompt_template='###Human: {} ###Assistant: ', raw_prompts=[ - ' Describe this image in detail.', - ' Take a look at this image and describe what you notice.', # noqa - ' Please provide a detailed description of the picture.', # noqa - ' Could you describe the contents of this image for me?', # noqa + [ + ' Describe this image in detail.', + ' Take a look at this image and describe what you notice.', # noqa + ' Please provide a detailed description of the picture.', # noqa + ' Could you describe the contents of this image for me?' # noqa + ], + [ + ' 详细描述这张图片。', + ' 浏览这张图片并描述你注意到什么。', + ' 请对这张图片进行详细的描述。', + ' 你能为我描述这张图片的内容吗?' + ] ], max_txt_len=160, end_sym='###')