mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Use reshape instead of view in std_conv, causing issues in recent PyTorch in channels_last
This commit is contained in:
parent
da06cc61d4
commit
515121cca1
@ -41,7 +41,7 @@ class StdConv2d(nn.Conv2d):
|
||||
|
||||
def forward(self, x):
|
||||
weight = F.batch_norm(
|
||||
self.weight.view(1, self.out_channels, -1), None, None,
|
||||
self.weight.reshape(1, self.out_channels, -1), None, None,
|
||||
training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
|
||||
x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
return x
|
||||
@ -67,7 +67,7 @@ class StdConv2dSame(nn.Conv2d):
|
||||
if self.same_pad:
|
||||
x = pad_same(x, self.kernel_size, self.stride, self.dilation)
|
||||
weight = F.batch_norm(
|
||||
self.weight.view(1, self.out_channels, -1), None, None,
|
||||
self.weight.reshape(1, self.out_channels, -1), None, None,
|
||||
training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
|
||||
x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
return x
|
||||
@ -96,7 +96,7 @@ class ScaledStdConv2d(nn.Conv2d):
|
||||
|
||||
def forward(self, x):
|
||||
weight = F.batch_norm(
|
||||
self.weight.view(1, self.out_channels, -1), None, None,
|
||||
self.weight.reshape(1, self.out_channels, -1), None, None,
|
||||
weight=(self.gain * self.scale).view(-1),
|
||||
training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
|
||||
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
@ -127,7 +127,7 @@ class ScaledStdConv2dSame(nn.Conv2d):
|
||||
if self.same_pad:
|
||||
x = pad_same(x, self.kernel_size, self.stride, self.dilation)
|
||||
weight = F.batch_norm(
|
||||
self.weight.view(1, self.out_channels, -1), None, None,
|
||||
self.weight.reshape(1, self.out_channels, -1), None, None,
|
||||
weight=(self.gain * self.scale).view(-1),
|
||||
training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
|
||||
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
Loading…
x
Reference in New Issue
Block a user