Fix CUDA crash w/ channels-last + CSP models. Remove use of chunk()

pull/450/head
Ross Wightman 2021-02-23 13:15:52 -08:00
parent da4839530c
commit 4bc103f504
1 changed files with 4 additions and 2 deletions

View File

@ -264,9 +264,11 @@ class CrossStage(nn.Module):
if self.conv_down is not None:
x = self.conv_down(x)
x = self.conv_exp(x)
xs, xb = x.chunk(2, dim=1)
split = x.shape[1] // 2
xs, xb = x[:, :split], x[:, split:]
xb = self.blocks(xb)
out = self.conv_transition(torch.cat([xs, self.conv_transition_b(xb)], dim=1))
xb = self.conv_transition_b(xb).contiguous()
out = self.conv_transition(torch.cat([xs, xb], dim=1))
return out