mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Refactor] Add build_pos_embed and build_layers for BEiT (#1517)
* [Refactor] Add build_pos_embed and build_layers for BEiT * Update mmseg/models/backbones/beit.py
This commit is contained in:
parent
7c3bf22885
commit
d90e700284
@ -306,25 +306,31 @@ class BEiT(BaseModule):
|
|||||||
elif pretrained is not None:
|
elif pretrained is not None:
|
||||||
raise TypeError('pretrained must be a str or None')
|
raise TypeError('pretrained must be a str or None')
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
self.img_size = img_size
|
self.img_size = img_size
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
self.norm_eval = norm_eval
|
self.norm_eval = norm_eval
|
||||||
self.pretrained = pretrained
|
self.pretrained = pretrained
|
||||||
|
self.num_layers = num_layers
|
||||||
self.patch_embed = PatchEmbed(
|
self.embed_dims = embed_dims
|
||||||
in_channels=in_channels,
|
self.num_heads = num_heads
|
||||||
embed_dims=embed_dims,
|
self.mlp_ratio = mlp_ratio
|
||||||
conv_type='Conv2d',
|
self.attn_drop_rate = attn_drop_rate
|
||||||
kernel_size=patch_size,
|
self.drop_path_rate = drop_path_rate
|
||||||
stride=patch_size,
|
self.num_fcs = num_fcs
|
||||||
padding=0,
|
self.qv_bias = qv_bias
|
||||||
norm_cfg=norm_cfg if patch_norm else None,
|
self.act_cfg = act_cfg
|
||||||
init_cfg=None)
|
self.norm_cfg = norm_cfg
|
||||||
|
self.patch_norm = patch_norm
|
||||||
window_size = (img_size[0] // patch_size, img_size[1] // patch_size)
|
self.init_values = init_values
|
||||||
self.patch_shape = window_size
|
self.window_size = (img_size[0] // patch_size,
|
||||||
|
img_size[1] // patch_size)
|
||||||
|
self.patch_shape = self.window_size
|
||||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
|
||||||
|
|
||||||
|
self._build_patch_embedding()
|
||||||
|
self._build_layers()
|
||||||
|
|
||||||
if isinstance(out_indices, int):
|
if isinstance(out_indices, int):
|
||||||
if out_indices == -1:
|
if out_indices == -1:
|
||||||
out_indices = num_layers - 1
|
out_indices = num_layers - 1
|
||||||
@ -334,29 +340,47 @@ class BEiT(BaseModule):
|
|||||||
else:
|
else:
|
||||||
raise TypeError('out_indices must be type of int, list or tuple')
|
raise TypeError('out_indices must be type of int, list or tuple')
|
||||||
|
|
||||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)]
|
|
||||||
self.layers = ModuleList()
|
|
||||||
for i in range(num_layers):
|
|
||||||
self.layers.append(
|
|
||||||
BEiTTransformerEncoderLayer(
|
|
||||||
embed_dims=embed_dims,
|
|
||||||
num_heads=num_heads,
|
|
||||||
feedforward_channels=mlp_ratio * embed_dims,
|
|
||||||
attn_drop_rate=attn_drop_rate,
|
|
||||||
drop_path_rate=dpr[i],
|
|
||||||
num_fcs=num_fcs,
|
|
||||||
bias='qv_bias' if qv_bias else False,
|
|
||||||
act_cfg=act_cfg,
|
|
||||||
norm_cfg=norm_cfg,
|
|
||||||
window_size=window_size,
|
|
||||||
init_values=init_values))
|
|
||||||
|
|
||||||
self.final_norm = final_norm
|
self.final_norm = final_norm
|
||||||
if final_norm:
|
if final_norm:
|
||||||
self.norm1_name, norm1 = build_norm_layer(
|
self.norm1_name, norm1 = build_norm_layer(
|
||||||
norm_cfg, embed_dims, postfix=1)
|
norm_cfg, embed_dims, postfix=1)
|
||||||
self.add_module(self.norm1_name, norm1)
|
self.add_module(self.norm1_name, norm1)
|
||||||
|
|
||||||
|
def _build_patch_embedding(self):
|
||||||
|
"""Build patch embedding layer."""
|
||||||
|
self.patch_embed = PatchEmbed(
|
||||||
|
in_channels=self.in_channels,
|
||||||
|
embed_dims=self.embed_dims,
|
||||||
|
conv_type='Conv2d',
|
||||||
|
kernel_size=self.patch_size,
|
||||||
|
stride=self.patch_size,
|
||||||
|
padding=0,
|
||||||
|
norm_cfg=self.norm_cfg if self.patch_norm else None,
|
||||||
|
init_cfg=None)
|
||||||
|
|
||||||
|
def _build_layers(self):
|
||||||
|
"""Build transformer encoding layers."""
|
||||||
|
|
||||||
|
dpr = [
|
||||||
|
x.item()
|
||||||
|
for x in torch.linspace(0, self.drop_path_rate, self.num_layers)
|
||||||
|
]
|
||||||
|
self.layers = ModuleList()
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
self.layers.append(
|
||||||
|
BEiTTransformerEncoderLayer(
|
||||||
|
embed_dims=self.embed_dims,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
feedforward_channels=self.mlp_ratio * self.embed_dims,
|
||||||
|
attn_drop_rate=self.attn_drop_rate,
|
||||||
|
drop_path_rate=dpr[i],
|
||||||
|
num_fcs=self.num_fcs,
|
||||||
|
bias='qv_bias' if self.qv_bias else False,
|
||||||
|
act_cfg=self.act_cfg,
|
||||||
|
norm_cfg=self.norm_cfg,
|
||||||
|
window_size=self.window_size,
|
||||||
|
init_values=self.init_values))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def norm1(self):
|
def norm1(self):
|
||||||
return getattr(self, self.norm1_name)
|
return getattr(self, self.norm1_name)
|
||||||
@ -419,7 +443,6 @@ class BEiT(BaseModule):
|
|||||||
https://github.com/microsoft/unilm/blob/master/beit/semantic_segmentation/mmcv_custom/checkpoint.py. # noqa: E501
|
https://github.com/microsoft/unilm/blob/master/beit/semantic_segmentation/mmcv_custom/checkpoint.py. # noqa: E501
|
||||||
Copyright (c) Microsoft Corporation
|
Copyright (c) Microsoft Corporation
|
||||||
Licensed under the MIT License
|
Licensed under the MIT License
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
checkpoint (dict): Key and value of the pretrain model.
|
checkpoint (dict): Key and value of the pretrain model.
|
||||||
Returns:
|
Returns:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user