Used wrong channel var for split

This commit is contained in:
Ross Wightman 2020-01-26 11:33:31 -08:00
parent 58e28dc7e7
commit 9abe610931

View File

@ -106,7 +106,7 @@ class SelectiveKernelConv(nn.Module):
def forward(self, x):
if self.split_input:
x_split = torch.split(x, self.out_channels // self.num_paths, 1)
x_split = torch.split(x, self.in_channels // self.num_paths, 1)
x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)]
else:
x_paths = [op(x) for op in self.paths]