[Feature] Support GLIP (#1308)

* rebase

* add glip

* update glip

* add links

* rename

* fix doc

---------

Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>
pull/1518/head
takuoko 2023-04-17 20:19:23 +09:00 committed by GitHub
parent 2c913020b9
commit fec3da781f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 221 additions and 0 deletions

View File

@ -200,6 +200,7 @@ Results and models are available in the [model zoo](https://mmpretrain.readthedo
<li><a href="configs/xcit">XCiT</a></li>
<li><a href="configs/levit">LeViT</a></li>
<li><a href="configs/riformer">RIFormer</a></li>
<li><a href="configs/glip">GLIP</a></li>
</ul>
</td>
<td>

View File

@ -196,6 +196,7 @@ mim install -e .
<li><a href="configs/xcit">XCiT</a></li>
<li><a href="configs/levit">LeViT</a></li>
<li><a href="configs/riformer">RIFormer</a></li>
<li><a href="configs/glip">GLIP</a></li>
</ul>
</td>
<td>

View File

@ -0,0 +1,57 @@
# GLIP
> [Grounded Language-Image Pre-training](https://arxiv.org/abs/2112.03857)
<!-- [ALGORITHM] -->
## Abstract
This paper presents a grounded language-image pre-training (GLIP) model for learning object-level, language-aware, and semantic-rich visual representations. GLIP unifies object detection and phrase grounding for pre-training. The unification brings two benefits: 1) it allows GLIP to learn from both detection and grounding data to improve both tasks and bootstrap a good grounding model; 2) GLIP can leverage massive image-text pairs by generating grounding boxes in a self-training fashion, making the learned representation semantic-rich. In our experiments, we pre-train GLIP on 27M grounding data, including 3M human-annotated and 24M web-crawled image-text pairs. The learned representations demonstrate strong zero-shot and few-shot transferability to various object-level recognition tasks. 1) When directly evaluated on COCO and LVIS (without seeing any images in COCO during pre-training), GLIP achieves 49.8 AP and 26.9 AP, respectively, surpassing many supervised baselines. 2) After fine-tuned on COCO, GLIP achieves 60.8 AP on val and 61.5 AP on test-dev, surpassing prior SoTA. 3) When transferred to 13 downstream object detection tasks, a 1-shot GLIP rivals with a fully-supervised Dynamic Head.
<div align="center">
<img src="https://github.com/microsoft/GLIP/blob/main/docs/lead.png" width="70%"/>
</div>
## How to use it?
<!-- [TABS-BEGIN] -->
**Use the model**
```python
import torch
from mmpretrain import get_model
model = get_model('swin-t_glip-pre_3rdparty', pretrained=True)
inputs = torch.rand(1, 3, 224, 224)
out = model(inputs)
print(type(out))
# To extract features.
feats = model.extract_feat(inputs)
print(type(feats))
```
<!-- [TABS-END] -->
## Results and models
### Pre-trained models
The pre-trained models are used to fine-tune, and therefore don't have evaluation results.
| Model | Pretrain | resolution | Download |
| :------------------------------------------ | :------------------------: | :--------: | :-------------------------------------------------------------------------------------------------------------------: |
| GLIP-T (`swin-t_glip-pre_3rdparty`)\* | O365,GoldG,CC3M,SBU | 224x224 | [model](https://download.openmmlab.com/mmclassification/v1/glip/swin-t_glip-pre_3rdparty_20230413-d85813b5.pth) |
| GLIP-L (`swin-l_glip-pre_3rdparty_384px`)\* | FourODs,GoldG,CC3M+12M,SBU | 384x384 | [model](https://download.openmmlab.com/mmclassification/v1/glip/swin-l_glip-pre_3rdparty_384px_20230413-04b198e8.pth) |
*Models with * are converted from the [official repo](https://github.com/microsoft/GLIP).*
## Citation
```bibtex
@inproceedings{li2021grounded,
title={Grounded Language-Image Pre-training},
author={Liunian Harold Li* and Pengchuan Zhang* and Haotian Zhang* and Jianwei Yang and Chunyuan Li and Yiwu Zhong and Lijuan Wang and Lu Yuan and Lei Zhang and Jenq-Neng Hwang and Kai-Wei Chang and Jianfeng Gao},
year={2022},
booktitle={CVPR},
}
```

View File

@ -0,0 +1,18 @@
model = dict(
type='ImageClassifier',
backbone=dict(
type='SwinTransformer',
arch='large',
img_size=384,
out_indices=(1, 2, 3), # original weight is for detection
stage_cfgs=dict(block_cfgs=dict(window_size=12))),
neck=None,
head=None)
data_preprocessor = dict(
# RGB format normalization parameters
mean=[103.53, 116.28, 123.675],
std=[57.375, 57.12, 58.395],
# convert image from BGR to RGB
to_rgb=False,
)

View File

@ -0,0 +1,18 @@
model = dict(
type='ImageClassifier',
backbone=dict(
type='SwinTransformer',
arch='tiny',
img_size=224,
out_indices=(1, 2, 3), # original weight is for detection
),
neck=None,
head=None)
data_preprocessor = dict(
# RGB format normalization parameters
mean=[103.53, 116.28, 123.675],
std=[57.375, 57.12, 58.395],
# convert image from BGR to RGB
to_rgb=False,
)

View File

@ -0,0 +1,49 @@
Collections:
- Name: GLIP
Metadata:
Training Techniques:
- AdamW
- Weight Decay
Architecture:
- Shift Window Multihead Self Attention
Paper:
URL: https://arxiv.org/abs/2112.03857
Title: "Grounded Language-Image Pre-training"
README: configs/glip/README.md
Code:
URL: https://github.com/open-mmlab/mmpretrain/blob/main/mmpretrain/models/backbones/vit.py
Version: v1.0.0rc8
Models:
- Name: swin-t_glip-pre_3rdparty
In Collection: GLIP
Metadata:
FLOPs: 4508464128
Parameters: 29056354
Training Data:
- O365
- GoldG
- CC3M
- SBU
Results: null
Weights: https://download.openmmlab.com/mmclassification/v1/glip/swin-t_glip-pre_3rdparty_20230413-d85813b5.pth
Converted From:
Weights: https://penzhanwu2bbs.blob.core.windows.net/data/GLIPv1_Open/models/glip_tiny_model_o365_goldg_cc_sbu.pth
Code: https://github.com/microsoft/GLIP
Config: configs/glip/glip-t_headless.py
- Name: swin-l_glip-pre_3rdparty_384px
In Collection: GLIP
Metadata:
FLOPs: 104080343040
Parameters: 196735516
Training Data:
- FourODs
- GoldG
- CC3M+12M
- SBU
Results: null
Weights: https://download.openmmlab.com/mmclassification/v1/glip/swin-l_glip-pre_3rdparty_384px_20230413-04b198e8.pth
Converted From:
Weights: https://penzhanwu2bbs.blob.core.windows.net/data/GLIPv1_Open/models/glip_large_model.pth
Code: https://github.com/microsoft/GLIP
Config: configs/glip/glip-l_headless.py

View File

@ -68,3 +68,4 @@ Import:
- configs/milan/metafile.yml
- configs/riformer/metafile.yml
- configs/sam/metafile.yml
- configs/glip/metafile.yml

View File

@ -0,0 +1,76 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_glip(ckpt):
def correct_unfold_reduction_order(x):
out_channel, in_channel = x.shape
x = x.reshape(out_channel, 4, in_channel // 4)
x = x[:, [0, 2, 1, 3], :].transpose(1,
2).reshape(out_channel, in_channel)
return x
def correct_unfold_norm_order(x):
in_channel = x.shape[0]
x = x.reshape(4, in_channel // 4)
x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
return x
new_ckpt = OrderedDict()
for k, v in list(ckpt.items()):
if 'language_backbone' in k or 'backbone' not in k or 'fpn' in k:
continue
new_v = v
new_k = k.replace('body.', '')
new_k = new_k.replace('module.', '')
if new_k.startswith('backbone.layers'):
new_k = new_k.replace('backbone.layers', 'backbone.stages')
if 'mlp' in new_k:
new_k = new_k.replace('mlp.fc1', 'ffn.layers.0.0')
new_k = new_k.replace('mlp.fc2', 'ffn.layers.1')
elif 'attn' in new_k:
new_k = new_k.replace('attn', 'attn.w_msa')
elif 'patch_embed' in k:
new_k = new_k.replace('proj', 'projection')
elif 'downsample' in new_k:
if 'reduction.' in k:
new_v = correct_unfold_reduction_order(new_v)
elif 'norm.' in k:
new_v = correct_unfold_norm_order(new_v)
new_ckpt[new_k] = new_v
return new_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in pretrained glip models to mmcls style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
weight = convert_glip(state_dict)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
print('Done!!')
if __name__ == '__main__':
main()