mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
[Feature] Support otter (#1651)
* [Feature] Support Otter * Update docs
This commit is contained in:
parent
9d3fc43073
commit
e69bace03f
@ -256,6 +256,7 @@ Results and models are available in the [model zoo](https://mmpretrain.readthedo
|
||||
<li><a href="configs/flamingo">Flamingo (NeurIPS'2022)</a></li>
|
||||
<li><a href="configs/chinese_clip">Chinese CLIP (arxiv'2022)</a></li>
|
||||
<li><a href="configs/minigpt4">MiniGPT-4 (arxiv'2023)</a></li>
|
||||
<li><a href="configs/otter">Otter (arxiv'2023)</a></li>
|
||||
</ul>
|
||||
</td>
|
||||
<td>
|
||||
|
@ -252,6 +252,7 @@ mim install -e ".[multimodal]"
|
||||
<li><a href="configs/flamingo">Flamingo (NeurIPS'2022)</a></li>
|
||||
<li><a href="configs/chinese_clip">Chinese CLIP (arxiv'2022)</a></li>
|
||||
<li><a href="configs/minigpt4">MiniGPT-4 (arxiv'2023)</a></li>
|
||||
<li><a href="configs/otter">Otter (arxiv'2023)</a></li>
|
||||
</ul>
|
||||
</td>
|
||||
<td>
|
||||
|
78
configs/otter/README.md
Normal file
78
configs/otter/README.md
Normal file
@ -0,0 +1,78 @@
|
||||
# Otter
|
||||
|
||||
> [Otter: A Multi-Modal Model with In-Context Instruction Tuning](https://arxiv.org/abs/2305.03726)
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
Large language models (LLMs) have demonstrated significant universal capabilities as few/zero-shot learners in various tasks due to their pre-training on vast amounts of text data, as exemplified by GPT-3, which boosted to InstrctGPT and ChatGPT, effectively following natural language instructions to accomplish real-world tasks. In this paper, we propose to introduce instruction tuning into multi-modal models, motivated by the Flamingo model's upstream interleaved format pretraining dataset. We adopt a similar approach to construct our MultI-Modal In-Context Instruction Tuning (MIMIC-IT) dataset. We then introduce Otter, a multi-modal model based on OpenFlamingo (open-sourced version of DeepMind's Flamingo), trained on MIMIC-IT and showcasing improved instruction-following ability and in-context learning. We also optimize OpenFlamingo's implementation for researchers, democratizing the required training resources from 1$\times$ A100 GPU to 4$\times$ RTX-3090 GPUs, and integrate both OpenFlamingo and Otter into Huggingface Transformers for more researchers to incorporate the models into their customized training and inference pipelines.
|
||||
|
||||
<div align=center>
|
||||
<img src="https://camo.githubusercontent.com/70613ab882a7827808148a2c577029d544371e707b0832a0b01151c54ce553c3/68747470733a2f2f692e706f7374696d672e63632f5477315a304243572f6f7474657276302d322d64656d6f2e706e67" width="80%"/>
|
||||
</div>
|
||||
|
||||
## How to use it?
|
||||
|
||||
<!-- [TABS-BEGIN] -->
|
||||
|
||||
**Use the model**
|
||||
|
||||
```python
|
||||
import torch
|
||||
from mmpretrain import get_model, inference_model
|
||||
|
||||
model = get_model('otter-9b_3rdparty_caption', pretrained=True, device='cuda')
|
||||
out = inference_model(model, 'demo/cat-dog.png')
|
||||
print(out)
|
||||
```
|
||||
|
||||
**Test Command**
|
||||
|
||||
Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset).
|
||||
|
||||
Test:
|
||||
|
||||
```shell
|
||||
python tools/test.py configs/otter/otter-9b_caption.py https://download.openmmlab.com/mmclassification/v1/otter/otter-9b-adapter_20230613-51c5be8d.pth
|
||||
```
|
||||
|
||||
<!-- [TABS-END] -->
|
||||
|
||||
## Models and results
|
||||
|
||||
### Image Caption on COCO
|
||||
|
||||
| Model | Pretrain | Params (M) | BLEU-4 | CIDER | Config | Download |
|
||||
| :---------------------------- | :----------: | :--------: | :------: | :------: | :---------------------------: | :------------------------------------------------------------------------------------------------------: |
|
||||
| `otter-9b_3rdparty_caption`\* | From scratch | 8220.45 | Upcoming | Upcoming | [config](otter-9b_caption.py) | [model](https://download.openmmlab.com/mmclassification/v1/otter/otter-9b-adapter_20230613-51c5be8d.pth) |
|
||||
|
||||
*Models with * are converted from the [official repo](https://github.com/Luodian/Otter/tree/main). The config files of these models are only for inference. We haven't reprodcue the training results.*
|
||||
|
||||
### Visual Question Answering on VQAv2
|
||||
|
||||
| Model | Pretrain | Params (M) | Accuracy | Config | Download |
|
||||
| :------------------------ | :----------: | :--------: | :------: | :-----------------------: | :------------------------------------------------------------------------------------------------------: |
|
||||
| `otter-9b_3rdparty_vqa`\* | From scratch | 8220.45 | Upcoming | [config](otter-9b_vqa.py) | [model](https://download.openmmlab.com/mmclassification/v1/otter/otter-9b-adapter_20230613-51c5be8d.pth) |
|
||||
|
||||
*Models with * are converted from the [official repo](https://github.com/Luodian/Otter/tree/main). The config files of these models are only for inference. We haven't reprodcue the training results.*
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{li2023otter,
|
||||
title={Otter: A Multi-Modal Model with In-Context Instruction Tuning},
|
||||
author={Li, Bo and Zhang, Yuanhan and Chen, Liangyu and Wang, Jinghao and Yang, Jingkang and Liu, Ziwei},
|
||||
journal={arXiv preprint arXiv:2305.03726},
|
||||
year={2023}
|
||||
}
|
||||
|
||||
@article{li2023mimicit,
|
||||
title={MIMIC-IT: Multi-Modal In-Context Instruction Tuning},
|
||||
author={Bo Li and Yuanhan Zhang and Liangyu Chen and Jinghao Wang and Fanyi Pu and Jingkang Yang and Chunyuan Li and Ziwei Liu},
|
||||
year={2023},
|
||||
eprint={2306.05425},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CV}
|
||||
}
|
||||
```
|
43
configs/otter/metafile.yml
Normal file
43
configs/otter/metafile.yml
Normal file
@ -0,0 +1,43 @@
|
||||
Collections:
|
||||
- Name: Otter
|
||||
Metadata:
|
||||
Architecture:
|
||||
- Transformer
|
||||
- Gated Cross-Attention Dense
|
||||
Paper:
|
||||
Title: 'Otter: A Multi-Modal Model with In-Context Instruction Tuning'
|
||||
URL: https://arxiv.org/abs/2305.03726
|
||||
README: configs/otter/README.md
|
||||
|
||||
Models:
|
||||
- Name: otter-9b_3rdparty_caption
|
||||
Metadata:
|
||||
FLOPs: null
|
||||
Parameters: 8220452880
|
||||
In Collection: Otter
|
||||
Results:
|
||||
- Task: Image Caption
|
||||
Dataset: COCO
|
||||
Metrics:
|
||||
BLEU-4: null
|
||||
CIDER: null
|
||||
Weights: https://download.openmmlab.com/mmclassification/v1/otter/otter-9b-adapter_20230613-51c5be8d.pth
|
||||
Config: configs/otter/otter-9b_caption.py
|
||||
Converted From:
|
||||
Weights: https://huggingface.co/luodian/otter-9b-hf
|
||||
Code: https://github.com/Luodian/Otter/tree/main
|
||||
- Name: otter-9b_3rdparty_vqa
|
||||
Metadata:
|
||||
FLOPs: null
|
||||
Parameters: 8220452880
|
||||
In Collection: Otter
|
||||
Results:
|
||||
- Task: Visual Question Answering
|
||||
Dataset: VQAv2
|
||||
Metrics:
|
||||
Accuracy: null
|
||||
Weights: https://download.openmmlab.com/mmclassification/v1/otter/otter-9b-adapter_20230613-51c5be8d.pth
|
||||
Config: configs/otter/otter-9b_vqa.py
|
||||
Converted From:
|
||||
Weights: https://huggingface.co/luodian/otter-9b-hf
|
||||
Code: https://github.com/Luodian/Otter/tree/main
|
91
configs/otter/otter-9b_caption.py
Normal file
91
configs/otter/otter-9b_caption.py
Normal file
@ -0,0 +1,91 @@
|
||||
_base_ = [
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='Otter',
|
||||
tokenizer=dict(type='LlamaTokenizer', name_or_path='huggyllama/llama-7b'),
|
||||
vision_encoder=dict(
|
||||
type='VisionTransformer',
|
||||
arch='l',
|
||||
patch_size=14,
|
||||
pre_norm=True,
|
||||
norm_cfg=dict(type='LN', eps=1e-5),
|
||||
layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')),
|
||||
final_norm=False,
|
||||
out_type='raw',
|
||||
pretrained=(
|
||||
'https://download.openmmlab.com/mmclassification/v0/clip/'
|
||||
'vit-large-p14_clip-openai-pre_3rdparty_20230517-95e2af0b.pth'),
|
||||
),
|
||||
lang_encoder=dict(
|
||||
base=dict(
|
||||
type='AutoModelForCausalLM',
|
||||
name_or_path='huggyllama/llama-7b',
|
||||
local_files_only=True),
|
||||
adapter=dict(
|
||||
type='FlamingoLMAdapter',
|
||||
vis_hidden_size=1024,
|
||||
cross_attn_every_n_layers=4,
|
||||
use_media_placement_augmentation=False,
|
||||
only_attend_previous=True,
|
||||
),
|
||||
),
|
||||
task='caption',
|
||||
final_prompt_tmpl='<image>User:Please describe the image. GPT:<answer>',
|
||||
generation_cfg=dict(
|
||||
num_beams=3, max_new_tokens=24, no_repeat_ngram_size=3),
|
||||
)
|
||||
|
||||
# data settings
|
||||
data_preprocessor = dict(
|
||||
type='MultiModalDataPreprocessor',
|
||||
mean=[122.770938, 116.7460125, 104.09373615],
|
||||
std=[68.5005327, 66.6321579, 70.32316305],
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeEdge',
|
||||
scale=224,
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='CenterCrop', crop_size=(224, 224)),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['gt_caption'],
|
||||
meta_keys=['image_id'],
|
||||
),
|
||||
]
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=8,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='FlamingoEvalCOCOCaption',
|
||||
data_root='data/coco',
|
||||
ann_file='annotations/captions_train2014.json',
|
||||
data_prefix=dict(img_path='train2014'),
|
||||
pipeline=test_pipeline,
|
||||
num_shots=0,
|
||||
num_support_examples=2048,
|
||||
num_query_examples=5000,
|
||||
),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_evaluator = dict(
|
||||
type='COCOCaption',
|
||||
ann_file='data/coco/annotations/captions_train2014.json')
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
||||
|
||||
# schedule settings
|
||||
val_cfg = dict()
|
||||
test_cfg = dict()
|
104
configs/otter/otter-9b_vqa.py
Normal file
104
configs/otter/otter-9b_vqa.py
Normal file
@ -0,0 +1,104 @@
|
||||
_base_ = [
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='Otter',
|
||||
tokenizer=dict(type='LlamaTokenizer', name_or_path='huggyllama/llama-7b'),
|
||||
vision_encoder=dict(
|
||||
type='VisionTransformer',
|
||||
arch='l',
|
||||
patch_size=14,
|
||||
pre_norm=True,
|
||||
norm_cfg=dict(type='LN', eps=1e-5),
|
||||
layer_cfgs=dict(act_cfg=dict(type='QuickGELU')),
|
||||
final_norm=False,
|
||||
out_type='raw',
|
||||
pretrained=(
|
||||
'https://download.openmmlab.com/mmclassification/v0/clip/'
|
||||
'vit-large-p14_clip-openai-pre_3rdparty_20230517-95e2af0b.pth'),
|
||||
),
|
||||
lang_encoder=dict(
|
||||
base=dict(
|
||||
type='AutoModelForCausalLM',
|
||||
name_or_path='huggyllama/llama-7b',
|
||||
local_files_only=True),
|
||||
adapter=dict(
|
||||
type='FlamingoLMAdapter',
|
||||
vis_hidden_size=1024,
|
||||
cross_attn_every_n_layers=4,
|
||||
use_media_placement_augmentation=False,
|
||||
only_attend_previous=True,
|
||||
),
|
||||
),
|
||||
task='vqa',
|
||||
final_prompt_tmpl='<image>User:{question} GPT:<answer>',
|
||||
generation_cfg=dict(
|
||||
num_beams=3, max_new_tokens=24, no_repeat_ngram_size=3),
|
||||
)
|
||||
|
||||
# data settings
|
||||
data_preprocessor = dict(
|
||||
type='MultiModalDataPreprocessor',
|
||||
mean=[122.770938, 116.7460125, 104.09373615],
|
||||
std=[68.5005327, 66.6321579, 70.32316305],
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeEdge',
|
||||
scale=224,
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='CenterCrop', crop_size=(224, 224)),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight', 'shots'],
|
||||
meta_keys=['image_id'],
|
||||
),
|
||||
]
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=8,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='FlamingoEvalCOCOVQA',
|
||||
data_root='data/coco',
|
||||
data_prefix='val2014',
|
||||
question_file='annotations/v2_OpenEnded_mscoco_val2014_questions.json',
|
||||
ann_file='annotations/v2_mscoco_val2014_annotations.json',
|
||||
pipeline=test_pipeline,
|
||||
num_shots=0,
|
||||
num_support_examples=2048,
|
||||
num_query_examples=5000,
|
||||
),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='VQAAcc')
|
||||
|
||||
test_dataloader = dict(
|
||||
batch_size=8,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='FlamingoEvalCOCOVQA',
|
||||
data_root='data/coco',
|
||||
data_prefix='test2015',
|
||||
question_file=
|
||||
'annotations/v2_OpenEnded_mscoco_test-dev2015_questions.json',
|
||||
pipeline=test_pipeline,
|
||||
num_shots=0,
|
||||
num_support_examples=2048,
|
||||
num_query_examples=5000,
|
||||
),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
test_evaluator = dict(type='ReportVQA', file_path='vqa_test-dev.json')
|
||||
|
||||
# schedule settings
|
||||
val_cfg = dict()
|
||||
test_cfg = dict()
|
@ -144,6 +144,7 @@ Multi-Modality Algorithms
|
||||
Flamingo
|
||||
OFA
|
||||
MiniGPT4
|
||||
Otter
|
||||
|
||||
.. module:: mmpretrain.models.backbones
|
||||
|
||||
|
@ -8,6 +8,7 @@ if WITH_MULTIMODAL:
|
||||
from .flamingo import * # noqa: F401, F403
|
||||
from .minigpt4 import * # noqa: F401, F403
|
||||
from .ofa import * # noqa: F401, F403
|
||||
from .otter import * # noqa: F401, F403
|
||||
else:
|
||||
from mmpretrain.registry import MODELS
|
||||
from mmpretrain.utils.dependency import register_multimodal_placeholder
|
||||
@ -15,5 +16,5 @@ else:
|
||||
register_multimodal_placeholder([
|
||||
'Blip2Caption', 'Blip2Retrieval', 'Blip2VQA', 'BlipCaption',
|
||||
'BlipNLVR', 'BlipRetrieval', 'BlipGrounding', 'BlipVQA', 'Flamingo',
|
||||
'OFA', 'ChineseCLIP', 'MiniGPT4'
|
||||
'OFA', 'ChineseCLIP', 'MiniGPT4', 'Otter'
|
||||
], MODELS)
|
||||
|
@ -19,6 +19,7 @@ class FlamingoLMAdapter:
|
||||
vis_hidden_size: int,
|
||||
cross_attn_every_n_layers: int,
|
||||
use_media_placement_augmentation: bool,
|
||||
only_attend_previous: bool = False,
|
||||
):
|
||||
"""Initialize Flamingo by adding a new gated cross attn to the decoder.
|
||||
|
||||
@ -48,6 +49,7 @@ class FlamingoLMAdapter:
|
||||
]))
|
||||
base.use_media_placement_augmentation = use_media_placement_augmentation # noqa
|
||||
base.initialized_flamingo = True
|
||||
base.only_attend_previous = only_attend_previous
|
||||
return base
|
||||
|
||||
def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
|
||||
@ -67,8 +69,12 @@ class FlamingoLMAdapter:
|
||||
function."""
|
||||
input_ids = kwargs['input_ids'] if 'input_ids' in kwargs else input[0]
|
||||
media_locations = input_ids == self.media_token_id
|
||||
attend_previous = ((random.random() < 0.5)
|
||||
if self.use_media_placement_augmentation else False)
|
||||
if self.only_attend_previous:
|
||||
attend_previous = True
|
||||
elif self.use_media_placement_augmentation:
|
||||
attend_previous = (random.random() < 0.5)
|
||||
else:
|
||||
attend_previous = False
|
||||
|
||||
for layer in self.get_decoder().layers:
|
||||
layer.condition_media_locations(media_locations)
|
||||
|
4
mmpretrain/models/multimodal/otter/__init__.py
Normal file
4
mmpretrain/models/multimodal/otter/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .otter import Otter
|
||||
|
||||
__all__ = ['Otter']
|
140
mmpretrain/models/multimodal/otter/otter.py
Normal file
140
mmpretrain/models/multimodal/otter/otter.py
Normal file
@ -0,0 +1,140 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from mmpretrain.registry import MODELS, TOKENIZER
|
||||
from mmpretrain.structures import DataSample
|
||||
from ..flamingo.flamingo import ExtendModule, Flamingo, PerceiverResampler
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class Otter(Flamingo):
|
||||
"""The Open Flamingo model for multiple tasks.
|
||||
|
||||
Args:
|
||||
vision_encoder (dict): The config of the vision encoder.
|
||||
lang_encoder (dict): The config of the language encoder.
|
||||
tokenizer (dict): The tokenizer to encode the text.
|
||||
task (int): The task to perform prediction.
|
||||
shot_prompt_tmpl (str): Prompt used for few-shot inference.
|
||||
Defaults to '<image>User:Please describe the image.
|
||||
GPT:<answer>{caption}<|endofchunk|>'.
|
||||
final_prompt_tmpl (str): Final part of prompt used for inference.
|
||||
Defaults to '<image>User:Please describe the image. GPT:<answer>'.
|
||||
generation_cfg (dict): The extra generation config, accept the keyword
|
||||
arguments of [~`transformers.GenerationConfig`].
|
||||
Defaults to an empty dict.
|
||||
data_preprocessor (Optional[dict]): The config for preprocessing input
|
||||
data. If None or no specified type, it will use
|
||||
"MutimodalDataPreprocessor" as type.
|
||||
See :class:`MutimodalDataPreprocessor` for more details.
|
||||
Defaults to None.
|
||||
init_cfg (dict, optional): The initialization config. Defaults to None.
|
||||
"""
|
||||
|
||||
support_tasks = {'caption', 'vqa'}
|
||||
_no_split_modules = [
|
||||
'TransformerEncoderLayer', 'PerceiverAttention',
|
||||
'GatedCrossAttentionBlock', 'FlamingoLayer'
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_encoder: dict,
|
||||
lang_encoder: dict,
|
||||
tokenizer: dict,
|
||||
task: str = 'caption',
|
||||
zeroshot_prompt: str = '',
|
||||
shot_prompt_tmpl: str = ('<image>User:Please describe the image. '
|
||||
'GPT:<answer>{caption}<|endofchunk|>'),
|
||||
final_prompt_tmpl: str = ('<image>User:Please describe the image. '
|
||||
'GPT:<answer>'),
|
||||
generation_cfg: dict = dict(),
|
||||
data_preprocessor: Optional[dict] = None,
|
||||
init_cfg: Optional[dict] = None):
|
||||
if data_preprocessor is None:
|
||||
data_preprocessor = {}
|
||||
if isinstance(data_preprocessor, dict):
|
||||
data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor')
|
||||
data_preprocessor = MODELS.build(data_preprocessor)
|
||||
|
||||
super(Flamingo, self).__init__(
|
||||
init_cfg=init_cfg, data_preprocessor=data_preprocessor)
|
||||
|
||||
if task not in self.support_tasks:
|
||||
raise ValueError(f'Unsupported task {task}, please select '
|
||||
f'the task from {self.support_tasks}.')
|
||||
self.task = task
|
||||
|
||||
# init tokenizer
|
||||
self.tokenizer = TOKENIZER.build(tokenizer)
|
||||
# add Flamingo special tokens to the tokenizer
|
||||
self.tokenizer.add_special_tokens({
|
||||
'additional_special_tokens':
|
||||
['<|endofchunk|>', '<image>', '<answer>']
|
||||
})
|
||||
self.tokenizer.bos_token_id = 1
|
||||
if self.tokenizer.pad_token is None:
|
||||
# Issue: GPT models don't have a pad token, which we use to
|
||||
# modify labels for the loss.
|
||||
self.tokenizer.add_special_tokens({'pad_token': '<PAD>'})
|
||||
|
||||
# Template to format the prompt input
|
||||
self.zeroshot_prompt = zeroshot_prompt
|
||||
self.shot_prompt_tmpl = shot_prompt_tmpl
|
||||
self.final_prompt_tmpl = final_prompt_tmpl
|
||||
|
||||
# init vision encoder related modules
|
||||
vision_encoder_weight = vision_encoder.pop('pretrained', None)
|
||||
self.vision_encoder = MODELS.build(vision_encoder)
|
||||
if vision_encoder_weight is not None:
|
||||
from mmengine.runner.checkpoint import load_checkpoint
|
||||
load_checkpoint(
|
||||
self.vision_encoder,
|
||||
vision_encoder_weight,
|
||||
map_location='cpu',
|
||||
revise_keys=[(r'^backbone\.', '')],
|
||||
)
|
||||
|
||||
self.perceiver = PerceiverResampler(dim=self.vision_encoder.embed_dims)
|
||||
|
||||
# init language encoder related modules
|
||||
self.lang_encoder = ExtendModule(**lang_encoder)
|
||||
self.lang_encoder.resize_token_embeddings(len(self.tokenizer))
|
||||
self.lang_encoder.media_token_id = self.tokenizer.encode('<image>')[-1]
|
||||
|
||||
# other necessary parameters
|
||||
self.eoc_token_id = self.tokenizer.encode('<|endofchunk|>')[-1]
|
||||
self.generation_cfg = generation_cfg
|
||||
|
||||
if hasattr(self, 'register_load_state_dict_post_hook'):
|
||||
self.register_load_state_dict_post_hook(self._load_adapter_hook)
|
||||
|
||||
def post_process(
|
||||
self, outputs: torch.Tensor,
|
||||
data_samples: Optional[List[DataSample]]) -> List[DataSample]:
|
||||
"""Perform post process for outputs for different task.
|
||||
|
||||
Args:
|
||||
outputs (torch.Tensor): The generated outputs.
|
||||
data_samples (List[DataSample], optional): The annotation
|
||||
data of every samples.
|
||||
|
||||
Returns:
|
||||
List[DataSample]: Return list of data samples.
|
||||
"""
|
||||
outputs = self.tokenizer.batch_decode(
|
||||
outputs, skip_special_tokens=True)
|
||||
|
||||
if data_samples is None:
|
||||
data_samples = [DataSample() for _ in range(len(outputs))]
|
||||
|
||||
for output, data_sample in zip(outputs, data_samples):
|
||||
# remove text pattern
|
||||
if self.task == 'caption':
|
||||
data_sample.pred_caption = output
|
||||
elif self.task == 'vqa':
|
||||
data_sample.pred_answer = output
|
||||
|
||||
return data_samples
|
@ -79,3 +79,4 @@ Import:
|
||||
- configs/itpn/metafile.yml
|
||||
- configs/hivit/metafile.yml
|
||||
- configs/minigpt4/metafile.yml
|
||||
- configs/otter/metafile.yml
|
||||
|
66
tools/model_converters/otter2mmpre.py
Normal file
66
tools/model_converters/otter2mmpre.py
Normal file
@ -0,0 +1,66 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
prog_description = """\
|
||||
Convert Official Otter HF models to MMPreTrain format.
|
||||
"""
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description=prog_description)
|
||||
parser.add_argument(
|
||||
'name_or_dir', type=str, help='The Otter HF model name or directory.')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
if not Path(args.name_or_dir).is_dir():
|
||||
from huggingface_hub import snapshot_download
|
||||
ckpt_dir = Path(
|
||||
snapshot_download(args.name_or_dir, allow_patterns='*.bin'))
|
||||
name = args.name_or_dir.replace('/', '_')
|
||||
else:
|
||||
ckpt_dir = Path(args.name_or_dir)
|
||||
name = ckpt_dir.name
|
||||
|
||||
state_dict = OrderedDict()
|
||||
for k, v in chain.from_iterable(
|
||||
torch.load(ckpt).items() for ckpt in ckpt_dir.glob('*.bin')):
|
||||
adapter_patterns = [
|
||||
r'^perceiver',
|
||||
r'lang_encoder.*embed_tokens',
|
||||
r'lang_encoder.*gated_cross_attn_layer',
|
||||
r'lang_encoder.*rotary_emb',
|
||||
]
|
||||
if not any(re.match(pattern, k) for pattern in adapter_patterns):
|
||||
# Drop encoder parameters to decrease the size.
|
||||
continue
|
||||
|
||||
# The keys are different between Open-Flamingo and Otter
|
||||
if 'gated_cross_attn_layer.feed_forward' in k:
|
||||
k = k.replace('feed_forward', 'ff')
|
||||
if 'perceiver.layers' in k:
|
||||
prefix_match = re.match(r'perceiver.layers.\d+.', k)
|
||||
prefix = k[:prefix_match.end()]
|
||||
suffix = k[prefix_match.end():]
|
||||
if 'feed_forward' in k:
|
||||
k = prefix + '1.' + suffix.replace('feed_forward.', '')
|
||||
else:
|
||||
k = prefix + '0.' + suffix
|
||||
state_dict[k] = v
|
||||
if len(state_dict) == 0:
|
||||
raise RuntimeError('No checkpoint found in the specified directory.')
|
||||
|
||||
torch.save(state_dict, name + '.pth')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user