[Feature] Add BEiT backbone (#1404)

* [Feature] Add BEiT backbone

* fix

* fix

* fix

* fix

* add readme

* fix

* fix

* fix

* fix

* fix

* add link

* fix memory

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix test_beit.py

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix
This commit is contained in:
FangjianLin 2022-03-30 15:25:10 +08:00 committed by GitHub
parent 30864ea23d
commit 24f1563571
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 1345 additions and 2 deletions

View File

@ -85,6 +85,7 @@ Supported backbones:
- [x] [Swin Transformer (ICCV'2021)](configs/swin)
- [x] [Twins (NeurIPS'2021)](configs/twins)
- [x] [ConvNeXt (CVPR'2022)](configs/convnext)
- [x] [BEiT (ICLR'2022)](configs/beit)
Supported methods:

View File

@ -84,6 +84,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
- [x] [Swin Transformer (ICCV'2021)](configs/swin)
- [x] [Twins (NeurIPS'2021)](configs/twins)
- [x] [ConvNeXt (CVPR'2022)](configs/convnext)
- [x] [BEiT (ICLR'2022)](configs/beit)
已支持的算法:

View File

@ -0,0 +1,50 @@
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(
type='BEiT',
img_size=(640, 640),
patch_size=16,
in_channels=3,
embed_dims=768,
num_layers=12,
num_heads=12,
mlp_ratio=4,
out_indices=(3, 5, 7, 11),
qv_bias=True,
attn_drop_rate=0.0,
drop_path_rate=0.1,
norm_cfg=dict(type='LN', eps=1e-6),
act_cfg=dict(type='GELU'),
norm_eval=False,
init_values=0.1),
neck=dict(type='Feature2Pyramid', embed_dim=768, rescales=[4, 2, 1, 0.5]),
decode_head=dict(
type='UPerHead',
in_channels=[768, 768, 768, 768],
in_index=[0, 1, 2, 3],
pool_scales=(1, 2, 3, 6),
channels=768,
dropout_ratio=0.1,
num_classes=150,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=768,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=150,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))

84
configs/beit/README.md Normal file
View File

@ -0,0 +1,84 @@
# BEiT
[BEiT: BERT Pre-Training of Image Transformers](https://arxiv.org/abs/2106.08254)
## Introduction
<!-- [BACKBONE] -->
<a href="https://github.com/microsoft/unilm/tree/master/beit">Official Repo</a>
<a href="https://github.com/open-mmlab/mmsegmentation/blob/v0.23.0/mmseg/models/backbones/beit.py#1404">Code Snippet</a>
## Abstract
<!-- [ABSTRACT] -->
We introduce a self-supervised vision representation model BEiT, which stands for Bidirectional Encoder representation from Image Transformers. Following BERT developed in the natural language processing area, we propose a masked image modeling task to pretrain vision Transformers. Specifically, each image has two views in our pre-training, i.e, image patches (such as 16x16 pixels), and visual tokens (i.e., discrete tokens). We first "tokenize" the original image into visual tokens. Then we randomly mask some image patches and fed them into the backbone Transformer. The pre-training objective is to recover the original visual tokens based on the corrupted image patches. After pre-training BEiT, we directly fine-tune the model parameters on downstream tasks by appending task layers upon the pretrained encoder. Experimental results on image classification and semantic segmentation show that our model achieves competitive results with previous pre-training methods. For example, base-size BEiT achieves 83.2% top-1 accuracy on ImageNet-1K, significantly outperforming from-scratch DeiT training (81.8%) with the same setup. Moreover, large-size BEiT obtains 86.3% only using ImageNet-1K, even outperforming ViT-L with supervised pre-training on ImageNet-22K (85.2%). The code and pretrained models are available at [this https URL](https://github.com/microsoft/unilm/tree/master/beit).
<!-- [IMAGE] -->
<div align=center>
<img src="https://user-images.githubusercontent.com/93248678/160155758-781c9a45-b1d7-4530-9015-88eca6645006.png" width="70%"/>
</div>
## Citation
```bibtex
@inproceedings{beit,
title={{BEiT}: {BERT} Pre-Training of Image Transformers},
author={Hangbo Bao and Li Dong and Songhao Piao and Furu Wei},
booktitle={International Conference on Learning Representations},
year={2022},
url={https://openreview.net/forum?id=p-BhZSz59o4}
}
```
## Usage
To use other repositories' pre-trained models, it is necessary to convert keys.
We provide a script [`beit2mmseg.py`](../../tools/model_converters/beit2mmseg.py) in the tools directory to convert the key of models from [the official repo](https://github.com/microsoft/unilm/tree/master/beit/semantic_segmentation) to MMSegmentation style.
```shell
python tools/model_converters/beit2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH}
```
E.g.
```shell
python tools/model_converters/beit2mmseg.py https://unilm.blob.core.windows.net/beit/beit_base_patch16_224_pt22k_ft22k.pth pretrain/beit_base_patch16_224_pt22k_ft22k.pth
```
This script convert model from `PRETRAIN_PATH` and store the converted model in `STORE_PATH`.
In our default setting, pretrained models could be defined below:
| pretrained models | original models |
| ------ | -------- |
|BEiT_base.pth | ['BEiT_base'](https://unilm.blob.core.windows.net/beit/beit_base_patch16_224_pt22k_ft22k.pth) |
|BEiT_large.pth | ['BEiT_large'](https://unilm.blob.core.windows.net/beit/beit_large_patch16_224_pt22k_ft22k.pth) |
Verify the single-scale results of the model:
```shell
sh tools/dist_test.sh \
configs/beit/upernet_beit-large_fp16_8x1_640x640_160k_ade20k.py \
upernet_beit-large_fp16_8x1_640x640_160k_ade20k-8fc0dd5d.pth $GPUS --eval mIoU
```
Since relative position embedding requires the input length and width to be equal, the sliding window is adopted for multi-scale inference. So we set min_size=640, that is, the shortest edge is 640. So the multi-scale inference of config is performed separately, instead of '--aug-test'. For multi-scale inference:
```shell
sh tools/dist_test.sh \
configs/beit/upernet_beit-large_fp16_640x640_160k_ade20k_ms.py \
upernet_beit-large_fp16_8x1_640x640_160k_ade20k-8fc0dd5d.pth $GPUS --eval mIoU
```
## Results and models
### ADE20K
| Method | Backbone | Crop Size | pretrain | pretrain img size | Batch Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| ------ | -------- | --------- | ---------- | ------- | -------- | --- | --- | -------------- | ----- | ------------: | -------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| UperNet | BEiT-B | 640x640 | ImageNet-22K | 224x224 | 16 | 160000 | 15.88 | 2.00 | 53.08 | 53.84 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/beit/upernet_beit-base_8x2_640x640_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/beit/upernet_beit-base_8x2_640x640_160k_ade20k/upernet_beit-base_8x2_640x640_160k_ade20k-eead221d.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/beit/upernet_beit-base_8x2_640x640_160k_ade20k/upernet_beit-base_8x2_640x640_160k_ade20k.log.json) |
| UperNet | BEiT-L | 640x640 | ImageNet-22K | 224x224 | 8 | 320000 | 22.64 | 0.96 | 56.33 | 56.84 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/beit/upernet_beit-large_fp16_8x1_640x640_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/beit/upernet_beit-large_fp16_8x1_640x640_160k_ade20k/upernet_beit-large_fp16_8x1_640x640_160k_ade20k-8fc0dd5d.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/beit/upernet_beit-large_fp16_8x1_640x640_160k_ade20k/upernet_beit-large_fp16_8x1_640x640_160k_ade20k.log.json) |

45
configs/beit/beit.yml Normal file
View File

@ -0,0 +1,45 @@
Models:
- Name: upernet_beit-base_8x2_640x640_160k_ade20k
In Collection: UperNet
Metadata:
backbone: BEiT-B
crop size: (640,640)
lr schd: 160000
inference time (ms/im):
- value: 500.0
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (640,640)
Training Memory (GB): 15.88
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 53.08
mIoU(ms+flip): 53.84
Config: configs/beit/upernet_beit-base_8x2_640x640_160k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/beit/upernet_beit-base_8x2_640x640_160k_ade20k/upernet_beit-base_8x2_640x640_160k_ade20k-eead221d.pth
- Name: upernet_beit-large_fp16_8x1_640x640_160k_ade20k
In Collection: UperNet
Metadata:
backbone: BEiT-L
crop size: (640,640)
lr schd: 320000
inference time (ms/im):
- value: 1041.67
hardware: V100
backend: PyTorch
batch size: 1
mode: FP16
resolution: (640,640)
Training Memory (GB): 22.64
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 56.33
mIoU(ms+flip): 56.84
Config: configs/beit/upernet_beit-large_fp16_8x1_640x640_160k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/beit/upernet_beit-large_fp16_8x1_640x640_160k_ade20k/upernet_beit-large_fp16_8x1_640x640_160k_ade20k-8fc0dd5d.pth

View File

@ -0,0 +1,24 @@
_base_ = './upernet_beit-base_8x2_640x640_160k_ade20k.py'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2560, 640),
img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=True,
transforms=[
dict(type='Resize', keep_ratio=True, min_size=640),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline),
samples_per_gpu=2)

View File

@ -0,0 +1,30 @@
_base_ = [
'../_base_/models/upernet_beit.py', '../_base_/datasets/ade20k_640x640.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
model = dict(
pretrained='pretrain/beit_base_patch16_224_pt22k_ft22k.pth',
test_cfg=dict(mode='slide', crop_size=(640, 640), stride=(426, 426)))
optimizer = dict(
_delete_=True,
type='AdamW',
lr=3e-5,
betas=(0.9, 0.999),
weight_decay=0.05,
constructor='LayerDecayOptimizerConstructor',
paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.9))
lr_config = dict(
_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0,
min_lr=0.0,
by_epoch=False)
# By default, models are trained on 8 GPUs with 2 images per GPU
data = dict(samples_per_gpu=2)

View File

@ -0,0 +1,22 @@
_base_ = './upernet_beit-large_fp16_8x1_640x640_160k_ade20k.py'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2560, 640),
img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=True,
transforms=[
dict(type='Resize', keep_ratio=True, min_size=640),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
val=dict(pipeline=test_pipeline), test=dict(pipeline=test_pipeline))

View File

@ -0,0 +1,47 @@
_base_ = [
'../_base_/models/upernet_beit.py', '../_base_/datasets/ade20k_640x640.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_320k.py'
]
model = dict(
pretrained='pretrain/beit_large_patch16_224_pt22k_ft22k.pth',
backbone=dict(
type='BEiT',
embed_dims=1024,
num_layers=24,
num_heads=16,
mlp_ratio=4,
qv_bias=True,
init_values=1e-6,
drop_path_rate=0.2,
out_indices=[7, 11, 15, 23]),
neck=dict(embed_dim=1024, rescales=[4, 2, 1, 0.5]),
decode_head=dict(
in_channels=[1024, 1024, 1024, 1024], num_classes=150, channels=1024),
auxiliary_head=dict(in_channels=1024, num_classes=150),
test_cfg=dict(mode='slide', crop_size=(640, 640), stride=(426, 426)))
optimizer = dict(
_delete_=True,
type='AdamW',
lr=2e-5,
betas=(0.9, 0.999),
weight_decay=0.05,
constructor='LayerDecayOptimizerConstructor',
paramwise_cfg=dict(num_layers=24, layer_decay_rate=0.95))
lr_config = dict(
_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=3000,
warmup_ratio=1e-6,
power=1.0,
min_lr=0.0,
by_epoch=False)
data = dict(samples_per_gpu=1)
optimizer_config = dict(
type='GradientCumulativeFp16OptimizerHook', cumulative_iters=2)
fp16 = dict()

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .evaluation import * # noqa: F401, F403
from .layer_decay_optimizer_constructor import \
LayerDecayOptimizerConstructor # noqa: F401
from .seg import * # noqa: F401, F403
from .utils import * # noqa: F401, F403

View File

@ -0,0 +1,87 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import (OPTIMIZER_BUILDERS, DefaultOptimizerConstructor,
get_dist_info)
from mmseg.utils import get_root_logger
def get_num_layer_for_vit(var_name, num_max_layer):
"""Get the layer id to set the different learning rates.
Args:
var_name (str): The key of the model.
num_max_layer (int): Maximum number of backbone layers.
Returns:
layer id (int): Returns the layer id of the key.
"""
if var_name in ('backbone.cls_token', 'backbone.mask_token',
'backbone.pos_embed'):
return 0
elif var_name.startswith('backbone.patch_embed'):
return 0
elif var_name.startswith('backbone.layers'):
layer_id = int(var_name.split('.')[2])
return layer_id + 1
else:
return num_max_layer - 1
@OPTIMIZER_BUILDERS.register_module()
class LayerDecayOptimizerConstructor(DefaultOptimizerConstructor):
"""Different learning rates are set for different layers of backbone."""
def add_params(self, params, module):
"""Add all parameters of module to the params list.
The parameters of the given module will be added to the list of param
groups, with specific rules defined by paramwise_cfg.
Args:
params (list[dict]): A list of param groups, it will be modified
in place.
module (nn.Module): The module to be added.
"""
parameter_groups = {}
logger = get_root_logger()
logger.info(self.paramwise_cfg)
num_layers = self.paramwise_cfg.get('num_layers') + 2
layer_decay_rate = self.paramwise_cfg.get('layer_decay_rate')
logger.info(f'Build LayerDecayOptimizerConstructor '
f'{layer_decay_rate} - {num_layers}')
weight_decay = self.base_wd
for name, param in module.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(param.shape) == 1 or name.endswith('.bias') or name in (
'pos_embed', 'cls_token'):
group_name = 'no_decay'
this_weight_decay = 0.
else:
group_name = 'decay'
this_weight_decay = weight_decay
layer_id = get_num_layer_for_vit(name, num_layers)
group_name = f'layer_{layer_id}_{group_name}'
if group_name not in parameter_groups:
scale = layer_decay_rate**(num_layers - layer_id - 1)
parameter_groups[group_name] = {
'weight_decay': this_weight_decay,
'params': [],
'param_names': [],
'lr_scale': scale,
'group_name': group_name,
'lr': scale * self.base_lr
}
parameter_groups[group_name]['params'].append(param)
parameter_groups[group_name]['param_names'].append(name)
rank, _ = get_dist_info()
if rank == 0:
to_display = {}
for key in parameter_groups:
to_display[key] = {
'param_names': parameter_groups[key]['param_names'],
'lr_scale': parameter_groups[key]['lr_scale'],
'lr': parameter_groups[key]['lr'],
'weight_decay': parameter_groups[key]['weight_decay']
}
logger.info(f'Param groups ={to_display}')
params.extend(parameter_groups.values())

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .beit import BEiT
from .bisenetv1 import BiSeNetV1
from .bisenetv2 import BiSeNetV2
from .cgnet import CGNet
@ -24,5 +25,5 @@ __all__ = [
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT',
'SVT', 'STDCNet', 'STDCContextPathNet'
'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT'
]

View File

@ -0,0 +1,532 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.transformer import FFN
from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
trunc_normal_)
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.utils import _pair as to_2tuple
from mmseg.utils import get_root_logger
from ..builder import BACKBONES
from ..utils import PatchEmbed
try:
from scipy import interpolate
except ImportError:
interpolate = None
class BEiTAttention(BaseModule):
"""Window based multi-head self-attention (W-MSA) module with relative
position bias.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (tuple[int]): The height and width of the window.
qv_bias (bool): If True, add a learnable bias to q, v.
Default: True.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
attn_drop_rate (float): Dropout ratio of attention weight.
Default: 0.0
proj_drop_rate (float): Dropout ratio of output. Default: 0.
init_cfg (dict | None, optional): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
num_heads,
window_size,
qv_bias=True,
qk_scale=None,
attn_drop_rate=0.,
proj_drop_rate=0.,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims
self.num_heads = num_heads
head_embed_dims = embed_dims // num_heads
self.scale = qk_scale or head_embed_dims**-0.5
if qv_bias:
self.q_bias = nn.Parameter(torch.zeros(embed_dims))
self.v_bias = nn.Parameter(torch.zeros(embed_dims))
else:
self.q_bias = None
self.v_bias = None
self.window_size = window_size
# cls to token & token 2 cls & cls to cls
self.num_relative_distance = (2 * window_size[0] -
1) * (2 * window_size[1] - 1) + 3
# relative_position_bias_table shape is (2*Wh-1 * 2*Ww-1 + 3, nH)
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads))
# get pair-wise relative position index for
# each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
# coords shape is (2, Wh, Ww)
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
# coords_flatten shape is (2, Wh*Ww)
coords_flatten = torch.flatten(coords, 1)
relative_coords = (
coords_flatten[:, :, None] - coords_flatten[:, None, :])
# relative_coords shape is (Wh*Ww, Wh*Ww, 2)
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
# shift to start from 0
relative_coords[:, :, 0] += window_size[0] - 1
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = torch.zeros(
size=(window_size[0] * window_size[1] + 1, ) * 2,
dtype=relative_coords.dtype)
# relative_position_index shape is (Wh*Ww, Wh*Ww)
relative_position_index[1:, 1:] = relative_coords.sum(-1)
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer('relative_position_index',
relative_position_index)
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=False)
self.attn_drop = nn.Dropout(attn_drop_rate)
self.proj = nn.Linear(embed_dims, embed_dims)
self.proj_drop = nn.Dropout(proj_drop_rate)
def init_weights(self):
trunc_normal_(self.relative_position_bias_table, std=0.02)
def forward(self, x):
"""
Args:
x (tensor): input features with shape of (num_windows*B, N, C).
"""
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
k_bias = torch.zeros_like(self.v_bias, requires_grad=False)
qkv_bias = torch.cat((self.q_bias, k_bias, self.v_bias))
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
if self.relative_position_bias_table is not None:
Wh = self.window_size[0]
Ww = self.window_size[1]
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)].view(
Wh * Ww + 1, Wh * Ww + 1, -1)
relative_position_bias = relative_position_bias.permute(
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class TransformerEncoderLayer(BaseModule):
"""Implements one encoder layer in Vision Transformer.
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
attn_drop_rate (float): The drop out rate for attention layer.
Default: 0.0.
drop_path_rate (float): Stochastic depth rate. Default 0.0.
num_fcs (int): The number of fully-connected layers for FFNs.
Default: 2.
qv_bias (bool): Enable bias for qv if True. Default: True
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
window_size (tuple[int], optional): The height and width of the window.
Default: None.
init_values (float, optional): Initialize the values of BEiTAttention
and FFN with learnable scaling. Default: None.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
attn_drop_rate=0.,
drop_path_rate=0.,
num_fcs=2,
qv_bias=True,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
window_size=None,
init_values=None):
super(TransformerEncoderLayer, self).__init__()
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, embed_dims, postfix=1)
self.add_module(self.norm1_name, norm1)
self.attn = BEiTAttention(
embed_dims=embed_dims,
num_heads=num_heads,
window_size=window_size,
qv_bias=qv_bias,
qk_scale=None,
attn_drop_rate=attn_drop_rate,
proj_drop_rate=0.,
init_cfg=None)
self.ffn = FFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=num_fcs,
ffn_drop=0.,
dropout_layer=None,
act_cfg=act_cfg,
add_identity=False)
self.norm2_name, norm2 = build_norm_layer(
norm_cfg, embed_dims, postfix=2)
self.add_module(self.norm2_name, norm2)
# NOTE: drop path for stochastic depth, we shall see if
# this is better than dropout here
dropout_layer = dict(type='DropPath', drop_prob=drop_path_rate)
self.drop_path = build_dropout(
dropout_layer) if dropout_layer else nn.Identity()
self.gamma_1 = nn.Parameter(
init_values * torch.ones((embed_dims)), requires_grad=True)
self.gamma_2 = nn.Parameter(
init_values * torch.ones((embed_dims)), requires_grad=True)
@property
def norm1(self):
return getattr(self, self.norm1_name)
@property
def norm2(self):
return getattr(self, self.norm2_name)
def forward(self, x):
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
x = x + self.drop_path(self.gamma_2 * self.ffn(self.norm2(x)))
return x
@BACKBONES.register_module()
class BEiT(BaseModule):
"""BERT Pre-Training of Image Transformers.
Args:
img_size (int | tuple): Input image size. Default: 224.
patch_size (int): The patch size. Default: 16.
in_channels (int): Number of input channels. Default: 3.
embed_dims (int): Embedding dimension. Default: 768.
num_layers (int): Depth of transformer. Default: 12.
num_heads (int): Number of attention heads. Default: 12.
mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
Default: 4.
out_indices (list | tuple | int): Output from which stages.
Default: -1.
qv_bias (bool): Enable bias for qv if True. Default: True.
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.0
drop_path_rate (float): Stochastic depth rate. Default 0.0.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN')
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
patch_norm (bool): Whether to add a norm in PatchEmbed Block.
Default: False.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Default: False.
num_fcs (int): The number of fully-connected layers for FFNs.
Default: 2.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
pretrained (str, optional): Model pretrained path. Default: None.
init_values (float): Initialize the values of BEiTAttention and FFN
with learnable scaling.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
img_size=224,
patch_size=16,
in_channels=3,
embed_dims=768,
num_layers=12,
num_heads=12,
mlp_ratio=4,
out_indices=-1,
qv_bias=True,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_cfg=dict(type='LN'),
act_cfg=dict(type='GELU'),
patch_norm=False,
final_norm=False,
num_fcs=2,
norm_eval=False,
pretrained=None,
init_values=0.1,
init_cfg=None):
super(BEiT, self).__init__(init_cfg=init_cfg)
if isinstance(img_size, int):
img_size = to_2tuple(img_size)
elif isinstance(img_size, tuple):
if len(img_size) == 1:
img_size = to_2tuple(img_size[0])
assert len(img_size) == 2, \
f'The size of image should have length 1 or 2, ' \
f'but got {len(img_size)}'
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be set at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is not None:
raise TypeError('pretrained must be a str or None')
self.img_size = img_size
self.patch_size = patch_size
self.norm_eval = norm_eval
self.pretrained = pretrained
self.patch_embed = PatchEmbed(
in_channels=in_channels,
embed_dims=embed_dims,
conv_type='Conv2d',
kernel_size=patch_size,
stride=patch_size,
padding=0,
norm_cfg=norm_cfg if patch_norm else None,
init_cfg=None)
window_size = (img_size[0] // patch_size, img_size[1] // patch_size)
self.patch_shape = window_size
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
if isinstance(out_indices, int):
if out_indices == -1:
out_indices = num_layers - 1
self.out_indices = [out_indices]
elif isinstance(out_indices, list) or isinstance(out_indices, tuple):
self.out_indices = out_indices
else:
raise TypeError('out_indices must be type of int, list or tuple')
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)]
self.layers = ModuleList()
for i in range(num_layers):
self.layers.append(
TransformerEncoderLayer(
embed_dims=embed_dims,
num_heads=num_heads,
feedforward_channels=mlp_ratio * embed_dims,
attn_drop_rate=attn_drop_rate,
drop_path_rate=dpr[i],
num_fcs=num_fcs,
qv_bias=qv_bias,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
window_size=window_size,
init_values=init_values))
self.final_norm = final_norm
if final_norm:
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, embed_dims, postfix=1)
self.add_module(self.norm1_name, norm1)
@property
def norm1(self):
return getattr(self, self.norm1_name)
def _geometric_sequence_interpolation(self, src_size, dst_size, sequence,
num):
"""Get new sequence via geometric sequence interpolation.
Args:
src_size (int): Pos_embedding size in pre-trained model.
dst_size (int): Pos_embedding size in the current model.
sequence (tensor): The relative position bias of the pretrain
model after removing the extra tokens.
num (int): Number of attention heads.
Returns:
new_sequence (tensor): Geometric sequence interpolate the
pre-trained relative position bias to the size of
the current model.
"""
def geometric_progression(a, r, n):
return a * (1.0 - r**n) / (1.0 - r)
# Here is a binary function.
left, right = 1.01, 1.5
while right - left > 1e-6:
q = (left + right) / 2.0
gp = geometric_progression(1, q, src_size // 2)
if gp > dst_size // 2:
right = q
else:
left = q
# The position of each interpolated point is determined
# by the ratio obtained by dichotomy.
dis = []
cur = 1
for i in range(src_size // 2):
dis.append(cur)
cur += q**(i + 1)
r_ids = [-_ for _ in reversed(dis)]
x = r_ids + [0] + dis
y = r_ids + [0] + dis
t = dst_size // 2.0
dx = np.arange(-t, t + 0.1, 1.0)
dy = np.arange(-t, t + 0.1, 1.0)
# Interpolation functions are being executed and called.
new_sequence = []
for i in range(num):
z = sequence[:, i].view(src_size, src_size).float().numpy()
f = interpolate.interp2d(x, y, z, kind='cubic')
new_sequence.append(
torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(sequence))
new_sequence = torch.cat(new_sequence, dim=-1)
return new_sequence
def resize_rel_pos_embed(self, checkpoint):
"""Resize relative pos_embed weights.
This function is modified from
https://github.com/microsoft/unilm/blob/master/beit/semantic_segmentation/mmcv_custom/checkpoint.py. # noqa: E501
Copyright (c) Microsoft Corporation
Licensed under the MIT License
Args:
checkpoint (dict): Key and value of the pretrain model.
Returns:
state_dict (dict): Interpolate the relative pos_embed weights
in the pre-train model to the current model size.
"""
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
all_keys = list(state_dict.keys())
for key in all_keys:
if 'relative_position_index' in key:
state_dict.pop(key)
# In order to keep the center of pos_bias as consistent as
# possible after interpolation, and vice versa in the edge
# area, the geometric sequence interpolation method is adopted.
if 'relative_position_bias_table' in key:
rel_pos_bias = state_dict[key]
src_num_pos, num_attn_heads = rel_pos_bias.size()
dst_num_pos, _ = self.state_dict()[key].size()
dst_patch_shape = self.patch_shape
if dst_patch_shape[0] != dst_patch_shape[1]:
raise NotImplementedError()
# Count the number of extra tokens.
num_extra_tokens = dst_num_pos - (
dst_patch_shape[0] * 2 - 1) * (
dst_patch_shape[1] * 2 - 1)
src_size = int((src_num_pos - num_extra_tokens)**0.5)
dst_size = int((dst_num_pos - num_extra_tokens)**0.5)
if src_size != dst_size:
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
new_rel_pos_bias = self._geometric_sequence_interpolation(
src_size, dst_size, rel_pos_bias, num_attn_heads)
new_rel_pos_bias = torch.cat(
(new_rel_pos_bias, extra_tokens), dim=0)
state_dict[key] = new_rel_pos_bias
return state_dict
def init_weights(self):
def _init_weights(m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
self.apply(_init_weights)
if (isinstance(self.init_cfg, dict)
and self.init_cfg.get('type') == 'Pretrained'):
logger = get_root_logger()
checkpoint = _load_checkpoint(
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
state_dict = self.resize_rel_pos_embed(checkpoint)
self.load_state_dict(state_dict, False)
elif self.init_cfg is not None:
super(BEiT, self).init_weights()
else:
# We only implement the 'jax_impl' initialization implemented at
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
# Copyright 2019 Ross Wightman
# Licensed under the Apache License, Version 2.0 (the "License")
trunc_normal_(self.cls_token, std=.02)
for n, m in self.named_modules():
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
if 'ffn' in n:
nn.init.normal_(m.bias, mean=0., std=1e-6)
else:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
kaiming_init(m, mode='fan_in', bias=0.)
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
constant_init(m, val=1.0, bias=0.)
def forward(self, inputs):
B = inputs.shape[0]
x, hw_shape = self.patch_embed(inputs)
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
if i == len(self.layers) - 1:
if self.final_norm:
x = self.norm1(x)
if i in self.out_indices:
# Remove class token and reshape token for decoder head
out = x[:, 1:]
B, _, C = out.shape
out = out.reshape(B, hw_shape[0], hw_shape[1],
C).permute(0, 3, 1, 2).contiguous()
outs.append(out)
return tuple(outs)
def train(self, mode=True):
super(BEiT, self).train(mode)
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, nn.LayerNorm):
m.eval()

View File

@ -1,8 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .featurepyramid import Feature2Pyramid
from .fpn import FPN
from .ic_neck import ICNeck
from .jpu import JPU
from .mla_neck import MLANeck
from .multilevel_neck import MultiLevelNeck
__all__ = ['FPN', 'MultiLevelNeck', 'MLANeck', 'ICNeck', 'JPU']
__all__ = [
'FPN', 'MultiLevelNeck', 'MLANeck', 'ICNeck', 'JPU', 'Feature2Pyramid'
]

View File

@ -0,0 +1,67 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from ..builder import NECKS
@NECKS.register_module()
class Feature2Pyramid(nn.Module):
"""Feature2Pyramid.
A neck structure connect ViT backbone and decoder_heads.
Args:
embed_dims (int): Embedding dimension.
rescales (list[float]): Different sampling multiples were
used to obtain pyramid features. Default: [4, 2, 1, 0.5].
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='SyncBN', requires_grad=True).
"""
def __init__(self,
embed_dim,
rescales=[4, 2, 1, 0.5],
norm_cfg=dict(type='SyncBN', requires_grad=True)):
super(Feature2Pyramid, self).__init__()
self.rescales = rescales
self.upsample_4x = None
for k in self.rescales:
if k == 4:
self.upsample_4x = nn.Sequential(
nn.ConvTranspose2d(
embed_dim, embed_dim, kernel_size=2, stride=2),
build_norm_layer(norm_cfg, embed_dim)[1],
nn.GELU(),
nn.ConvTranspose2d(
embed_dim, embed_dim, kernel_size=2, stride=2),
)
elif k == 2:
self.upsample_2x = nn.Sequential(
nn.ConvTranspose2d(
embed_dim, embed_dim, kernel_size=2, stride=2))
elif k == 1:
self.identity = nn.Identity()
elif k == 0.5:
self.downsample_2x = nn.MaxPool2d(kernel_size=2, stride=2)
elif k == 0.25:
self.downsample_4x = nn.MaxPool2d(kernel_size=4, stride=4)
else:
raise KeyError(f'invalid {k} for feature2pyramid')
def forward(self, inputs):
assert len(inputs) == len(self.rescales)
outputs = []
if self.upsample_4x is not None:
ops = [
self.upsample_4x, self.upsample_2x, self.identity,
self.downsample_2x
]
else:
ops = [
self.upsample_2x, self.identity, self.downsample_2x,
self.downsample_4x
]
for i in range(len(inputs)):
outputs.append(ops[i](inputs[i]))
return tuple(outputs)

View File

@ -1,6 +1,7 @@
Import:
- configs/ann/ann.yml
- configs/apcnet/apcnet.yml
- configs/beit/beit.yml
- configs/bisenetv1/bisenetv1.yml
- configs/bisenetv2/bisenetv2.yml
- configs/ccnet/ccnet.yml

View File

@ -0,0 +1,70 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmseg.core.layer_decay_optimizer_constructor import \
LayerDecayOptimizerConstructor
layer_wise_gt_lst = [{
'weight_decay': 0.0,
'lr_scale': 16
}, {
'weight_decay': 0.05,
'lr_scale': 8
}, {
'weight_decay': 0.0,
'lr_scale': 8
}, {
'weight_decay': 0.05,
'lr_scale': 4
}, {
'weight_decay': 0.0,
'lr_scale': 4
}, {
'weight_decay': 0.05,
'lr_scale': 2
}, {
'weight_decay': 0.0,
'lr_scale': 2
}]
class BEiTExampleModel(nn.Module):
def __init__(self, depth):
super().__init__()
self.backbone = nn.ModuleList()
# add some variables to meet unit test coverate rate
self.backbone.cls_token = nn.Parameter(torch.ones(1))
self.backbone.patch_embed = nn.Parameter(torch.ones(1))
self.backbone.layers = nn.ModuleList()
for _ in range(depth):
layer = nn.Conv2d(3, 3, 1)
self.backbone.layers.append(layer)
def check_beit_adamw_optimizer(optimizer, gt_lst):
assert isinstance(optimizer, torch.optim.AdamW)
assert optimizer.defaults['lr'] == 1
assert optimizer.defaults['weight_decay'] == 0.05
param_groups = optimizer.param_groups
# 1 layer (cls_token and patch_embed) + 3 layers * 2 (w, b) = 7 layers
assert len(param_groups) == 7
for i, param_dict in enumerate(param_groups):
assert param_dict['weight_decay'] == gt_lst[i]['weight_decay']
assert param_dict['lr_scale'] == gt_lst[i]['lr_scale']
assert param_dict['lr_scale'] == param_dict['lr']
def test_beit_layer_decay_optimizer_constructor():
# paramwise_cfg with ConvNeXtExampleModel
model = BEiTExampleModel(depth=3)
optimizer_cfg = dict(
type='AdamW', lr=1, betas=(0.9, 0.999), weight_decay=0.05)
paramwise_cfg = dict(num_layers=3, layer_decay_rate=2)
optim_constructor = LayerDecayOptimizerConstructor(optimizer_cfg,
paramwise_cfg)
optimizer = optim_constructor(model)
check_beit_adamw_optimizer(optimizer, layer_wise_gt_lst)

View File

@ -0,0 +1,182 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmseg.models.backbones.beit import BEiT
from .utils import check_norm_state
def test_beit_backbone():
with pytest.raises(TypeError):
# pretrained must be a string path
model = BEiT()
model.init_weights(pretrained=0)
with pytest.raises(TypeError):
# img_size must be int or tuple
model = BEiT(img_size=512.0)
with pytest.raises(TypeError):
# out_indices must be int ,list or tuple
model = BEiT(out_indices=1.)
with pytest.raises(AssertionError):
# The length of img_size tuple must be lower than 3.
BEiT(img_size=(224, 224, 224))
with pytest.raises(TypeError):
# Pretrained must be None or Str.
BEiT(pretrained=123)
# Test img_size isinstance tuple
imgs = torch.randn(1, 3, 224, 224)
model = BEiT(img_size=(224, ))
model.init_weights()
model(imgs)
# Test img_size isinstance tuple
imgs = torch.randn(1, 3, 224, 224)
model = BEiT(img_size=(224, 224))
model(imgs)
# Test norm_eval = True
model = BEiT(norm_eval=True)
model.train()
# Test BEiT backbone with input size of 224 and patch size of 16
model = BEiT()
model.init_weights()
model.train()
# Test qv_bias
model = BEiT(qv_bias=False)
model.train()
# Test out_indices = list
model = BEiT(out_indices=[2, 4, 8, 12])
model.train()
assert check_norm_state(model.modules(), True)
# Test image size = (224, 224)
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat[-1].shape == (1, 768, 14, 14)
# Test BEiT backbone with input size of 256 and patch size of 16
model = BEiT(img_size=(256, 256))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 256, 256)
feat = model(imgs)
assert feat[-1].shape == (1, 768, 16, 16)
# Test BEiT backbone with input size of 32 and patch size of 16
model = BEiT(img_size=(32, 32))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 32, 32)
feat = model(imgs)
assert feat[-1].shape == (1, 768, 2, 2)
# Test unbalanced size input image
model = BEiT(img_size=(112, 224))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 112, 224)
feat = model(imgs)
assert feat[-1].shape == (1, 768, 7, 14)
# Test irregular input image
model = BEiT(img_size=(234, 345))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 234, 345)
feat = model(imgs)
assert feat[-1].shape == (1, 768, 14, 21)
# Test init_values=0
model = BEiT(init_values=0)
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat[-1].shape == (1, 768, 14, 14)
# Test final norm
model = BEiT(final_norm=True)
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat[-1].shape == (1, 768, 14, 14)
# Test patch norm
model = BEiT(patch_norm=True)
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat[-1].shape == (1, 768, 14, 14)
def test_beit_init():
path = 'PATH_THAT_DO_NOT_EXIST'
# Test all combinations of pretrained and init_cfg
# pretrained=None, init_cfg=None
model = BEiT(pretrained=None, init_cfg=None)
assert model.init_cfg is None
model.init_weights()
# pretrained=None
# init_cfg loads pretrain from an non-existent file
model = BEiT(
pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path))
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
# Test loading a checkpoint from an non-existent file
with pytest.raises(OSError):
model.init_weights()
# test resize_rel_pos_embed
value = torch.randn(732, 16)
ckpt = {
'state_dict': {
'layers.0.attn.relative_position_index': 0,
'layers.0.attn.relative_position_bias_table': value
}
}
model = BEiT(img_size=(512, 512))
with pytest.raises(AttributeError):
model.resize_rel_pos_embed(ckpt)
# pretrained=None
# init_cfg=123, whose type is unsupported
model = BEiT(pretrained=None, init_cfg=123)
with pytest.raises(TypeError):
model.init_weights()
# pretrained loads pretrain from an non-existent file
# init_cfg=None
model = BEiT(pretrained=path, init_cfg=None)
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
# Test loading a checkpoint from an non-existent file
with pytest.raises(OSError):
model.init_weights()
# pretrained loads pretrain from an non-existent file
# init_cfg loads pretrain from an non-existent file
with pytest.raises(AssertionError):
model = BEiT(
pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path))
with pytest.raises(AssertionError):
model = BEiT(pretrained=path, init_cfg=123)
# pretrain=123, whose type is unsupported
# init_cfg=None
with pytest.raises(TypeError):
model = BEiT(pretrained=123, init_cfg=None)
# pretrain=123, whose type is unsupported
# init_cfg loads pretrain from an non-existent file
with pytest.raises(AssertionError):
model = BEiT(
pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path))
# pretrain=123, whose type is unsupported
# init_cfg=123, whose type is unsupported
with pytest.raises(AssertionError):
model = BEiT(pretrained=123, init_cfg=123)

View File

@ -0,0 +1,38 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmseg.models import Feature2Pyramid
def test_feature2pyramid():
# test
rescales = [4, 2, 1, 0.5]
embed_dim = 64
inputs = [torch.randn(1, embed_dim, 32, 32) for i in range(len(rescales))]
fpn = Feature2Pyramid(
embed_dim, rescales, norm_cfg=dict(type='BN', requires_grad=True))
outputs = fpn(inputs)
assert outputs[0].shape == torch.Size([1, 64, 128, 128])
assert outputs[1].shape == torch.Size([1, 64, 64, 64])
assert outputs[2].shape == torch.Size([1, 64, 32, 32])
assert outputs[3].shape == torch.Size([1, 64, 16, 16])
# test rescales = [2, 1, 0.5, 0.25]
rescales = [2, 1, 0.5, 0.25]
inputs = [torch.randn(1, embed_dim, 32, 32) for i in range(len(rescales))]
fpn = Feature2Pyramid(
embed_dim, rescales, norm_cfg=dict(type='BN', requires_grad=True))
outputs = fpn(inputs)
assert outputs[0].shape == torch.Size([1, 64, 64, 64])
assert outputs[1].shape == torch.Size([1, 64, 32, 32])
assert outputs[2].shape == torch.Size([1, 64, 16, 16])
assert outputs[3].shape == torch.Size([1, 64, 8, 8])
# test rescales = [4, 2, 0.25, 0]
rescales = [4, 2, 0.25, 0]
with pytest.raises(KeyError):
fpn = Feature2Pyramid(
embed_dim, rescales, norm_cfg=dict(type='BN', requires_grad=True))

View File

@ -0,0 +1,56 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmcv
import torch
from mmcv.runner import CheckpointLoader
def convert_beit(ckpt):
new_ckpt = OrderedDict()
for k, v in ckpt.items():
if k.startswith('patch_embed'):
new_key = k.replace('patch_embed.proj', 'patch_embed.projection')
new_ckpt[new_key] = v
if k.startswith('blocks'):
new_key = k.replace('blocks', 'layers')
if 'norm' in new_key:
new_key = new_key.replace('norm', 'ln')
elif 'mlp.fc1' in new_key:
new_key = new_key.replace('mlp.fc1', 'ffn.layers.0.0')
elif 'mlp.fc2' in new_key:
new_key = new_key.replace('mlp.fc2', 'ffn.layers.1')
new_ckpt[new_key] = v
else:
new_key = k
new_ckpt[new_key] = v
return new_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in official pretrained beit models to'
'MMSegmentation style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
weight = convert_beit(state_dict)
mmcv.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
if __name__ == '__main__':
main()