Ma Zerun 6847d20d57
[Feature] Support multiple multi-modal algorithms and inferencers. (#1561)
* [Feat] Migrate blip caption to mmpretrain. (#50)

* Migrate blip caption to mmpretrain

* minor fix

* support train

* [Feature] Support OFA caption task. (#51)

* [Feature] Support OFA caption task.

* Remove duplicated files.

* [Feature] Support OFA vqa task. (#58)

* [Feature] Support OFA vqa task.

* Fix lint.

* [Feat] Add BLIP retrieval to mmpretrain. (#55)

* init

* minor fix for train

* fix according to comments

* refactor

* Update Blip retrieval. (#62)

* [Feature] Support OFA visual grounding task. (#59)

* [Feature] Support OFA visual grounding task.

* minor add TODO

---------

Co-authored-by: yingfhu <yingfhu@gmail.com>

* [Feat] Add flamingos coco caption and vqa. (#60)

* first init

* init flamingo coco

* add vqa

* minor fix

* remove unnecessary modules

* Update config

* Use `ApplyToList`.

---------

Co-authored-by: mzr1996 <mzr1996@163.com>

* [Feature]: BLIP2 coco retrieval  (#53)

* [Feature]: Add blip2 retriever

* [Feature]: Add blip2 all modules

* [Feature]: Refine model

* [Feature]: x1

* [Feature]: Runnable coco ret

* [Feature]: Runnable version

* [Feature]: Fix lint

* [Fix]: Fix lint

* [Feature]: Use 364 img size

* [Feature]: Refactor blip2

* [Fix]: Fix lint

* refactor files

* minor fix

* minor fix

---------

Co-authored-by: yingfhu <yingfhu@gmail.com>

* Remove

* fix blip caption inputs (#68)

* [Feat] Add BLIP NLVR support. (#67)

* first init

* init flamingo coco

* add vqa

* add nlvr

* refactor nlvr

* minor fix

* minor fix

* Update dataset

---------

Co-authored-by: mzr1996 <mzr1996@163.com>

* [Feature]: BLIP2 Caption (#70)

* [Feature]: Add language model

* [Feature]: blip2 caption forward

* [Feature]: Reproduce the results

* [Feature]: Refactor caption

* refine config

---------

Co-authored-by: yingfhu <yingfhu@gmail.com>

* [Feat] Migrate BLIP VQA to mmpretrain (#69)

* reformat

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* refactor code

---------

Co-authored-by: yingfhu <yingfhu@gmail.com>

* Update RefCOCO dataset

* [Fix] fix lint

* [Feature] Implement inference APIs for multi-modal tasks. (#65)

* [Feature] Implement inference APIs for multi-modal tasks.

* [Project] Add gradio demo.

* [Improve] Update requirements

* Update flamingo

* Update blip

* Add NLVR inferencer

* Update flamingo

* Update hugging face model register

* Update ofa vqa

* Update BLIP-vqa (#71)

* Update blip-vqa docstring (#72)

* Refine flamingo docstring (#73)

* [Feature]: BLIP2 VQA (#61)

* [Feature]: VQA forward

* [Feature]: Reproduce accuracy

* [Fix]: Fix lint

* [Fix]: Add blank line

* minor fix

---------

Co-authored-by: yingfhu <yingfhu@gmail.com>

* [Feature]: BLIP2 docstring (#74)

* [Feature]: Add caption docstring

* [Feature]: Add docstring to blip2 vqa

* [Feature]: Add docstring to retrieval

* Update BLIP-2 metafile and README (#75)

* [Feature]: Add readme and docstring

* Update blip2 results

---------

Co-authored-by: mzr1996 <mzr1996@163.com>

* [Feature] BLIP Visual Grounding on MMPretrain Branch (#66)

* blip grounding merge with mmpretrain

* remove commit

* blip grounding test and inference api

* refcoco dataset

* refcoco dataset refine config

* rebasing

* gitignore

* rebasing

* minor edit

* minor edit

* Update blip-vqa docstring (#72)

* rebasing

* Revert "minor edit"

This reverts commit 639cec757c215e654625ed0979319e60f0be9044.

* blip grounding final

* precommit

* refine config

* refine config

* Update blip visual grounding

---------

Co-authored-by: Yiqin Wang 王逸钦 <wyq1217@outlook.com>
Co-authored-by: mzr1996 <mzr1996@163.com>

* Update visual grounding metric

* Update OFA docstring, README and metafiles. (#76)

* [Docs] Update installation docs and gradio demo docs. (#77)

* Update OFA name

* Update Visual Grounding Visualizer

* Integrate accelerate support

* Fix imports.

* Fix timm backbone

* Update imports

* Update README

* Update circle ci

* Update flamingo config

* Add gradio demo README

* [Feature]: Add scienceqa (#1571)

* [Feature]: Add scienceqa

* [Feature]: Change param name

* Update docs

* Update video

---------

Co-authored-by: Hubert <42952108+yingfhu@users.noreply.github.com>
Co-authored-by: yingfhu <yingfhu@gmail.com>
Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com>
Co-authored-by: Yiqin Wang 王逸钦 <wyq1217@outlook.com>
Co-authored-by: Rongjie Li <limo97@163.com>
2023-05-19 16:50:04 +08:00

112 lines
3.5 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import re
from collections import OrderedDict, namedtuple
from pathlib import Path
import torch
prog_description = """\
Convert OFA official models to MMPretrain format.
"""
MapItem = namedtuple(
'MapItem', 'pattern repl key_action value_action', defaults=[None] * 4)
def convert_by_mapdict(src_dict: dict, map_dict: Path):
dst_dict = OrderedDict()
convert_map_dict = dict()
for k, v in src_dict.items():
ori_k = k
for item in map_dict:
pattern = item.pattern
assert pattern is not None
match = next(re.finditer(pattern, k), None)
if match is None:
continue
match_group = match.groups()
repl = item.repl
key_action = item.key_action
if key_action is not None:
assert callable(key_action)
match_group = key_action(*match_group)
if isinstance(match_group, str):
match_group = (match_group, )
start, end = match.span(0)
if repl is not None:
k = k[:start] + repl.format(*match_group) + k[end:]
else:
for i, sub in enumerate(match_group):
start, end = match.span(i + 1)
k = k[:start] + str(sub) + k[end:]
value_action = item.value_action
if value_action is not None:
assert callable(value_action)
v = value_action(v)
if v is not None:
dst_dict[k] = v
convert_map_dict[k] = ori_k
return dst_dict, convert_map_dict
map_dict = [
# Encoder modules
MapItem(r'\.type_embedding\.', '.embed_type.'),
MapItem(r'\.layernorm_embedding\.', '.embedding_ln.'),
MapItem(r'\.patch_layernorm_embedding\.', '.image_embedding_ln.'),
MapItem(r'encoder.layer_norm\.', 'encoder.final_ln.'),
# Encoder layers
MapItem(r'\.attn_ln\.', '.attn_mid_ln.'),
MapItem(r'\.ffn_layernorm\.', '.ffn_mid_ln.'),
MapItem(r'\.final_layer_norm', '.ffn_ln'),
MapItem(r'encoder.*(\.self_attn\.)', key_action=lambda _: '.attn.'),
MapItem(
r'encoder.*(\.self_attn_layer_norm\.)',
key_action=lambda _: '.attn_ln.'),
# Decoder modules
MapItem(r'\.code_layernorm_embedding\.', '.code_embedding_ln.'),
MapItem(r'decoder.layer_norm\.', 'decoder.final_ln.'),
# Decoder layers
MapItem(r'\.self_attn_ln', '.self_attn_mid_ln'),
MapItem(r'\.cross_attn_ln', '.cross_attn_mid_ln'),
MapItem(r'\.encoder_attn_layer_norm', '.cross_attn_ln'),
MapItem(r'\.encoder_attn', '.cross_attn'),
MapItem(
r'decoder.*(\.self_attn_layer_norm\.)',
key_action=lambda _: '.self_attn_ln.'),
# Remove version key
MapItem(r'version', '', value_action=lambda _: None),
# Add model prefix
MapItem(r'^', 'model.'),
]
def parse_args():
parser = argparse.ArgumentParser(description=prog_description)
parser.add_argument('src', type=str, help='The official checkpoint path.')
parser.add_argument('dst', type=str, help='The save path.')
args = parser.parse_args()
return args
def main():
args = parse_args()
src = torch.load(args.src)
if 'extra_state' in src and 'ema' in src['extra_state']:
print('Use EMA weights.')
src = src['extra_state']['ema']
else:
src = src['model']
dst, _ = convert_by_mapdict(src, map_dict)
torch.save(dst, args.dst)
print('Done!!')
if __name__ == '__main__':
main()