[Refactor] Add build_pos_embed and build_layers for BEiT (#1517)

* [Refactor] Add build_pos_embed and build_layers for BEiT

* Update mmseg/models/backbones/beit.py
This commit is contained in:
Miao Zheng 2022-04-27 11:37:03 +08:00 committed by GitHub
parent 7c3bf22885
commit d90e700284

View File

@ -306,25 +306,31 @@ class BEiT(BaseModule):
elif pretrained is not None:
raise TypeError('pretrained must be a str or None')
self.in_channels = in_channels
self.img_size = img_size
self.patch_size = patch_size
self.norm_eval = norm_eval
self.pretrained = pretrained
self.patch_embed = PatchEmbed(
in_channels=in_channels,
embed_dims=embed_dims,
conv_type='Conv2d',
kernel_size=patch_size,
stride=patch_size,
padding=0,
norm_cfg=norm_cfg if patch_norm else None,
init_cfg=None)
window_size = (img_size[0] // patch_size, img_size[1] // patch_size)
self.patch_shape = window_size
self.num_layers = num_layers
self.embed_dims = embed_dims
self.num_heads = num_heads
self.mlp_ratio = mlp_ratio
self.attn_drop_rate = attn_drop_rate
self.drop_path_rate = drop_path_rate
self.num_fcs = num_fcs
self.qv_bias = qv_bias
self.act_cfg = act_cfg
self.norm_cfg = norm_cfg
self.patch_norm = patch_norm
self.init_values = init_values
self.window_size = (img_size[0] // patch_size,
img_size[1] // patch_size)
self.patch_shape = self.window_size
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
self._build_patch_embedding()
self._build_layers()
if isinstance(out_indices, int):
if out_indices == -1:
out_indices = num_layers - 1
@ -334,29 +340,47 @@ class BEiT(BaseModule):
else:
raise TypeError('out_indices must be type of int, list or tuple')
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)]
self.layers = ModuleList()
for i in range(num_layers):
self.layers.append(
BEiTTransformerEncoderLayer(
embed_dims=embed_dims,
num_heads=num_heads,
feedforward_channels=mlp_ratio * embed_dims,
attn_drop_rate=attn_drop_rate,
drop_path_rate=dpr[i],
num_fcs=num_fcs,
bias='qv_bias' if qv_bias else False,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
window_size=window_size,
init_values=init_values))
self.final_norm = final_norm
if final_norm:
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, embed_dims, postfix=1)
self.add_module(self.norm1_name, norm1)
def _build_patch_embedding(self):
"""Build patch embedding layer."""
self.patch_embed = PatchEmbed(
in_channels=self.in_channels,
embed_dims=self.embed_dims,
conv_type='Conv2d',
kernel_size=self.patch_size,
stride=self.patch_size,
padding=0,
norm_cfg=self.norm_cfg if self.patch_norm else None,
init_cfg=None)
def _build_layers(self):
"""Build transformer encoding layers."""
dpr = [
x.item()
for x in torch.linspace(0, self.drop_path_rate, self.num_layers)
]
self.layers = ModuleList()
for i in range(self.num_layers):
self.layers.append(
BEiTTransformerEncoderLayer(
embed_dims=self.embed_dims,
num_heads=self.num_heads,
feedforward_channels=self.mlp_ratio * self.embed_dims,
attn_drop_rate=self.attn_drop_rate,
drop_path_rate=dpr[i],
num_fcs=self.num_fcs,
bias='qv_bias' if self.qv_bias else False,
act_cfg=self.act_cfg,
norm_cfg=self.norm_cfg,
window_size=self.window_size,
init_values=self.init_values))
@property
def norm1(self):
return getattr(self, self.norm1_name)
@ -419,7 +443,6 @@ class BEiT(BaseModule):
https://github.com/microsoft/unilm/blob/master/beit/semantic_segmentation/mmcv_custom/checkpoint.py. # noqa: E501
Copyright (c) Microsoft Corporation
Licensed under the MIT License
Args:
checkpoint (dict): Key and value of the pretrain model.
Returns: