RE-OWOD/detectron2/layers/wrappers.py

241 lines
8.7 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Wrappers around on some nn functions, mainly to support empty tensors.
Ideally, add support directly in PyTorch to empty tensors in those functions.
These can be removed once https://github.com/pytorch/pytorch/issues/12013
is implemented
"""
import math
from typing import List
import torch
import torch.nn.functional as F
from torch.nn.modules.utils import _ntuple
from detectron2.utils.env import TORCH_VERSION
def cat(tensors: List[torch.Tensor], dim: int = 0):
"""
Efficient version of torch.cat that avoids a copy if there is only a single element in a list
"""
assert isinstance(tensors, (list, tuple))
if len(tensors) == 1:
return tensors[0]
return torch.cat(tensors, dim)
class _NewEmptyTensorOp(torch.autograd.Function):
@staticmethod
def forward(ctx, x, new_shape):
ctx.shape = x.shape
return x.new_empty(new_shape)
@staticmethod
def backward(ctx, grad):
shape = ctx.shape
return _NewEmptyTensorOp.apply(grad, shape), None
class Conv2d(torch.nn.Conv2d):
"""
A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features.
"""
def __init__(self, *args, **kwargs):
"""
Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`:
Args:
norm (nn.Module, optional): a normalization layer
activation (callable(Tensor) -> Tensor): a callable activation function
It assumes that norm layer is used before activation.
"""
norm = kwargs.pop("norm", None)
activation = kwargs.pop("activation", None)
super().__init__(*args, **kwargs)
self.norm = norm
self.activation = activation
def forward(self, x):
# torchscript does not support SyncBatchNorm yet
# https://github.com/pytorch/pytorch/issues/40507
# and we skip these codes in torchscript since:
# 1. currently we only support torchscript in evaluation mode
# 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or
# later version, `Conv2d` in these PyTorch versions has already supported empty inputs.
if not torch.jit.is_scripting():
if x.numel() == 0 and self.training:
# https://github.com/pytorch/pytorch/issues/12013
assert not isinstance(
self.norm, torch.nn.SyncBatchNorm
), "SyncBatchNorm does not support empty inputs!"
if x.numel() == 0 and TORCH_VERSION <= (1, 4):
assert not isinstance(
self.norm, torch.nn.GroupNorm
), "GroupNorm does not support empty inputs in PyTorch <=1.4!"
# When input is empty, we want to return a empty tensor with "correct" shape,
# So that the following operations will not panic
# if they check for the shape of the tensor.
# This computes the height and width of the output tensor
output_shape = [
(i + 2 * p - (di * (k - 1) + 1)) // s + 1
for i, p, di, k, s in zip(
x.shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride
)
]
output_shape = [x.shape[0], self.weight.shape[0]] + output_shape
empty = _NewEmptyTensorOp.apply(x, output_shape)
if self.training:
# This is to make DDP happy.
# DDP expects all workers to have gradient w.r.t the same set of parameters.
_dummy = sum([x.view(-1)[0] for x in self.parameters()]) * 0.0
return empty + _dummy
else:
return empty
x = F.conv2d(
x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
)
if self.norm is not None:
x = self.norm(x)
if self.activation is not None:
x = self.activation(x)
return x
if TORCH_VERSION > (1, 4):
ConvTranspose2d = torch.nn.ConvTranspose2d
else:
class ConvTranspose2d(torch.nn.ConvTranspose2d):
"""
A wrapper around :class:`torch.nn.ConvTranspose2d` to support zero-size tensor.
"""
def forward(self, x):
if x.numel() > 0:
return super(ConvTranspose2d, self).forward(x)
# get output shape
# When input is empty, we want to return a empty tensor with "correct" shape,
# So that the following operations will not panic
# if they check for the shape of the tensor.
# This computes the height and width of the output tensor
output_shape = [
(i - 1) * d - 2 * p + (di * (k - 1) + 1) + op
for i, p, di, k, d, op in zip(
x.shape[-2:],
self.padding,
self.dilation,
self.kernel_size,
self.stride,
self.output_padding,
)
]
output_shape = [x.shape[0], self.out_channels] + output_shape
# This is to make DDP happy.
# DDP expects all workers to have gradient w.r.t the same set of parameters.
_dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
return _NewEmptyTensorOp.apply(x, output_shape) + _dummy
if TORCH_VERSION > (1, 4):
BatchNorm2d = torch.nn.BatchNorm2d
else:
class BatchNorm2d(torch.nn.BatchNorm2d):
"""
A wrapper around :class:`torch.nn.BatchNorm2d` to support zero-size tensor.
"""
def forward(self, x):
if x.numel() > 0:
return super(BatchNorm2d, self).forward(x)
# get output shape
output_shape = x.shape
return _NewEmptyTensorOp.apply(x, output_shape)
if TORCH_VERSION > (1, 5):
Linear = torch.nn.Linear
else:
class Linear(torch.nn.Linear):
"""
A wrapper around :class:`torch.nn.Linear` to support empty inputs and more features.
Because of https://github.com/pytorch/pytorch/issues/34202
"""
def forward(self, x):
if x.numel() == 0:
output_shape = [x.shape[0], self.weight.shape[0]]
empty = _NewEmptyTensorOp.apply(x, output_shape)
if self.training:
# This is to make DDP happy.
# DDP expects all workers to have gradient w.r.t the same set of parameters.
_dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
return empty + _dummy
else:
return empty
x = super().forward(x)
return x
if TORCH_VERSION > (1, 4):
interpolate = torch.nn.functional.interpolate
else:
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
"""
A wrapper around :func:`torch.nn.functional.interpolate` to support zero-size tensor.
"""
if input.numel() > 0:
return torch.nn.functional.interpolate(
input, size, scale_factor, mode, align_corners=align_corners
)
def _check_size_scale_factor(dim):
if size is None and scale_factor is None:
raise ValueError("either size or scale_factor should be defined")
if size is not None and scale_factor is not None:
raise ValueError("only one of size or scale_factor should be defined")
if (
scale_factor is not None
and isinstance(scale_factor, tuple)
and len(scale_factor) != dim
):
raise ValueError(
"scale_factor shape must match input shape. "
"Input is {}D, scale_factor size is {}".format(dim, len(scale_factor))
)
def _output_size(dim):
_check_size_scale_factor(dim)
if size is not None:
return size
scale_factors = _ntuple(dim)(scale_factor)
# math.floor might return float in py2.7
return [int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim)]
output_shape = tuple(_output_size(2))
output_shape = input.shape[:-2] + output_shape
return _NewEmptyTensorOp.apply(input, output_shape)
def nonzero_tuple(x):
"""
A 'as_tuple=True' version of torch.nonzero to support torchscript.
because of https://github.com/pytorch/pytorch/issues/38718
"""
if x.dim() == 0:
return x.unsqueeze(0).nonzero().unbind(1)
return x.nonzero().unbind(1)