diff --git a/ppcls/arch/backbone/model_zoo/swin_transformer_v2.py b/ppcls/arch/backbone/model_zoo/swin_transformer_v2.py index d2f89446a..bfb4a17ba 100644 --- a/ppcls/arch/backbone/model_zoo/swin_transformer_v2.py +++ b/ppcls/arch/backbone/model_zoo/swin_transformer_v2.py @@ -422,8 +422,36 @@ class SwinTransformerBlock(nn.Layer): self.register_buffer("attn_mask", attn_mask) - def forward(self, x): - H, W = self.input_resolution + def get_attn_mask(self, height, width, dtype): + if self.shift_size > 0: + # calculate attention mask for shifted window multihead self attention + img_mask = paddle.zeros((1, height, width, 1), dtype=dtype) + height_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), ) + width_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), ) + count = 0 + for height_slice in height_slices: + for width_slice in width_slices: + img_mask[:, height_slice, width_slice, :] = count + count += 1 + + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.reshape( + (-1, self.window_size * self.window_size)) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = masked_fill(attn_mask, attn_mask != 0, float(-100.0)) + attn_mask = masked_fill(attn_mask, attn_mask == 0, float(0.0)) + else: + attn_mask = None + return attn_mask + + def forward(self, x, input_dimensions): + H, W = input_dimensions B, L, C = x.shape assert L == H * W, "input feature has wrong size" @@ -431,6 +459,12 @@ class SwinTransformerBlock(nn.Layer): x = x.reshape([B, H, W, C]) + x, pad_values = pading_for_not_divisible(x, H, W, self.window_size, + "BHWC") + _, height_pad, width_pad, _ = x.shape + + padding_state = pad_values[3] > 0 or pad_values[ + 5] > 0 # change variable name # cyclic shift if self.shift_size > 0: shifted_x = paddle.roll( @@ -511,16 +545,23 @@ class PatchMerging(nn.Layer): self.reduction = nn.Linear(4 * dim, 2 * dim, bias_attr=False) self.norm = norm_layer(2 * dim) - def forward(self, x): + def forward(self, x, input_dimensions): """ x: B, H*W, C """ H, W = input_dimensions B, L, C = x.shape + x = x.reshape([B, H, W, C]) + x, _ = pading_for_not_divisible(x, H, W, 2, "BHWC", function="merge") - x = x.reshape([B, H // 2, 2, W // 2, 2, C]) - x = x.transpose((0, 1, 3, 4, 2, 5)) - x = x.reshape([B, H * W // 4, 4 * C]) # B H/2*W/2 4*C + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = paddle.reshape(x, x.shape) + x = paddle.concat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + + x = x.reshape([B, -1, 4 * C]) # B H/2*W/2 4*C x = self.reduction(x) x = self.norm(x) return x @@ -683,7 +724,11 @@ class PatchEmbed(nn.Layer): # FIXME look at relaxing size constraints assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x).flatten(2).transpose([0, 2, 1]) # B Ph*Pw C + x = self.proj(x) + _, _, height, width = x.shape + output_dimensions = (height, width) + x = x.flatten(2).transpose([0, 2, 1]) # B Ph*Pw C + # x = self.proj(x).flatten(2).transpose([0, 2, 1]) # B Ph*Pw C if self.norm is not None: x = self.norm(x) return x, output_dimensions