This commit is contained in:
littletomatodonkey 2022-06-09 15:08:45 +08:00 committed by GitHub
parent 794af8c06f
commit b01a79aba7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -157,6 +157,7 @@ class WindowAttention(nn.Layer):
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index",
relative_position_index)
@ -168,6 +169,23 @@ class WindowAttention(nn.Layer):
trunc_normal_(self.relative_position_bias_table)
self.softmax = nn.Softmax(axis=-1)
def eval(self, ):
# this is used to re-param swin for model export
relative_position_bias_table = self.relative_position_bias_table
window_size = self.window_size
index = self.relative_position_index.reshape([-1])
relative_position_bias = paddle.index_select(
relative_position_bias_table, index)
relative_position_bias = relative_position_bias.reshape([
window_size[0] * window_size[1], window_size[0] * window_size[1],
-1
]) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.transpose(
[2, 0, 1]) # nH, Wh*Ww, Wh*Ww
relative_position_bias = relative_position_bias.unsqueeze(0)
self.register_buffer("relative_position_bias", relative_position_bias)
def forward(self, x, mask=None):
"""
Args:
@ -183,18 +201,21 @@ class WindowAttention(nn.Layer):
q = q * self.scale
attn = paddle.mm(q, k.transpose([0, 1, 3, 2]))
index = self.relative_position_index.reshape([-1])
if self.training or not hasattr(self, "relative_position_bias"):
index = self.relative_position_index.reshape([-1])
relative_position_bias = paddle.index_select(
self.relative_position_bias_table, index)
relative_position_bias = relative_position_bias.reshape([
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1], -1
]) # Wh*Ww,Wh*Ww,nH
relative_position_bias = paddle.index_select(
self.relative_position_bias_table, index)
relative_position_bias = relative_position_bias.reshape([
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1], -1
]) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.transpose(
[2, 0, 1]) # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
relative_position_bias = relative_position_bias.transpose(
[2, 0, 1]) # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
else:
attn = attn + self.relative_position_bias
if mask is not None:
nW = mask.shape[0]