mirror of https://github.com/NVlabs/SegFormer.git
Adding .contiguous() after transpose or permutation
parent
1a8ad5123a
commit
9740efab3e
|
@ -106,11 +106,11 @@ class Attention(nn.Module):
|
|||
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
k, v = kv[0], kv[1]
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = ((q @ k.transpose(-2, -1)) * self.scale).contiguous()
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C).contiguous()
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
|
||||
|
@ -194,7 +194,7 @@ class OverlapPatchEmbed(nn.Module):
|
|||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
_, _, H, W = x.shape
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = x.flatten(2).transpose(1, 2).contiguous()
|
||||
x = self.norm(x)
|
||||
|
||||
return x, H, W
|
||||
|
@ -362,9 +362,9 @@ class DWConv(nn.Module):
|
|||
|
||||
def forward(self, x, H, W):
|
||||
B, N, C = x.shape
|
||||
x = x.transpose(1, 2).view(B, C, H, W)
|
||||
x = x.transpose(1, 2).view(B, C, H, W).contiguous()
|
||||
x = self.dwconv(x)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = x.flatten(2).transpose(1, 2).contiguous()
|
||||
|
||||
return x
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ class MLP(nn.Module):
|
|||
self.proj = nn.Linear(input_dim, embed_dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = x.flatten(2).transpose(1, 2).contiguous()
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
@ -68,16 +68,16 @@ class SegFormerHead(BaseDecodeHead):
|
|||
############## MLP decoder on C1-C4 ###########
|
||||
n, _, h, w = c4.shape
|
||||
|
||||
_c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3])
|
||||
_c4 = (self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3])).contiguous()
|
||||
_c4 = resize(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False)
|
||||
|
||||
_c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3])
|
||||
_c3 = (self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3])).contiguous()
|
||||
_c3 = resize(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False)
|
||||
|
||||
_c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3])
|
||||
_c2 = (self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3])).contiguous()
|
||||
_c2 = resize(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False)
|
||||
|
||||
_c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3])
|
||||
_c1 = (self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3])).contiguous()
|
||||
|
||||
_c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1))
|
||||
|
||||
|
|
Loading…
Reference in New Issue