mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix torchscript issue in bat
This commit is contained in:
parent
d17b374f0f
commit
b7a568f065
@ -81,7 +81,7 @@ class BilinearAttnTransform(nn.Module):
|
||||
self.groups = groups
|
||||
self.in_channels = in_channels
|
||||
|
||||
def resize_mat(self, x, t):
|
||||
def resize_mat(self, x, t: int):
|
||||
B, C, block_size, block_size1 = x.shape
|
||||
assert block_size == block_size1
|
||||
if t <= 1:
|
||||
@ -100,10 +100,8 @@ class BilinearAttnTransform(nn.Module):
|
||||
out = self.conv1(x)
|
||||
rp = F.adaptive_max_pool2d(out, (self.block_size, 1))
|
||||
cp = F.adaptive_max_pool2d(out, (1, self.block_size))
|
||||
p = self.conv_p(rp).view(B, self.groups, self.block_size, self.block_size)
|
||||
q = self.conv_q(cp).view(B, self.groups, self.block_size, self.block_size)
|
||||
p = F.sigmoid(p)
|
||||
q = F.sigmoid(q)
|
||||
p = self.conv_p(rp).view(B, self.groups, self.block_size, self.block_size).sigmoid()
|
||||
q = self.conv_q(cp).view(B, self.groups, self.block_size, self.block_size).sigmoid()
|
||||
p = p / p.sum(dim=3, keepdim=True)
|
||||
q = q / q.sum(dim=2, keepdim=True)
|
||||
p = p.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size(
|
||||
|
Loading…
x
Reference in New Issue
Block a user