From d90e7002843af14ab77b79f40ef171f955972c4e Mon Sep 17 00:00:00 2001 From: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> Date: Wed, 27 Apr 2022 11:37:03 +0800 Subject: [PATCH] [Refactor] Add build_pos_embed and build_layers for BEiT (#1517) * [Refactor] Add build_pos_embed and build_layers for BEiT * Update mmseg/models/backbones/beit.py --- mmseg/models/backbones/beit.py | 85 +++++++++++++++++++++------------- 1 file changed, 54 insertions(+), 31 deletions(-) diff --git a/mmseg/models/backbones/beit.py b/mmseg/models/backbones/beit.py index 28e73d12c..fade60137 100644 --- a/mmseg/models/backbones/beit.py +++ b/mmseg/models/backbones/beit.py @@ -306,25 +306,31 @@ class BEiT(BaseModule): elif pretrained is not None: raise TypeError('pretrained must be a str or None') + self.in_channels = in_channels self.img_size = img_size self.patch_size = patch_size self.norm_eval = norm_eval self.pretrained = pretrained - - self.patch_embed = PatchEmbed( - in_channels=in_channels, - embed_dims=embed_dims, - conv_type='Conv2d', - kernel_size=patch_size, - stride=patch_size, - padding=0, - norm_cfg=norm_cfg if patch_norm else None, - init_cfg=None) - - window_size = (img_size[0] // patch_size, img_size[1] // patch_size) - self.patch_shape = window_size + self.num_layers = num_layers + self.embed_dims = embed_dims + self.num_heads = num_heads + self.mlp_ratio = mlp_ratio + self.attn_drop_rate = attn_drop_rate + self.drop_path_rate = drop_path_rate + self.num_fcs = num_fcs + self.qv_bias = qv_bias + self.act_cfg = act_cfg + self.norm_cfg = norm_cfg + self.patch_norm = patch_norm + self.init_values = init_values + self.window_size = (img_size[0] // patch_size, + img_size[1] // patch_size) + self.patch_shape = self.window_size self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) + self._build_patch_embedding() + self._build_layers() + if isinstance(out_indices, int): if out_indices == -1: out_indices = num_layers - 1 @@ -334,29 +340,47 @@ class BEiT(BaseModule): else: raise TypeError('out_indices must be type of int, list or tuple') - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] - self.layers = ModuleList() - for i in range(num_layers): - self.layers.append( - BEiTTransformerEncoderLayer( - embed_dims=embed_dims, - num_heads=num_heads, - feedforward_channels=mlp_ratio * embed_dims, - attn_drop_rate=attn_drop_rate, - drop_path_rate=dpr[i], - num_fcs=num_fcs, - bias='qv_bias' if qv_bias else False, - act_cfg=act_cfg, - norm_cfg=norm_cfg, - window_size=window_size, - init_values=init_values)) - self.final_norm = final_norm if final_norm: self.norm1_name, norm1 = build_norm_layer( norm_cfg, embed_dims, postfix=1) self.add_module(self.norm1_name, norm1) + def _build_patch_embedding(self): + """Build patch embedding layer.""" + self.patch_embed = PatchEmbed( + in_channels=self.in_channels, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=self.patch_size, + stride=self.patch_size, + padding=0, + norm_cfg=self.norm_cfg if self.patch_norm else None, + init_cfg=None) + + def _build_layers(self): + """Build transformer encoding layers.""" + + dpr = [ + x.item() + for x in torch.linspace(0, self.drop_path_rate, self.num_layers) + ] + self.layers = ModuleList() + for i in range(self.num_layers): + self.layers.append( + BEiTTransformerEncoderLayer( + embed_dims=self.embed_dims, + num_heads=self.num_heads, + feedforward_channels=self.mlp_ratio * self.embed_dims, + attn_drop_rate=self.attn_drop_rate, + drop_path_rate=dpr[i], + num_fcs=self.num_fcs, + bias='qv_bias' if self.qv_bias else False, + act_cfg=self.act_cfg, + norm_cfg=self.norm_cfg, + window_size=self.window_size, + init_values=self.init_values)) + @property def norm1(self): return getattr(self, self.norm1_name) @@ -419,7 +443,6 @@ class BEiT(BaseModule): https://github.com/microsoft/unilm/blob/master/beit/semantic_segmentation/mmcv_custom/checkpoint.py. # noqa: E501 Copyright (c) Microsoft Corporation Licensed under the MIT License - Args: checkpoint (dict): Key and value of the pretrain model. Returns: