mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
fix swin (#2004)
This commit is contained in:
parent
794af8c06f
commit
b01a79aba7
@ -157,6 +157,7 @@ class WindowAttention(nn.Layer):
|
||||
relative_coords[:, :, 1] += self.window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
||||
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
|
||||
self.register_buffer("relative_position_index",
|
||||
relative_position_index)
|
||||
|
||||
@ -168,6 +169,23 @@ class WindowAttention(nn.Layer):
|
||||
trunc_normal_(self.relative_position_bias_table)
|
||||
self.softmax = nn.Softmax(axis=-1)
|
||||
|
||||
def eval(self, ):
|
||||
# this is used to re-param swin for model export
|
||||
relative_position_bias_table = self.relative_position_bias_table
|
||||
window_size = self.window_size
|
||||
index = self.relative_position_index.reshape([-1])
|
||||
|
||||
relative_position_bias = paddle.index_select(
|
||||
relative_position_bias_table, index)
|
||||
relative_position_bias = relative_position_bias.reshape([
|
||||
window_size[0] * window_size[1], window_size[0] * window_size[1],
|
||||
-1
|
||||
]) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.transpose(
|
||||
[2, 0, 1]) # nH, Wh*Ww, Wh*Ww
|
||||
relative_position_bias = relative_position_bias.unsqueeze(0)
|
||||
self.register_buffer("relative_position_bias", relative_position_bias)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
"""
|
||||
Args:
|
||||
@ -183,18 +201,21 @@ class WindowAttention(nn.Layer):
|
||||
q = q * self.scale
|
||||
attn = paddle.mm(q, k.transpose([0, 1, 3, 2]))
|
||||
|
||||
index = self.relative_position_index.reshape([-1])
|
||||
if self.training or not hasattr(self, "relative_position_bias"):
|
||||
index = self.relative_position_index.reshape([-1])
|
||||
|
||||
relative_position_bias = paddle.index_select(
|
||||
self.relative_position_bias_table, index)
|
||||
relative_position_bias = relative_position_bias.reshape([
|
||||
self.window_size[0] * self.window_size[1],
|
||||
self.window_size[0] * self.window_size[1], -1
|
||||
]) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = paddle.index_select(
|
||||
self.relative_position_bias_table, index)
|
||||
relative_position_bias = relative_position_bias.reshape([
|
||||
self.window_size[0] * self.window_size[1],
|
||||
self.window_size[0] * self.window_size[1], -1
|
||||
]) # Wh*Ww,Wh*Ww,nH
|
||||
|
||||
relative_position_bias = relative_position_bias.transpose(
|
||||
[2, 0, 1]) # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
relative_position_bias = relative_position_bias.transpose(
|
||||
[2, 0, 1]) # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
else:
|
||||
attn = attn + self.relative_position_bias
|
||||
|
||||
if mask is not None:
|
||||
nW = mask.shape[0]
|
||||
|
Loading…
x
Reference in New Issue
Block a user