mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix CUDA crash w/ channels-last + CSP models. Remove use of chunk()
This commit is contained in:
parent
317ea3e599
commit
7ef7788ee9
@ -264,9 +264,11 @@ class CrossStage(nn.Module):
|
|||||||
if self.conv_down is not None:
|
if self.conv_down is not None:
|
||||||
x = self.conv_down(x)
|
x = self.conv_down(x)
|
||||||
x = self.conv_exp(x)
|
x = self.conv_exp(x)
|
||||||
xs, xb = x.chunk(2, dim=1)
|
split = x.shape[1] // 2
|
||||||
|
xs, xb = x[:, :split], x[:, split:]
|
||||||
xb = self.blocks(xb)
|
xb = self.blocks(xb)
|
||||||
out = self.conv_transition(torch.cat([xs, self.conv_transition_b(xb)], dim=1))
|
xb = self.conv_transition_b(xb).contiguous()
|
||||||
|
out = self.conv_transition(torch.cat([xs, xb], dim=1))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user