mmcv/mmcv/onnx/onnx_utils/symbolic_helper.py

326 lines
11 KiB
Python

"""Modified from https://github.com/pytorch/pytorch."""
import warnings
from functools import wraps
from sys import maxsize
import torch
import torch.onnx
# This import monkey-patches graph manipulation methods on Graph, used for the
# ONNX symbolics
import torch.onnx.utils
from torch._C import ListType
# ---------------------------------------------------------------------------------
# Helper functions
# ---------------------------------------------------------------------------------
# Save some builtins as locals, because we'll shadown them below
_sum = sum
def _parse_arg(value, desc):
if desc == 'none':
return value
if desc == 'v' or not _is_value(value):
return value
if value.node().mustBeNone():
return None
if value.node().kind() == 'onnx::Constant':
tval = value.node()['value']
if desc == 'i':
return int(tval)
elif desc == 'f':
return float(tval)
elif desc == 'b':
return bool(tval)
elif desc == 's':
return str(tval)
elif desc == 't':
return tval
elif desc == 'is':
return [int(v) for v in tval]
elif desc == 'fs':
return [float(v) for v in tval]
else:
raise RuntimeError(
"ONNX symbolic doesn't know to interpret Constant node")
elif value.node().kind() == 'prim::ListConstruct':
if desc == 'is':
for v in value.node().inputs():
if v.node().kind() != 'onnx::Constant':
raise RuntimeError(
"Failed to export an ONNX attribute '" +
v.node().kind() +
"', since it's not constant, please try to make "
'things (e.g., kernel size) static if possible')
return [int(v.node()['value']) for v in value.node().inputs()]
else:
raise RuntimeError(
"ONNX symbolic doesn't know to interpret ListConstruct node")
raise RuntimeError('Unexpected node type: {}'.format(value.node().kind()))
def _maybe_get_const(value, desc):
if _is_value(value) and value.node().kind() == 'onnx::Constant':
return _parse_arg(value, desc)
return value
def _maybe_get_scalar(value):
value_t = _maybe_get_const(value, 't')
if isinstance(value_t, torch.Tensor) and value_t.shape == ():
return value_t
return value
def _get_const(value, desc, arg_name):
if _is_value(value) and value.node().kind() not in ('onnx::Constant',
'prim::Constant'):
raise RuntimeError('ONNX symbolic expected a constant'
' value of the {} argument, got `{}`'.format(
arg_name, value))
return _parse_arg(value, desc)
def _unpack_list(list_value):
list_node = list_value.node()
assert list_node.kind() == 'prim::ListConstruct'
return list(list_node.inputs())
# Check if list_value is output from prim::ListConstruct
# This is usually called before _unpack_list to ensure the list can be
# unpacked.
def _is_packed_list(list_value):
return _is_value(
list_value) and list_value.node().kind() == 'prim::ListConstruct'
def parse_args(*arg_descriptors):
def decorator(fn):
fn._arg_descriptors = arg_descriptors
def wrapper(g, *args):
# some args may be optional, so the length may be smaller
assert len(arg_descriptors) >= len(args)
args = [
_parse_arg(arg, arg_desc)
for arg, arg_desc in zip(args, arg_descriptors)
]
return fn(g, *args)
# In Python 2 functools.wraps chokes on partially applied functions, so
# we need this as a workaround
try:
wrapper = wraps(fn)(wrapper)
except Exception:
pass
return wrapper
return decorator
def _scalar(x):
"""Convert a scalar tensor into a Python value."""
assert x.numel() == 1
return x.item()
def _if_scalar_type_as(g, self, tensor):
"""Convert self into the same type of tensor, as necessary."""
if isinstance(self, torch._C.Value):
return self
scalar_type = tensor.type().scalarType()
if scalar_type:
ty = scalar_type.lower()
return getattr(self, ty)()
return self
def _is_none(x):
return x.node().mustBeNone()
def _is_value(x):
return isinstance(x, torch._C.Value)
def _is_tensor_list(x):
return x.type().isSubtypeOf(ListType.ofTensors())
def _unimplemented(op, msg):
warnings.warn('ONNX export failed on ' + op + ' because ' + msg +
' not supported')
def _try_get_scalar_type(*args):
for arg in args:
try:
return arg.type().scalarType()
except RuntimeError:
pass
return None
def _topk_helper(g, input, k, dim, largest=True, sorted=False, out=None):
if out is not None:
_unimplemented('TopK', 'Out parameter is not supported')
if not _is_value(k):
k = g.op('Constant', value_t=torch.tensor([k], dtype=torch.int64))
else:
k = g.op('Reshape', k, g.op('Constant', value_t=torch.tensor([1])))
return g.op(
'TopK',
input,
k,
axis_i=dim,
largest_i=largest,
sorted_i=sorted,
outputs=2)
def _slice_helper(g,
input,
axes,
starts,
ends,
steps=None,
dynamic_slice=False):
# TODO(ruobing): add support for opset<10
from torch.onnx.symbolic_opset10 import _slice
return _slice(g, input, axes, starts, ends, steps, dynamic_slice)
def _unsqueeze_helper(g, input, dim):
from torch.onnx.symbolic_opset9 import unsqueeze
return unsqueeze(g, input, dim)
def _interpolate_size_to_scales(g, input, output_size, dim):
output_size = _maybe_get_const(output_size, 'is')
if _is_value(output_size):
offset = 2
offsets = g.op(
'Constant', value_t=torch.ones(offset, dtype=torch.float32))
dividend = g.op(
'Cast', output_size, to_i=cast_pytorch_to_onnx['Float'])
divisor = _slice_helper(
g, g.op('Shape', input), axes=[0], ends=[maxsize], starts=[offset])
divisor = g.op('Cast', divisor, to_i=cast_pytorch_to_onnx['Float'])
scale_dims = g.op('Div', dividend, divisor)
scales = g.op('Concat', offsets, scale_dims, axis_i=0)
else:
scales_constant = [
1. if i < 2 else float(output_size[-(dim - i)]) /
float(input.type().sizes()[-(dim - i)]) for i in range(0, dim)
]
scales = g.op(
'Constant',
value_t=torch.tensor(scales_constant, dtype=torch.float32))
return scales
def _interpolate_get_scales_if_available(g, scales):
if len(scales) == 0:
return None
# scales[0] is ListType in Pytorch == 1.7.0
scale_desc = 'fs' if scales[0].type().kind() == 'ListType' else 'f'
available_scales = _maybe_get_const(
scales[0], scale_desc) != -1 and not _is_none(scales[0])
if not available_scales:
return None
offsets = g.op('Constant', value_t=torch.ones(2, dtype=torch.float32))
if scale_desc == 'fs':
scales_list = g.op(
'Constant',
value_t=torch.tensor(_maybe_get_const(scales[0], scale_desc)))
# modify to support PyTorch==1.7.0
# https://github.com/pytorch/pytorch/blob/75ee5756715e7161314ce037474843b68f69fc04/torch/onnx/symbolic_helper.py#L375 # noqa: E501
scales = g.op('Concat', offsets, scales_list, axis_i=0)
else:
# for PyTorch < 1.7.0
scales_list = []
for scale in scales:
unsqueezed_scale = _unsqueeze_helper(g, scale, 0)
# ONNX only supports float for the scales. double -> float.
unsqueezed_scale = g.op(
'Cast', unsqueezed_scale, to_i=cast_pytorch_to_onnx['Float'])
scales_list.append(unsqueezed_scale)
scales = g.op('Concat', offsets, *scales_list, axis_i=0)
return scales
def _get_interpolate_attributes(g, mode, args):
if mode == 'nearest':
align_corners = None
scales = args[0:]
else:
align_corners = args[0]
scales = args[1:]
scales = _interpolate_get_scales_if_available(g, scales)
return scales, align_corners
def _interpolate_get_scales(g, scale_factor, dim):
offsets = g.op('Constant', value_t=torch.ones(2, dtype=torch.float32))
if isinstance(scale_factor.type(), torch._C.ListType):
return g.op('Concat', offsets, scale_factor, axis_i=0)
else:
scale_factor = _unsqueeze_helper(g, scale_factor, 0)
scale_factor = g.op(
'Cast', scale_factor, to_i=cast_pytorch_to_onnx['Float'])
scales = [scale_factor for i in range(dim - 2)]
scale_factor = g.op('Concat', offsets, *scales, axis_i=0)
return scale_factor
def _size_helper(g, self, dim):
full_shape = g.op('Shape', self)
from torch.onnx.symbolic_opset9 import select
return select(g, full_shape, g.op('Constant', value_t=torch.tensor([0])),
dim)
def _avgpool_helper(tuple_fn, padding, kernel_size, stride, divisor_override,
name):
if divisor_override and divisor_override.node().kind() != 'prim::Constant':
return _unimplemented(name, 'divisor_override')
if not stride:
stride = kernel_size
padding = tuple(tuple_fn(padding))
return padding
# Metaprogram symbolics for each ATen native specialized cast operator.
# For e.g. we specify a function named `_cast_uint8_t` that instantiates an
# ONNX cast node with `to` attribute 'UINT8'
#
# TODO: remove these once we support Type's in the JIT IR and we can once again
# use the unified toType operator
cast_pytorch_to_onnx = {
'Byte': torch.onnx.TensorProtoDataType.UINT8,
'Char': torch.onnx.TensorProtoDataType.INT8,
'Double': torch.onnx.TensorProtoDataType.DOUBLE,
'Float': torch.onnx.TensorProtoDataType.FLOAT,
'Half': torch.onnx.TensorProtoDataType.FLOAT16,
'Int': torch.onnx.TensorProtoDataType.INT32,
'Long': torch.onnx.TensorProtoDataType.INT64,
'Short': torch.onnx.TensorProtoDataType.INT16,
'Bool': torch.onnx.TensorProtoDataType.BOOL,
'ComplexFloat': torch.onnx.TensorProtoDataType.COMPLEX64,
'ComplexDouble': torch.onnx.TensorProtoDataType.COMPLEX128,
'Undefined': torch.onnx.TensorProtoDataType.UNDEFINED,
}
# Global set to store the list of quantized operators in the network.
# This is currently only used in the conversion of quantized ops from PT
# -> C2 via ONNX.
_quantized_ops = set()