mirror of https://github.com/FoundationVision/GLEE
133 lines
4.4 KiB
Python
133 lines
4.4 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
"""
|
|
Wrappers around on some nn functions, mainly to support empty tensors.
|
|
|
|
Ideally, add support directly in PyTorch to empty tensors in those functions.
|
|
|
|
These can be removed once https://github.com/pytorch/pytorch/issues/12013
|
|
is implemented
|
|
"""
|
|
|
|
from typing import List, Optional
|
|
import torch
|
|
from torch.nn import functional as F
|
|
|
|
|
|
def shapes_to_tensor(x: List[int], device: Optional[torch.device] = None) -> torch.Tensor:
|
|
"""
|
|
Turn a list of integer scalars or integer Tensor scalars into a vector,
|
|
in a way that's both traceable and scriptable.
|
|
|
|
In tracing, `x` should be a list of scalar Tensor, so the output can trace to the inputs.
|
|
In scripting or eager, `x` should be a list of int.
|
|
"""
|
|
if torch.jit.is_scripting():
|
|
return torch.as_tensor(x, device=device)
|
|
if torch.jit.is_tracing():
|
|
assert all(
|
|
[isinstance(t, torch.Tensor) for t in x]
|
|
), "Shape should be tensor during tracing!"
|
|
# as_tensor should not be used in tracing because it records a constant
|
|
ret = torch.stack(x)
|
|
if ret.device != device: # avoid recording a hard-coded device if not necessary
|
|
ret = ret.to(device=device)
|
|
return ret
|
|
return torch.as_tensor(x, device=device)
|
|
|
|
|
|
def cat(tensors: List[torch.Tensor], dim: int = 0):
|
|
"""
|
|
Efficient version of torch.cat that avoids a copy if there is only a single element in a list
|
|
"""
|
|
assert isinstance(tensors, (list, tuple))
|
|
if len(tensors) == 1:
|
|
return tensors[0]
|
|
return torch.cat(tensors, dim)
|
|
|
|
|
|
def cross_entropy(input, target, *, reduction="mean", **kwargs):
|
|
"""
|
|
Same as `torch.nn.functional.cross_entropy`, but returns 0 (instead of nan)
|
|
for empty inputs.
|
|
"""
|
|
if target.numel() == 0 and reduction == "mean":
|
|
return input.sum() * 0.0 # connect the gradient
|
|
return F.cross_entropy(input, target, reduction=reduction, **kwargs)
|
|
|
|
|
|
class _NewEmptyTensorOp(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x, new_shape):
|
|
ctx.shape = x.shape
|
|
return x.new_empty(new_shape)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad):
|
|
shape = ctx.shape
|
|
return _NewEmptyTensorOp.apply(grad, shape), None
|
|
|
|
|
|
class Conv2d(torch.nn.Conv2d):
|
|
"""
|
|
A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features.
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
"""
|
|
Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`:
|
|
|
|
Args:
|
|
norm (nn.Module, optional): a normalization layer
|
|
activation (callable(Tensor) -> Tensor): a callable activation function
|
|
|
|
It assumes that norm layer is used before activation.
|
|
"""
|
|
norm = kwargs.pop("norm", None)
|
|
activation = kwargs.pop("activation", None)
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self.norm = norm
|
|
self.activation = activation
|
|
|
|
def forward(self, x):
|
|
# torchscript does not support SyncBatchNorm yet
|
|
# https://github.com/pytorch/pytorch/issues/40507
|
|
# and we skip these codes in torchscript since:
|
|
# 1. currently we only support torchscript in evaluation mode
|
|
# 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or
|
|
# later version, `Conv2d` in these PyTorch versions has already supported empty inputs.
|
|
if not torch.jit.is_scripting():
|
|
if x.numel() == 0 and self.training:
|
|
# https://github.com/pytorch/pytorch/issues/12013
|
|
assert not isinstance(
|
|
self.norm, torch.nn.SyncBatchNorm
|
|
), "SyncBatchNorm does not support empty inputs!"
|
|
|
|
x = F.conv2d(
|
|
x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
|
|
)
|
|
if self.norm is not None:
|
|
x = self.norm(x)
|
|
if self.activation is not None:
|
|
x = self.activation(x)
|
|
return x
|
|
|
|
|
|
ConvTranspose2d = torch.nn.ConvTranspose2d
|
|
BatchNorm2d = torch.nn.BatchNorm2d
|
|
interpolate = F.interpolate
|
|
Linear = torch.nn.Linear
|
|
|
|
|
|
def nonzero_tuple(x):
|
|
"""
|
|
A 'as_tuple=True' version of torch.nonzero to support torchscript.
|
|
because of https://github.com/pytorch/pytorch/issues/38718
|
|
"""
|
|
if torch.jit.is_scripting():
|
|
if x.dim() == 0:
|
|
return x.unsqueeze(0).nonzero().unbind(1)
|
|
return x.nonzero().unbind(1)
|
|
else:
|
|
return x.nonzero(as_tuple=True)
|