mirror of https://github.com/open-mmlab/mmcv.git
fix wrappers when using parrots (#613)
* fix wrappers when using parrots * linting * refactor according to reviewpull/615/head
parent
993da2bbd7
commit
54c527acd5
|
@ -12,9 +12,16 @@ from torch.nn.modules.utils import _pair
|
|||
|
||||
from .registry import CONV_LAYERS, UPSAMPLE_LAYERS
|
||||
|
||||
# torch.__version__ could be 1.3.1+cu92, we only need the first two
|
||||
# for comparison
|
||||
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2])
|
||||
if torch.__version__ == 'parrots':
|
||||
TORCH_VERSION = torch.__version__
|
||||
else:
|
||||
# torch.__version__ could be 1.3.1+cu92, we only need the first two
|
||||
# for comparison
|
||||
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2])
|
||||
|
||||
|
||||
def obsolete_torch_version(torch_version, version_threshold):
|
||||
return torch_version == 'parrots' or torch_version <= version_threshold
|
||||
|
||||
|
||||
class NewEmptyTensorOp(torch.autograd.Function):
|
||||
|
@ -34,7 +41,7 @@ class NewEmptyTensorOp(torch.autograd.Function):
|
|||
class Conv2d(nn.Conv2d):
|
||||
|
||||
def forward(self, x):
|
||||
if x.numel() == 0 and TORCH_VERSION <= (1, 4):
|
||||
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
|
||||
out_shape = [x.shape[0], self.out_channels]
|
||||
for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size,
|
||||
self.padding, self.stride, self.dilation):
|
||||
|
@ -57,7 +64,7 @@ class Conv2d(nn.Conv2d):
|
|||
class ConvTranspose2d(nn.ConvTranspose2d):
|
||||
|
||||
def forward(self, x):
|
||||
if x.numel() == 0 and TORCH_VERSION <= (1, 4):
|
||||
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
|
||||
out_shape = [x.shape[0], self.out_channels]
|
||||
for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size,
|
||||
self.padding, self.stride,
|
||||
|
@ -78,7 +85,7 @@ class MaxPool2d(nn.MaxPool2d):
|
|||
|
||||
def forward(self, x):
|
||||
# PyTorch 1.6 does not support empty tensor inference yet
|
||||
if x.numel() == 0 and TORCH_VERSION <= (1, 6):
|
||||
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 6)):
|
||||
out_shape = list(x.shape[:2])
|
||||
for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size),
|
||||
_pair(self.padding), _pair(self.stride),
|
||||
|
@ -96,7 +103,7 @@ class Linear(torch.nn.Linear):
|
|||
|
||||
def forward(self, x):
|
||||
# empty tensor forward of Linear layer is supported in Pytorch 1.6
|
||||
if x.numel() == 0 and TORCH_VERSION <= (1, 5):
|
||||
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 5)):
|
||||
out_shape = [x.shape[0], self.out_features]
|
||||
empty = NewEmptyTensorOp.apply(x, out_shape)
|
||||
if self.training:
|
||||
|
|
Loading…
Reference in New Issue