fix swin_transformerv2
parent
104b378828
commit
ad46ac0765
|
@ -422,8 +422,36 @@ class SwinTransformerBlock(nn.Layer):
|
|||
|
||||
self.register_buffer("attn_mask", attn_mask)
|
||||
|
||||
def forward(self, x):
|
||||
H, W = self.input_resolution
|
||||
def get_attn_mask(self, height, width, dtype):
|
||||
if self.shift_size > 0:
|
||||
# calculate attention mask for shifted window multihead self attention
|
||||
img_mask = paddle.zeros((1, height, width, 1), dtype=dtype)
|
||||
height_slices = (
|
||||
slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None), )
|
||||
width_slices = (
|
||||
slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None), )
|
||||
count = 0
|
||||
for height_slice in height_slices:
|
||||
for width_slice in width_slices:
|
||||
img_mask[:, height_slice, width_slice, :] = count
|
||||
count += 1
|
||||
|
||||
mask_windows = window_partition(img_mask, self.window_size)
|
||||
mask_windows = mask_windows.reshape(
|
||||
(-1, self.window_size * self.window_size))
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = masked_fill(attn_mask, attn_mask != 0, float(-100.0))
|
||||
attn_mask = masked_fill(attn_mask, attn_mask == 0, float(0.0))
|
||||
else:
|
||||
attn_mask = None
|
||||
return attn_mask
|
||||
|
||||
def forward(self, x, input_dimensions):
|
||||
H, W = input_dimensions
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, "input feature has wrong size"
|
||||
|
||||
|
@ -431,6 +459,12 @@ class SwinTransformerBlock(nn.Layer):
|
|||
|
||||
x = x.reshape([B, H, W, C])
|
||||
|
||||
x, pad_values = pading_for_not_divisible(x, H, W, self.window_size,
|
||||
"BHWC")
|
||||
_, height_pad, width_pad, _ = x.shape
|
||||
|
||||
padding_state = pad_values[3] > 0 or pad_values[
|
||||
5] > 0 # change variable name
|
||||
# cyclic shift
|
||||
if self.shift_size > 0:
|
||||
shifted_x = paddle.roll(
|
||||
|
@ -511,16 +545,23 @@ class PatchMerging(nn.Layer):
|
|||
self.reduction = nn.Linear(4 * dim, 2 * dim, bias_attr=False)
|
||||
self.norm = norm_layer(2 * dim)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, input_dimensions):
|
||||
"""
|
||||
x: B, H*W, C
|
||||
"""
|
||||
H, W = input_dimensions
|
||||
B, L, C = x.shape
|
||||
x = x.reshape([B, H, W, C])
|
||||
x, _ = pading_for_not_divisible(x, H, W, 2, "BHWC", function="merge")
|
||||
|
||||
x = x.reshape([B, H // 2, 2, W // 2, 2, C])
|
||||
x = x.transpose((0, 1, 3, 4, 2, 5))
|
||||
x = x.reshape([B, H * W // 4, 4 * C]) # B H/2*W/2 4*C
|
||||
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
||||
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
||||
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
||||
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
||||
x = paddle.reshape(x, x.shape)
|
||||
x = paddle.concat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
||||
|
||||
x = x.reshape([B, -1, 4 * C]) # B H/2*W/2 4*C
|
||||
x = self.reduction(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
@ -683,7 +724,11 @@ class PatchEmbed(nn.Layer):
|
|||
# FIXME look at relaxing size constraints
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x = self.proj(x).flatten(2).transpose([0, 2, 1]) # B Ph*Pw C
|
||||
x = self.proj(x)
|
||||
_, _, height, width = x.shape
|
||||
output_dimensions = (height, width)
|
||||
x = x.flatten(2).transpose([0, 2, 1]) # B Ph*Pw C
|
||||
# x = self.proj(x).flatten(2).transpose([0, 2, 1]) # B Ph*Pw C
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
return x, output_dimensions
|
||||
|
|
Loading…
Reference in New Issue