fix swin_transformerv2

pull/3272/head
zhangyubo0722 2024-10-10 11:06:40 +00:00
parent 104b378828
commit ad46ac0765
1 changed files with 52 additions and 7 deletions

View File

@ -422,8 +422,36 @@ class SwinTransformerBlock(nn.Layer):
self.register_buffer("attn_mask", attn_mask)
def forward(self, x):
H, W = self.input_resolution
def get_attn_mask(self, height, width, dtype):
if self.shift_size > 0:
# calculate attention mask for shifted window multihead self attention
img_mask = paddle.zeros((1, height, width, 1), dtype=dtype)
height_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None), )
width_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None), )
count = 0
for height_slice in height_slices:
for width_slice in width_slices:
img_mask[:, height_slice, width_slice, :] = count
count += 1
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.reshape(
(-1, self.window_size * self.window_size))
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = masked_fill(attn_mask, attn_mask != 0, float(-100.0))
attn_mask = masked_fill(attn_mask, attn_mask == 0, float(0.0))
else:
attn_mask = None
return attn_mask
def forward(self, x, input_dimensions):
H, W = input_dimensions
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
@ -431,6 +459,12 @@ class SwinTransformerBlock(nn.Layer):
x = x.reshape([B, H, W, C])
x, pad_values = pading_for_not_divisible(x, H, W, self.window_size,
"BHWC")
_, height_pad, width_pad, _ = x.shape
padding_state = pad_values[3] > 0 or pad_values[
5] > 0 # change variable name
# cyclic shift
if self.shift_size > 0:
shifted_x = paddle.roll(
@ -511,16 +545,23 @@ class PatchMerging(nn.Layer):
self.reduction = nn.Linear(4 * dim, 2 * dim, bias_attr=False)
self.norm = norm_layer(2 * dim)
def forward(self, x):
def forward(self, x, input_dimensions):
"""
x: B, H*W, C
"""
H, W = input_dimensions
B, L, C = x.shape
x = x.reshape([B, H, W, C])
x, _ = pading_for_not_divisible(x, H, W, 2, "BHWC", function="merge")
x = x.reshape([B, H // 2, 2, W // 2, 2, C])
x = x.transpose((0, 1, 3, 4, 2, 5))
x = x.reshape([B, H * W // 4, 4 * C]) # B H/2*W/2 4*C
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = paddle.reshape(x, x.shape)
x = paddle.concat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.reshape([B, -1, 4 * C]) # B H/2*W/2 4*C
x = self.reduction(x)
x = self.norm(x)
return x
@ -683,7 +724,11 @@ class PatchEmbed(nn.Layer):
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose([0, 2, 1]) # B Ph*Pw C
x = self.proj(x)
_, _, height, width = x.shape
output_dimensions = (height, width)
x = x.flatten(2).transpose([0, 2, 1]) # B Ph*Pw C
# x = self.proj(x).flatten(2).transpose([0, 2, 1]) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x, output_dimensions