[Feature] Support GLIP (#1308)
* rebase * add glip * update glip * add links * rename * fix doc --------- Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>pull/1518/head
parent
2c913020b9
commit
fec3da781f
|
@ -200,6 +200,7 @@ Results and models are available in the [model zoo](https://mmpretrain.readthedo
|
||||||
<li><a href="configs/xcit">XCiT</a></li>
|
<li><a href="configs/xcit">XCiT</a></li>
|
||||||
<li><a href="configs/levit">LeViT</a></li>
|
<li><a href="configs/levit">LeViT</a></li>
|
||||||
<li><a href="configs/riformer">RIFormer</a></li>
|
<li><a href="configs/riformer">RIFormer</a></li>
|
||||||
|
<li><a href="configs/glip">GLIP</a></li>
|
||||||
</ul>
|
</ul>
|
||||||
</td>
|
</td>
|
||||||
<td>
|
<td>
|
||||||
|
|
|
@ -196,6 +196,7 @@ mim install -e .
|
||||||
<li><a href="configs/xcit">XCiT</a></li>
|
<li><a href="configs/xcit">XCiT</a></li>
|
||||||
<li><a href="configs/levit">LeViT</a></li>
|
<li><a href="configs/levit">LeViT</a></li>
|
||||||
<li><a href="configs/riformer">RIFormer</a></li>
|
<li><a href="configs/riformer">RIFormer</a></li>
|
||||||
|
<li><a href="configs/glip">GLIP</a></li>
|
||||||
</ul>
|
</ul>
|
||||||
</td>
|
</td>
|
||||||
<td>
|
<td>
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
# GLIP
|
||||||
|
|
||||||
|
> [Grounded Language-Image Pre-training](https://arxiv.org/abs/2112.03857)
|
||||||
|
|
||||||
|
<!-- [ALGORITHM] -->
|
||||||
|
|
||||||
|
## Abstract
|
||||||
|
|
||||||
|
This paper presents a grounded language-image pre-training (GLIP) model for learning object-level, language-aware, and semantic-rich visual representations. GLIP unifies object detection and phrase grounding for pre-training. The unification brings two benefits: 1) it allows GLIP to learn from both detection and grounding data to improve both tasks and bootstrap a good grounding model; 2) GLIP can leverage massive image-text pairs by generating grounding boxes in a self-training fashion, making the learned representation semantic-rich. In our experiments, we pre-train GLIP on 27M grounding data, including 3M human-annotated and 24M web-crawled image-text pairs. The learned representations demonstrate strong zero-shot and few-shot transferability to various object-level recognition tasks. 1) When directly evaluated on COCO and LVIS (without seeing any images in COCO during pre-training), GLIP achieves 49.8 AP and 26.9 AP, respectively, surpassing many supervised baselines. 2) After fine-tuned on COCO, GLIP achieves 60.8 AP on val and 61.5 AP on test-dev, surpassing prior SoTA. 3) When transferred to 13 downstream object detection tasks, a 1-shot GLIP rivals with a fully-supervised Dynamic Head.
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<img src="https://github.com/microsoft/GLIP/blob/main/docs/lead.png" width="70%"/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
## How to use it?
|
||||||
|
|
||||||
|
<!-- [TABS-BEGIN] -->
|
||||||
|
|
||||||
|
**Use the model**
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from mmpretrain import get_model
|
||||||
|
model = get_model('swin-t_glip-pre_3rdparty', pretrained=True)
|
||||||
|
inputs = torch.rand(1, 3, 224, 224)
|
||||||
|
out = model(inputs)
|
||||||
|
print(type(out))
|
||||||
|
# To extract features.
|
||||||
|
feats = model.extract_feat(inputs)
|
||||||
|
print(type(feats))
|
||||||
|
```
|
||||||
|
|
||||||
|
<!-- [TABS-END] -->
|
||||||
|
|
||||||
|
## Results and models
|
||||||
|
|
||||||
|
### Pre-trained models
|
||||||
|
|
||||||
|
The pre-trained models are used to fine-tune, and therefore don't have evaluation results.
|
||||||
|
|
||||||
|
| Model | Pretrain | resolution | Download |
|
||||||
|
| :------------------------------------------ | :------------------------: | :--------: | :-------------------------------------------------------------------------------------------------------------------: |
|
||||||
|
| GLIP-T (`swin-t_glip-pre_3rdparty`)\* | O365,GoldG,CC3M,SBU | 224x224 | [model](https://download.openmmlab.com/mmclassification/v1/glip/swin-t_glip-pre_3rdparty_20230413-d85813b5.pth) |
|
||||||
|
| GLIP-L (`swin-l_glip-pre_3rdparty_384px`)\* | FourODs,GoldG,CC3M+12M,SBU | 384x384 | [model](https://download.openmmlab.com/mmclassification/v1/glip/swin-l_glip-pre_3rdparty_384px_20230413-04b198e8.pth) |
|
||||||
|
|
||||||
|
*Models with * are converted from the [official repo](https://github.com/microsoft/GLIP).*
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@inproceedings{li2021grounded,
|
||||||
|
title={Grounded Language-Image Pre-training},
|
||||||
|
author={Liunian Harold Li* and Pengchuan Zhang* and Haotian Zhang* and Jianwei Yang and Chunyuan Li and Yiwu Zhong and Lijuan Wang and Lu Yuan and Lei Zhang and Jenq-Neng Hwang and Kai-Wei Chang and Jianfeng Gao},
|
||||||
|
year={2022},
|
||||||
|
booktitle={CVPR},
|
||||||
|
}
|
||||||
|
```
|
|
@ -0,0 +1,18 @@
|
||||||
|
model = dict(
|
||||||
|
type='ImageClassifier',
|
||||||
|
backbone=dict(
|
||||||
|
type='SwinTransformer',
|
||||||
|
arch='large',
|
||||||
|
img_size=384,
|
||||||
|
out_indices=(1, 2, 3), # original weight is for detection
|
||||||
|
stage_cfgs=dict(block_cfgs=dict(window_size=12))),
|
||||||
|
neck=None,
|
||||||
|
head=None)
|
||||||
|
|
||||||
|
data_preprocessor = dict(
|
||||||
|
# RGB format normalization parameters
|
||||||
|
mean=[103.53, 116.28, 123.675],
|
||||||
|
std=[57.375, 57.12, 58.395],
|
||||||
|
# convert image from BGR to RGB
|
||||||
|
to_rgb=False,
|
||||||
|
)
|
|
@ -0,0 +1,18 @@
|
||||||
|
model = dict(
|
||||||
|
type='ImageClassifier',
|
||||||
|
backbone=dict(
|
||||||
|
type='SwinTransformer',
|
||||||
|
arch='tiny',
|
||||||
|
img_size=224,
|
||||||
|
out_indices=(1, 2, 3), # original weight is for detection
|
||||||
|
),
|
||||||
|
neck=None,
|
||||||
|
head=None)
|
||||||
|
|
||||||
|
data_preprocessor = dict(
|
||||||
|
# RGB format normalization parameters
|
||||||
|
mean=[103.53, 116.28, 123.675],
|
||||||
|
std=[57.375, 57.12, 58.395],
|
||||||
|
# convert image from BGR to RGB
|
||||||
|
to_rgb=False,
|
||||||
|
)
|
|
@ -0,0 +1,49 @@
|
||||||
|
Collections:
|
||||||
|
- Name: GLIP
|
||||||
|
Metadata:
|
||||||
|
Training Techniques:
|
||||||
|
- AdamW
|
||||||
|
- Weight Decay
|
||||||
|
Architecture:
|
||||||
|
- Shift Window Multihead Self Attention
|
||||||
|
Paper:
|
||||||
|
URL: https://arxiv.org/abs/2112.03857
|
||||||
|
Title: "Grounded Language-Image Pre-training"
|
||||||
|
README: configs/glip/README.md
|
||||||
|
Code:
|
||||||
|
URL: https://github.com/open-mmlab/mmpretrain/blob/main/mmpretrain/models/backbones/vit.py
|
||||||
|
Version: v1.0.0rc8
|
||||||
|
|
||||||
|
Models:
|
||||||
|
- Name: swin-t_glip-pre_3rdparty
|
||||||
|
In Collection: GLIP
|
||||||
|
Metadata:
|
||||||
|
FLOPs: 4508464128
|
||||||
|
Parameters: 29056354
|
||||||
|
Training Data:
|
||||||
|
- O365
|
||||||
|
- GoldG
|
||||||
|
- CC3M
|
||||||
|
- SBU
|
||||||
|
Results: null
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v1/glip/swin-t_glip-pre_3rdparty_20230413-d85813b5.pth
|
||||||
|
Converted From:
|
||||||
|
Weights: https://penzhanwu2bbs.blob.core.windows.net/data/GLIPv1_Open/models/glip_tiny_model_o365_goldg_cc_sbu.pth
|
||||||
|
Code: https://github.com/microsoft/GLIP
|
||||||
|
Config: configs/glip/glip-t_headless.py
|
||||||
|
- Name: swin-l_glip-pre_3rdparty_384px
|
||||||
|
In Collection: GLIP
|
||||||
|
Metadata:
|
||||||
|
FLOPs: 104080343040
|
||||||
|
Parameters: 196735516
|
||||||
|
Training Data:
|
||||||
|
- FourODs
|
||||||
|
- GoldG
|
||||||
|
- CC3M+12M
|
||||||
|
- SBU
|
||||||
|
Results: null
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v1/glip/swin-l_glip-pre_3rdparty_384px_20230413-04b198e8.pth
|
||||||
|
Converted From:
|
||||||
|
Weights: https://penzhanwu2bbs.blob.core.windows.net/data/GLIPv1_Open/models/glip_large_model.pth
|
||||||
|
Code: https://github.com/microsoft/GLIP
|
||||||
|
Config: configs/glip/glip-l_headless.py
|
|
@ -68,3 +68,4 @@ Import:
|
||||||
- configs/milan/metafile.yml
|
- configs/milan/metafile.yml
|
||||||
- configs/riformer/metafile.yml
|
- configs/riformer/metafile.yml
|
||||||
- configs/sam/metafile.yml
|
- configs/sam/metafile.yml
|
||||||
|
- configs/glip/metafile.yml
|
||||||
|
|
|
@ -0,0 +1,76 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import argparse
|
||||||
|
import os.path as osp
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import mmengine
|
||||||
|
import torch
|
||||||
|
from mmengine.runner import CheckpointLoader
|
||||||
|
|
||||||
|
|
||||||
|
def convert_glip(ckpt):
|
||||||
|
|
||||||
|
def correct_unfold_reduction_order(x):
|
||||||
|
out_channel, in_channel = x.shape
|
||||||
|
x = x.reshape(out_channel, 4, in_channel // 4)
|
||||||
|
x = x[:, [0, 2, 1, 3], :].transpose(1,
|
||||||
|
2).reshape(out_channel, in_channel)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def correct_unfold_norm_order(x):
|
||||||
|
in_channel = x.shape[0]
|
||||||
|
x = x.reshape(4, in_channel // 4)
|
||||||
|
x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
|
||||||
|
return x
|
||||||
|
|
||||||
|
new_ckpt = OrderedDict()
|
||||||
|
|
||||||
|
for k, v in list(ckpt.items()):
|
||||||
|
if 'language_backbone' in k or 'backbone' not in k or 'fpn' in k:
|
||||||
|
continue
|
||||||
|
new_v = v
|
||||||
|
new_k = k.replace('body.', '')
|
||||||
|
new_k = new_k.replace('module.', '')
|
||||||
|
if new_k.startswith('backbone.layers'):
|
||||||
|
new_k = new_k.replace('backbone.layers', 'backbone.stages')
|
||||||
|
if 'mlp' in new_k:
|
||||||
|
new_k = new_k.replace('mlp.fc1', 'ffn.layers.0.0')
|
||||||
|
new_k = new_k.replace('mlp.fc2', 'ffn.layers.1')
|
||||||
|
elif 'attn' in new_k:
|
||||||
|
new_k = new_k.replace('attn', 'attn.w_msa')
|
||||||
|
elif 'patch_embed' in k:
|
||||||
|
new_k = new_k.replace('proj', 'projection')
|
||||||
|
elif 'downsample' in new_k:
|
||||||
|
if 'reduction.' in k:
|
||||||
|
new_v = correct_unfold_reduction_order(new_v)
|
||||||
|
elif 'norm.' in k:
|
||||||
|
new_v = correct_unfold_norm_order(new_v)
|
||||||
|
|
||||||
|
new_ckpt[new_k] = new_v
|
||||||
|
return new_ckpt
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Convert keys in pretrained glip models to mmcls style.')
|
||||||
|
parser.add_argument('src', help='src model path or url')
|
||||||
|
# The dst path must be a full path of the new checkpoint.
|
||||||
|
parser.add_argument('dst', help='save path')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
|
||||||
|
|
||||||
|
if 'model' in checkpoint:
|
||||||
|
state_dict = checkpoint['model']
|
||||||
|
else:
|
||||||
|
state_dict = checkpoint
|
||||||
|
|
||||||
|
weight = convert_glip(state_dict)
|
||||||
|
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
||||||
|
torch.save(weight, args.dst)
|
||||||
|
|
||||||
|
print('Done!!')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Loading…
Reference in New Issue