[Feature] Support otter (#1651)

* [Feature] Support Otter

* Update docs
This commit is contained in:
Ma Zerun 2023-06-17 16:03:21 +08:00 committed by GitHub
parent 9d3fc43073
commit e69bace03f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 540 additions and 3 deletions

View File

@ -256,6 +256,7 @@ Results and models are available in the [model zoo](https://mmpretrain.readthedo
<li><a href="configs/flamingo">Flamingo (NeurIPS'2022)</a></li>
<li><a href="configs/chinese_clip">Chinese CLIP (arxiv'2022)</a></li>
<li><a href="configs/minigpt4">MiniGPT-4 (arxiv'2023)</a></li>
<li><a href="configs/otter">Otter (arxiv'2023)</a></li>
</ul>
</td>
<td>

View File

@ -252,6 +252,7 @@ mim install -e ".[multimodal]"
<li><a href="configs/flamingo">Flamingo (NeurIPS'2022)</a></li>
<li><a href="configs/chinese_clip">Chinese CLIP (arxiv'2022)</a></li>
<li><a href="configs/minigpt4">MiniGPT-4 (arxiv'2023)</a></li>
<li><a href="configs/otter">Otter (arxiv'2023)</a></li>
</ul>
</td>
<td>

78
configs/otter/README.md Normal file
View File

@ -0,0 +1,78 @@
# Otter
> [Otter: A Multi-Modal Model with In-Context Instruction Tuning](https://arxiv.org/abs/2305.03726)
<!-- [ALGORITHM] -->
## Abstract
Large language models (LLMs) have demonstrated significant universal capabilities as few/zero-shot learners in various tasks due to their pre-training on vast amounts of text data, as exemplified by GPT-3, which boosted to InstrctGPT and ChatGPT, effectively following natural language instructions to accomplish real-world tasks. In this paper, we propose to introduce instruction tuning into multi-modal models, motivated by the Flamingo model's upstream interleaved format pretraining dataset. We adopt a similar approach to construct our MultI-Modal In-Context Instruction Tuning (MIMIC-IT) dataset. We then introduce Otter, a multi-modal model based on OpenFlamingo (open-sourced version of DeepMind's Flamingo), trained on MIMIC-IT and showcasing improved instruction-following ability and in-context learning. We also optimize OpenFlamingo's implementation for researchers, democratizing the required training resources from 1$\times$ A100 GPU to 4$\times$ RTX-3090 GPUs, and integrate both OpenFlamingo and Otter into Huggingface Transformers for more researchers to incorporate the models into their customized training and inference pipelines.
<div align=center>
<img src="https://camo.githubusercontent.com/70613ab882a7827808148a2c577029d544371e707b0832a0b01151c54ce553c3/68747470733a2f2f692e706f7374696d672e63632f5477315a304243572f6f7474657276302d322d64656d6f2e706e67" width="80%"/>
</div>
## How to use it?
<!-- [TABS-BEGIN] -->
**Use the model**
```python
import torch
from mmpretrain import get_model, inference_model
model = get_model('otter-9b_3rdparty_caption', pretrained=True, device='cuda')
out = inference_model(model, 'demo/cat-dog.png')
print(out)
```
**Test Command**
Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset).
Test:
```shell
python tools/test.py configs/otter/otter-9b_caption.py https://download.openmmlab.com/mmclassification/v1/otter/otter-9b-adapter_20230613-51c5be8d.pth
```
<!-- [TABS-END] -->
## Models and results
### Image Caption on COCO
| Model | Pretrain | Params (M) | BLEU-4 | CIDER | Config | Download |
| :---------------------------- | :----------: | :--------: | :------: | :------: | :---------------------------: | :------------------------------------------------------------------------------------------------------: |
| `otter-9b_3rdparty_caption`\* | From scratch | 8220.45 | Upcoming | Upcoming | [config](otter-9b_caption.py) | [model](https://download.openmmlab.com/mmclassification/v1/otter/otter-9b-adapter_20230613-51c5be8d.pth) |
*Models with * are converted from the [official repo](https://github.com/Luodian/Otter/tree/main). The config files of these models are only for inference. We haven't reprodcue the training results.*
### Visual Question Answering on VQAv2
| Model | Pretrain | Params (M) | Accuracy | Config | Download |
| :------------------------ | :----------: | :--------: | :------: | :-----------------------: | :------------------------------------------------------------------------------------------------------: |
| `otter-9b_3rdparty_vqa`\* | From scratch | 8220.45 | Upcoming | [config](otter-9b_vqa.py) | [model](https://download.openmmlab.com/mmclassification/v1/otter/otter-9b-adapter_20230613-51c5be8d.pth) |
*Models with * are converted from the [official repo](https://github.com/Luodian/Otter/tree/main). The config files of these models are only for inference. We haven't reprodcue the training results.*
## Citation
```bibtex
@article{li2023otter,
title={Otter: A Multi-Modal Model with In-Context Instruction Tuning},
author={Li, Bo and Zhang, Yuanhan and Chen, Liangyu and Wang, Jinghao and Yang, Jingkang and Liu, Ziwei},
journal={arXiv preprint arXiv:2305.03726},
year={2023}
}
@article{li2023mimicit,
title={MIMIC-IT: Multi-Modal In-Context Instruction Tuning},
author={Bo Li and Yuanhan Zhang and Liangyu Chen and Jinghao Wang and Fanyi Pu and Jingkang Yang and Chunyuan Li and Ziwei Liu},
year={2023},
eprint={2306.05425},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```

View File

@ -0,0 +1,43 @@
Collections:
- Name: Otter
Metadata:
Architecture:
- Transformer
- Gated Cross-Attention Dense
Paper:
Title: 'Otter: A Multi-Modal Model with In-Context Instruction Tuning'
URL: https://arxiv.org/abs/2305.03726
README: configs/otter/README.md
Models:
- Name: otter-9b_3rdparty_caption
Metadata:
FLOPs: null
Parameters: 8220452880
In Collection: Otter
Results:
- Task: Image Caption
Dataset: COCO
Metrics:
BLEU-4: null
CIDER: null
Weights: https://download.openmmlab.com/mmclassification/v1/otter/otter-9b-adapter_20230613-51c5be8d.pth
Config: configs/otter/otter-9b_caption.py
Converted From:
Weights: https://huggingface.co/luodian/otter-9b-hf
Code: https://github.com/Luodian/Otter/tree/main
- Name: otter-9b_3rdparty_vqa
Metadata:
FLOPs: null
Parameters: 8220452880
In Collection: Otter
Results:
- Task: Visual Question Answering
Dataset: VQAv2
Metrics:
Accuracy: null
Weights: https://download.openmmlab.com/mmclassification/v1/otter/otter-9b-adapter_20230613-51c5be8d.pth
Config: configs/otter/otter-9b_vqa.py
Converted From:
Weights: https://huggingface.co/luodian/otter-9b-hf
Code: https://github.com/Luodian/Otter/tree/main

View File

@ -0,0 +1,91 @@
_base_ = [
'../_base_/default_runtime.py',
]
# model settings
model = dict(
type='Otter',
tokenizer=dict(type='LlamaTokenizer', name_or_path='huggyllama/llama-7b'),
vision_encoder=dict(
type='VisionTransformer',
arch='l',
patch_size=14,
pre_norm=True,
norm_cfg=dict(type='LN', eps=1e-5),
layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')),
final_norm=False,
out_type='raw',
pretrained=(
'https://download.openmmlab.com/mmclassification/v0/clip/'
'vit-large-p14_clip-openai-pre_3rdparty_20230517-95e2af0b.pth'),
),
lang_encoder=dict(
base=dict(
type='AutoModelForCausalLM',
name_or_path='huggyllama/llama-7b',
local_files_only=True),
adapter=dict(
type='FlamingoLMAdapter',
vis_hidden_size=1024,
cross_attn_every_n_layers=4,
use_media_placement_augmentation=False,
only_attend_previous=True,
),
),
task='caption',
final_prompt_tmpl='<image>User:Please describe the image. GPT:<answer>',
generation_cfg=dict(
num_beams=3, max_new_tokens=24, no_repeat_ngram_size=3),
)
# data settings
data_preprocessor = dict(
type='MultiModalDataPreprocessor',
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=224,
interpolation='bicubic',
backend='pillow'),
dict(type='CenterCrop', crop_size=(224, 224)),
dict(
type='PackInputs',
algorithm_keys=['gt_caption'],
meta_keys=['image_id'],
),
]
val_dataloader = dict(
batch_size=8,
num_workers=8,
dataset=dict(
type='FlamingoEvalCOCOCaption',
data_root='data/coco',
ann_file='annotations/captions_train2014.json',
data_prefix=dict(img_path='train2014'),
pipeline=test_pipeline,
num_shots=0,
num_support_examples=2048,
num_query_examples=5000,
),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(
type='COCOCaption',
ann_file='data/coco/annotations/captions_train2014.json')
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator
# schedule settings
val_cfg = dict()
test_cfg = dict()

View File

@ -0,0 +1,104 @@
_base_ = [
'../_base_/default_runtime.py',
]
# model settings
model = dict(
type='Otter',
tokenizer=dict(type='LlamaTokenizer', name_or_path='huggyllama/llama-7b'),
vision_encoder=dict(
type='VisionTransformer',
arch='l',
patch_size=14,
pre_norm=True,
norm_cfg=dict(type='LN', eps=1e-5),
layer_cfgs=dict(act_cfg=dict(type='QuickGELU')),
final_norm=False,
out_type='raw',
pretrained=(
'https://download.openmmlab.com/mmclassification/v0/clip/'
'vit-large-p14_clip-openai-pre_3rdparty_20230517-95e2af0b.pth'),
),
lang_encoder=dict(
base=dict(
type='AutoModelForCausalLM',
name_or_path='huggyllama/llama-7b',
local_files_only=True),
adapter=dict(
type='FlamingoLMAdapter',
vis_hidden_size=1024,
cross_attn_every_n_layers=4,
use_media_placement_augmentation=False,
only_attend_previous=True,
),
),
task='vqa',
final_prompt_tmpl='<image>User:{question} GPT:<answer>',
generation_cfg=dict(
num_beams=3, max_new_tokens=24, no_repeat_ngram_size=3),
)
# data settings
data_preprocessor = dict(
type='MultiModalDataPreprocessor',
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=224,
interpolation='bicubic',
backend='pillow'),
dict(type='CenterCrop', crop_size=(224, 224)),
dict(
type='PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight', 'shots'],
meta_keys=['image_id'],
),
]
val_dataloader = dict(
batch_size=8,
num_workers=8,
dataset=dict(
type='FlamingoEvalCOCOVQA',
data_root='data/coco',
data_prefix='val2014',
question_file='annotations/v2_OpenEnded_mscoco_val2014_questions.json',
ann_file='annotations/v2_mscoco_val2014_annotations.json',
pipeline=test_pipeline,
num_shots=0,
num_support_examples=2048,
num_query_examples=5000,
),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='VQAAcc')
test_dataloader = dict(
batch_size=8,
num_workers=8,
dataset=dict(
type='FlamingoEvalCOCOVQA',
data_root='data/coco',
data_prefix='test2015',
question_file=
'annotations/v2_OpenEnded_mscoco_test-dev2015_questions.json',
pipeline=test_pipeline,
num_shots=0,
num_support_examples=2048,
num_query_examples=5000,
),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
test_evaluator = dict(type='ReportVQA', file_path='vqa_test-dev.json')
# schedule settings
val_cfg = dict()
test_cfg = dict()

View File

@ -144,6 +144,7 @@ Multi-Modality Algorithms
Flamingo
OFA
MiniGPT4
Otter
.. module:: mmpretrain.models.backbones

View File

@ -8,6 +8,7 @@ if WITH_MULTIMODAL:
from .flamingo import * # noqa: F401, F403
from .minigpt4 import * # noqa: F401, F403
from .ofa import * # noqa: F401, F403
from .otter import * # noqa: F401, F403
else:
from mmpretrain.registry import MODELS
from mmpretrain.utils.dependency import register_multimodal_placeholder
@ -15,5 +16,5 @@ else:
register_multimodal_placeholder([
'Blip2Caption', 'Blip2Retrieval', 'Blip2VQA', 'BlipCaption',
'BlipNLVR', 'BlipRetrieval', 'BlipGrounding', 'BlipVQA', 'Flamingo',
'OFA', 'ChineseCLIP', 'MiniGPT4'
'OFA', 'ChineseCLIP', 'MiniGPT4', 'Otter'
], MODELS)

View File

@ -19,6 +19,7 @@ class FlamingoLMAdapter:
vis_hidden_size: int,
cross_attn_every_n_layers: int,
use_media_placement_augmentation: bool,
only_attend_previous: bool = False,
):
"""Initialize Flamingo by adding a new gated cross attn to the decoder.
@ -48,6 +49,7 @@ class FlamingoLMAdapter:
]))
base.use_media_placement_augmentation = use_media_placement_augmentation # noqa
base.initialized_flamingo = True
base.only_attend_previous = only_attend_previous
return base
def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
@ -67,8 +69,12 @@ class FlamingoLMAdapter:
function."""
input_ids = kwargs['input_ids'] if 'input_ids' in kwargs else input[0]
media_locations = input_ids == self.media_token_id
attend_previous = ((random.random() < 0.5)
if self.use_media_placement_augmentation else False)
if self.only_attend_previous:
attend_previous = True
elif self.use_media_placement_augmentation:
attend_previous = (random.random() < 0.5)
else:
attend_previous = False
for layer in self.get_decoder().layers:
layer.condition_media_locations(media_locations)

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .otter import Otter
__all__ = ['Otter']

View File

@ -0,0 +1,140 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional
import torch
from mmpretrain.registry import MODELS, TOKENIZER
from mmpretrain.structures import DataSample
from ..flamingo.flamingo import ExtendModule, Flamingo, PerceiverResampler
@MODELS.register_module()
class Otter(Flamingo):
"""The Open Flamingo model for multiple tasks.
Args:
vision_encoder (dict): The config of the vision encoder.
lang_encoder (dict): The config of the language encoder.
tokenizer (dict): The tokenizer to encode the text.
task (int): The task to perform prediction.
shot_prompt_tmpl (str): Prompt used for few-shot inference.
Defaults to '<image>User:Please describe the image.
GPT:<answer>{caption}<|endofchunk|>'.
final_prompt_tmpl (str): Final part of prompt used for inference.
Defaults to '<image>User:Please describe the image. GPT:<answer>'.
generation_cfg (dict): The extra generation config, accept the keyword
arguments of [~`transformers.GenerationConfig`].
Defaults to an empty dict.
data_preprocessor (Optional[dict]): The config for preprocessing input
data. If None or no specified type, it will use
"MutimodalDataPreprocessor" as type.
See :class:`MutimodalDataPreprocessor` for more details.
Defaults to None.
init_cfg (dict, optional): The initialization config. Defaults to None.
"""
support_tasks = {'caption', 'vqa'}
_no_split_modules = [
'TransformerEncoderLayer', 'PerceiverAttention',
'GatedCrossAttentionBlock', 'FlamingoLayer'
]
def __init__(
self,
vision_encoder: dict,
lang_encoder: dict,
tokenizer: dict,
task: str = 'caption',
zeroshot_prompt: str = '',
shot_prompt_tmpl: str = ('<image>User:Please describe the image. '
'GPT:<answer>{caption}<|endofchunk|>'),
final_prompt_tmpl: str = ('<image>User:Please describe the image. '
'GPT:<answer>'),
generation_cfg: dict = dict(),
data_preprocessor: Optional[dict] = None,
init_cfg: Optional[dict] = None):
if data_preprocessor is None:
data_preprocessor = {}
if isinstance(data_preprocessor, dict):
data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor')
data_preprocessor = MODELS.build(data_preprocessor)
super(Flamingo, self).__init__(
init_cfg=init_cfg, data_preprocessor=data_preprocessor)
if task not in self.support_tasks:
raise ValueError(f'Unsupported task {task}, please select '
f'the task from {self.support_tasks}.')
self.task = task
# init tokenizer
self.tokenizer = TOKENIZER.build(tokenizer)
# add Flamingo special tokens to the tokenizer
self.tokenizer.add_special_tokens({
'additional_special_tokens':
['<|endofchunk|>', '<image>', '<answer>']
})
self.tokenizer.bos_token_id = 1
if self.tokenizer.pad_token is None:
# Issue: GPT models don't have a pad token, which we use to
# modify labels for the loss.
self.tokenizer.add_special_tokens({'pad_token': '<PAD>'})
# Template to format the prompt input
self.zeroshot_prompt = zeroshot_prompt
self.shot_prompt_tmpl = shot_prompt_tmpl
self.final_prompt_tmpl = final_prompt_tmpl
# init vision encoder related modules
vision_encoder_weight = vision_encoder.pop('pretrained', None)
self.vision_encoder = MODELS.build(vision_encoder)
if vision_encoder_weight is not None:
from mmengine.runner.checkpoint import load_checkpoint
load_checkpoint(
self.vision_encoder,
vision_encoder_weight,
map_location='cpu',
revise_keys=[(r'^backbone\.', '')],
)
self.perceiver = PerceiverResampler(dim=self.vision_encoder.embed_dims)
# init language encoder related modules
self.lang_encoder = ExtendModule(**lang_encoder)
self.lang_encoder.resize_token_embeddings(len(self.tokenizer))
self.lang_encoder.media_token_id = self.tokenizer.encode('<image>')[-1]
# other necessary parameters
self.eoc_token_id = self.tokenizer.encode('<|endofchunk|>')[-1]
self.generation_cfg = generation_cfg
if hasattr(self, 'register_load_state_dict_post_hook'):
self.register_load_state_dict_post_hook(self._load_adapter_hook)
def post_process(
self, outputs: torch.Tensor,
data_samples: Optional[List[DataSample]]) -> List[DataSample]:
"""Perform post process for outputs for different task.
Args:
outputs (torch.Tensor): The generated outputs.
data_samples (List[DataSample], optional): The annotation
data of every samples.
Returns:
List[DataSample]: Return list of data samples.
"""
outputs = self.tokenizer.batch_decode(
outputs, skip_special_tokens=True)
if data_samples is None:
data_samples = [DataSample() for _ in range(len(outputs))]
for output, data_sample in zip(outputs, data_samples):
# remove text pattern
if self.task == 'caption':
data_sample.pred_caption = output
elif self.task == 'vqa':
data_sample.pred_answer = output
return data_samples

View File

@ -79,3 +79,4 @@ Import:
- configs/itpn/metafile.yml
- configs/hivit/metafile.yml
- configs/minigpt4/metafile.yml
- configs/otter/metafile.yml

View File

@ -0,0 +1,66 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import re
from collections import OrderedDict
from itertools import chain
from pathlib import Path
import torch
prog_description = """\
Convert Official Otter HF models to MMPreTrain format.
"""
def parse_args():
parser = argparse.ArgumentParser(description=prog_description)
parser.add_argument(
'name_or_dir', type=str, help='The Otter HF model name or directory.')
args = parser.parse_args()
return args
def main():
args = parse_args()
if not Path(args.name_or_dir).is_dir():
from huggingface_hub import snapshot_download
ckpt_dir = Path(
snapshot_download(args.name_or_dir, allow_patterns='*.bin'))
name = args.name_or_dir.replace('/', '_')
else:
ckpt_dir = Path(args.name_or_dir)
name = ckpt_dir.name
state_dict = OrderedDict()
for k, v in chain.from_iterable(
torch.load(ckpt).items() for ckpt in ckpt_dir.glob('*.bin')):
adapter_patterns = [
r'^perceiver',
r'lang_encoder.*embed_tokens',
r'lang_encoder.*gated_cross_attn_layer',
r'lang_encoder.*rotary_emb',
]
if not any(re.match(pattern, k) for pattern in adapter_patterns):
# Drop encoder parameters to decrease the size.
continue
# The keys are different between Open-Flamingo and Otter
if 'gated_cross_attn_layer.feed_forward' in k:
k = k.replace('feed_forward', 'ff')
if 'perceiver.layers' in k:
prefix_match = re.match(r'perceiver.layers.\d+.', k)
prefix = k[:prefix_match.end()]
suffix = k[prefix_match.end():]
if 'feed_forward' in k:
k = prefix + '1.' + suffix.replace('feed_forward.', '')
else:
k = prefix + '0.' + suffix
state_dict[k] = v
if len(state_dict) == 0:
raise RuntimeError('No checkpoint found in the specified directory.')
torch.save(state_dict, name + '.pth')
if __name__ == '__main__':
main()