# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved """ Wrappers around on some nn functions, mainly to support empty tensors. Ideally, add support directly in PyTorch to empty tensors in those functions. These can be removed once https://github.com/pytorch/pytorch/issues/12013 is implemented """ import math from typing import List import torch import torch.nn.functional as F from torch.nn.modules.utils import _ntuple from detectron2.utils.env import TORCH_VERSION def cat(tensors: List[torch.Tensor], dim: int = 0): """ Efficient version of torch.cat that avoids a copy if there is only a single element in a list """ assert isinstance(tensors, (list, tuple)) if len(tensors) == 1: return tensors[0] return torch.cat(tensors, dim) class _NewEmptyTensorOp(torch.autograd.Function): @staticmethod def forward(ctx, x, new_shape): ctx.shape = x.shape return x.new_empty(new_shape) @staticmethod def backward(ctx, grad): shape = ctx.shape return _NewEmptyTensorOp.apply(grad, shape), None class Conv2d(torch.nn.Conv2d): """ A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features. """ def __init__(self, *args, **kwargs): """ Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`: Args: norm (nn.Module, optional): a normalization layer activation (callable(Tensor) -> Tensor): a callable activation function It assumes that norm layer is used before activation. """ norm = kwargs.pop("norm", None) activation = kwargs.pop("activation", None) super().__init__(*args, **kwargs) self.norm = norm self.activation = activation def forward(self, x): # torchscript does not support SyncBatchNorm yet # https://github.com/pytorch/pytorch/issues/40507 # and we skip these codes in torchscript since: # 1. currently we only support torchscript in evaluation mode # 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or # later version, `Conv2d` in these PyTorch versions has already supported empty inputs. if not torch.jit.is_scripting(): if x.numel() == 0 and self.training: # https://github.com/pytorch/pytorch/issues/12013 assert not isinstance( self.norm, torch.nn.SyncBatchNorm ), "SyncBatchNorm does not support empty inputs!" if x.numel() == 0 and TORCH_VERSION <= (1, 4): assert not isinstance( self.norm, torch.nn.GroupNorm ), "GroupNorm does not support empty inputs in PyTorch <=1.4!" # When input is empty, we want to return a empty tensor with "correct" shape, # So that the following operations will not panic # if they check for the shape of the tensor. # This computes the height and width of the output tensor output_shape = [ (i + 2 * p - (di * (k - 1) + 1)) // s + 1 for i, p, di, k, s in zip( x.shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride ) ] output_shape = [x.shape[0], self.weight.shape[0]] + output_shape empty = _NewEmptyTensorOp.apply(x, output_shape) if self.training: # This is to make DDP happy. # DDP expects all workers to have gradient w.r.t the same set of parameters. _dummy = sum([x.view(-1)[0] for x in self.parameters()]) * 0.0 return empty + _dummy else: return empty x = F.conv2d( x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups ) if self.norm is not None: x = self.norm(x) if self.activation is not None: x = self.activation(x) return x if TORCH_VERSION > (1, 4): ConvTranspose2d = torch.nn.ConvTranspose2d else: class ConvTranspose2d(torch.nn.ConvTranspose2d): """ A wrapper around :class:`torch.nn.ConvTranspose2d` to support zero-size tensor. """ def forward(self, x): if x.numel() > 0: return super(ConvTranspose2d, self).forward(x) # get output shape # When input is empty, we want to return a empty tensor with "correct" shape, # So that the following operations will not panic # if they check for the shape of the tensor. # This computes the height and width of the output tensor output_shape = [ (i - 1) * d - 2 * p + (di * (k - 1) + 1) + op for i, p, di, k, d, op in zip( x.shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride, self.output_padding, ) ] output_shape = [x.shape[0], self.out_channels] + output_shape # This is to make DDP happy. # DDP expects all workers to have gradient w.r.t the same set of parameters. _dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0 return _NewEmptyTensorOp.apply(x, output_shape) + _dummy if TORCH_VERSION > (1, 4): BatchNorm2d = torch.nn.BatchNorm2d else: class BatchNorm2d(torch.nn.BatchNorm2d): """ A wrapper around :class:`torch.nn.BatchNorm2d` to support zero-size tensor. """ def forward(self, x): if x.numel() > 0: return super(BatchNorm2d, self).forward(x) # get output shape output_shape = x.shape return _NewEmptyTensorOp.apply(x, output_shape) if TORCH_VERSION > (1, 5): Linear = torch.nn.Linear else: class Linear(torch.nn.Linear): """ A wrapper around :class:`torch.nn.Linear` to support empty inputs and more features. Because of https://github.com/pytorch/pytorch/issues/34202 """ def forward(self, x): if x.numel() == 0: output_shape = [x.shape[0], self.weight.shape[0]] empty = _NewEmptyTensorOp.apply(x, output_shape) if self.training: # This is to make DDP happy. # DDP expects all workers to have gradient w.r.t the same set of parameters. _dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0 return empty + _dummy else: return empty x = super().forward(x) return x if TORCH_VERSION > (1, 4): interpolate = torch.nn.functional.interpolate else: def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): """ A wrapper around :func:`torch.nn.functional.interpolate` to support zero-size tensor. """ if input.numel() > 0: return torch.nn.functional.interpolate( input, size, scale_factor, mode, align_corners=align_corners ) def _check_size_scale_factor(dim): if size is None and scale_factor is None: raise ValueError("either size or scale_factor should be defined") if size is not None and scale_factor is not None: raise ValueError("only one of size or scale_factor should be defined") if ( scale_factor is not None and isinstance(scale_factor, tuple) and len(scale_factor) != dim ): raise ValueError( "scale_factor shape must match input shape. " "Input is {}D, scale_factor size is {}".format(dim, len(scale_factor)) ) def _output_size(dim): _check_size_scale_factor(dim) if size is not None: return size scale_factors = _ntuple(dim)(scale_factor) # math.floor might return float in py2.7 return [int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim)] output_shape = tuple(_output_size(2)) output_shape = input.shape[:-2] + output_shape return _NewEmptyTensorOp.apply(input, output_shape) def nonzero_tuple(x): """ A 'as_tuple=True' version of torch.nonzero to support torchscript. because of https://github.com/pytorch/pytorch/issues/38718 """ if x.dim() == 0: return x.unsqueeze(0).nonzero().unbind(1) return x.nonzero().unbind(1)