Adding .contiguous() after transpose or permutation

pull/115/head^2
Jayden9912 2023-01-16 14:56:24 +08:00
parent 1a8ad5123a
commit 9740efab3e
2 changed files with 10 additions and 10 deletions

View File

@ -106,11 +106,11 @@ class Attention(nn.Module):
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = ((q @ k.transpose(-2, -1)) * self.scale).contiguous()
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = (attn @ v).transpose(1, 2).reshape(B, N, C).contiguous()
x = self.proj(x)
x = self.proj_drop(x)
@ -194,7 +194,7 @@ class OverlapPatchEmbed(nn.Module):
def forward(self, x):
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
x = x.flatten(2).transpose(1, 2).contiguous()
x = self.norm(x)
return x, H, W
@ -362,9 +362,9 @@ class DWConv(nn.Module):
def forward(self, x, H, W):
B, N, C = x.shape
x = x.transpose(1, 2).view(B, C, H, W)
x = x.transpose(1, 2).view(B, C, H, W).contiguous()
x = self.dwconv(x)
x = x.flatten(2).transpose(1, 2)
x = x.flatten(2).transpose(1, 2).contiguous()
return x

View File

@ -26,7 +26,7 @@ class MLP(nn.Module):
self.proj = nn.Linear(input_dim, embed_dim)
def forward(self, x):
x = x.flatten(2).transpose(1, 2)
x = x.flatten(2).transpose(1, 2).contiguous()
x = self.proj(x)
return x
@ -68,16 +68,16 @@ class SegFormerHead(BaseDecodeHead):
############## MLP decoder on C1-C4 ###########
n, _, h, w = c4.shape
_c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3])
_c4 = (self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3])).contiguous()
_c4 = resize(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False)
_c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3])
_c3 = (self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3])).contiguous()
_c3 = resize(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False)
_c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3])
_c2 = (self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3])).contiguous()
_c2 = resize(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False)
_c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3])
_c1 = (self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3])).contiguous()
_c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1))