fix wrappers when using parrots (#613)

* fix wrappers when using parrots

* linting

* refactor according to review
pull/615/head
Ryan Li 2020-10-13 17:54:16 +08:00 committed by GitHub
parent 993da2bbd7
commit 54c527acd5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 14 additions and 7 deletions

View File

@ -12,9 +12,16 @@ from torch.nn.modules.utils import _pair
from .registry import CONV_LAYERS, UPSAMPLE_LAYERS
# torch.__version__ could be 1.3.1+cu92, we only need the first two
# for comparison
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2])
if torch.__version__ == 'parrots':
TORCH_VERSION = torch.__version__
else:
# torch.__version__ could be 1.3.1+cu92, we only need the first two
# for comparison
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2])
def obsolete_torch_version(torch_version, version_threshold):
return torch_version == 'parrots' or torch_version <= version_threshold
class NewEmptyTensorOp(torch.autograd.Function):
@ -34,7 +41,7 @@ class NewEmptyTensorOp(torch.autograd.Function):
class Conv2d(nn.Conv2d):
def forward(self, x):
if x.numel() == 0 and TORCH_VERSION <= (1, 4):
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size,
self.padding, self.stride, self.dilation):
@ -57,7 +64,7 @@ class Conv2d(nn.Conv2d):
class ConvTranspose2d(nn.ConvTranspose2d):
def forward(self, x):
if x.numel() == 0 and TORCH_VERSION <= (1, 4):
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size,
self.padding, self.stride,
@ -78,7 +85,7 @@ class MaxPool2d(nn.MaxPool2d):
def forward(self, x):
# PyTorch 1.6 does not support empty tensor inference yet
if x.numel() == 0 and TORCH_VERSION <= (1, 6):
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 6)):
out_shape = list(x.shape[:2])
for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size),
_pair(self.padding), _pair(self.stride),
@ -96,7 +103,7 @@ class Linear(torch.nn.Linear):
def forward(self, x):
# empty tensor forward of Linear layer is supported in Pytorch 1.6
if x.numel() == 0 and TORCH_VERSION <= (1, 5):
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 5)):
out_shape = [x.shape[0], self.out_features]
empty = NewEmptyTensorOp.apply(x, out_shape)
if self.training: