mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
support converting pytorch to onnx (#429)
* support convert nms to onnx op * add comment for NMSop * add corresponding onnx support for mmdet/mmseg/mmediting * remove onnx part from mmcv/__init__
This commit is contained in:
parent
07e4215286
commit
34ee69e0f6
3
mmcv/onnx/__init__.py
Normal file
3
mmcv/onnx/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .symbolic import register_extra_symbolics
|
||||
|
||||
__all__ = ['register_extra_symbolics']
|
312
mmcv/onnx/onnx_utils/symbolic_helper.py
Normal file
312
mmcv/onnx/onnx_utils/symbolic_helper.py
Normal file
@ -0,0 +1,312 @@
|
||||
"""Modified from https://github.com/pytorch/pytorch."""
|
||||
import warnings
|
||||
from functools import wraps
|
||||
from sys import maxsize
|
||||
|
||||
import torch
|
||||
import torch.onnx
|
||||
# This import monkey-patches graph manipulation methods on Graph, used for the
|
||||
# ONNX symbolics
|
||||
import torch.onnx.utils
|
||||
from torch._C import ListType
|
||||
|
||||
# ---------------------------------------------------------------------------------
|
||||
# Helper functions
|
||||
# ---------------------------------------------------------------------------------
|
||||
|
||||
# Save some builtins as locals, because we'll shadown them below
|
||||
_sum = sum
|
||||
|
||||
|
||||
def _parse_arg(value, desc):
|
||||
if desc == 'none':
|
||||
return value
|
||||
if desc == 'v' or not _is_value(value):
|
||||
return value
|
||||
if value.node().mustBeNone():
|
||||
return None
|
||||
if value.node().kind() == 'onnx::Constant':
|
||||
tval = value.node()['value']
|
||||
if desc == 'i':
|
||||
return int(tval)
|
||||
elif desc == 'f':
|
||||
return float(tval)
|
||||
elif desc == 'b':
|
||||
return bool(tval)
|
||||
elif desc == 's':
|
||||
return str(tval)
|
||||
elif desc == 't':
|
||||
return tval
|
||||
elif desc == 'is':
|
||||
return [int(v) for v in tval]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"ONNX symbolic doesn't know to interpret Constant node")
|
||||
elif value.node().kind() == 'prim::ListConstruct':
|
||||
if desc == 'is':
|
||||
for v in value.node().inputs():
|
||||
if v.node().kind() != 'onnx::Constant':
|
||||
raise RuntimeError(
|
||||
"Failed to export an ONNX attribute '" +
|
||||
v.node().kind() +
|
||||
"', since it's not constant, please try to make "
|
||||
'things (e.g., kernel size) static if possible')
|
||||
return [int(v.node()['value']) for v in value.node().inputs()]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"ONNX symbolic doesn't know to interpret ListConstruct node")
|
||||
|
||||
raise RuntimeError('Unexpected node type: {}'.format(value.node().kind()))
|
||||
|
||||
|
||||
def _maybe_get_const(value, desc):
|
||||
if _is_value(value) and value.node().kind() == 'onnx::Constant':
|
||||
return _parse_arg(value, desc)
|
||||
return value
|
||||
|
||||
|
||||
def _maybe_get_scalar(value):
|
||||
value_t = _maybe_get_const(value, 't')
|
||||
if isinstance(value_t, torch.Tensor) and value_t.shape == ():
|
||||
return value_t
|
||||
return value
|
||||
|
||||
|
||||
def _get_const(value, desc, arg_name):
|
||||
if _is_value(value) and value.node().kind() not in ('onnx::Constant',
|
||||
'prim::Constant'):
|
||||
raise RuntimeError('ONNX symbolic expected a constant'
|
||||
' value of the {} argument, got `{}`'.format(
|
||||
arg_name, value))
|
||||
return _parse_arg(value, desc)
|
||||
|
||||
|
||||
def _unpack_list(list_value):
|
||||
list_node = list_value.node()
|
||||
assert list_node.kind() == 'prim::ListConstruct'
|
||||
return list(list_node.inputs())
|
||||
|
||||
|
||||
# Check if list_value is output from prim::ListConstruct
|
||||
# This is usually called before _unpack_list to ensure the list can be
|
||||
# unpacked.
|
||||
def _is_packed_list(list_value):
|
||||
return _is_value(
|
||||
list_value) and list_value.node().kind() == 'prim::ListConstruct'
|
||||
|
||||
|
||||
def parse_args(*arg_descriptors):
|
||||
|
||||
def decorator(fn):
|
||||
fn._arg_descriptors = arg_descriptors
|
||||
|
||||
def wrapper(g, *args):
|
||||
# some args may be optional, so the length may be smaller
|
||||
assert len(arg_descriptors) >= len(args)
|
||||
args = [
|
||||
_parse_arg(arg, arg_desc)
|
||||
for arg, arg_desc in zip(args, arg_descriptors)
|
||||
]
|
||||
return fn(g, *args)
|
||||
|
||||
# In Python 2 functools.wraps chokes on partially applied functions, so
|
||||
# we need this as a workaround
|
||||
try:
|
||||
wrapper = wraps(fn)(wrapper)
|
||||
except Exception:
|
||||
pass
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _scalar(x):
|
||||
"""Convert a scalar tensor into a Python value."""
|
||||
assert x.numel() == 1
|
||||
return x.item()
|
||||
|
||||
|
||||
def _if_scalar_type_as(g, self, tensor):
|
||||
"""Convert self into the same type of tensor, as necessary."""
|
||||
if isinstance(self, torch._C.Value):
|
||||
return self
|
||||
|
||||
scalar_type = tensor.type().scalarType()
|
||||
if scalar_type:
|
||||
ty = scalar_type.lower()
|
||||
return getattr(self, ty)()
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def _is_none(x):
|
||||
return x.node().mustBeNone()
|
||||
|
||||
|
||||
def _is_value(x):
|
||||
return isinstance(x, torch._C.Value)
|
||||
|
||||
|
||||
def _is_tensor_list(x):
|
||||
return x.type().isSubtypeOf(ListType.ofTensors())
|
||||
|
||||
|
||||
def _unimplemented(op, msg):
|
||||
warnings.warn('ONNX export failed on ' + op + ' because ' + msg +
|
||||
' not supported')
|
||||
|
||||
|
||||
def _try_get_scalar_type(*args):
|
||||
for arg in args:
|
||||
try:
|
||||
return arg.type().scalarType()
|
||||
except RuntimeError:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _topk_helper(g, input, k, dim, largest=True, sorted=False, out=None):
|
||||
if out is not None:
|
||||
_unimplemented('TopK', 'Out parameter is not supported')
|
||||
if not _is_value(k):
|
||||
k = g.op('Constant', value_t=torch.tensor([k], dtype=torch.int64))
|
||||
else:
|
||||
k = g.op('Reshape', k, g.op('Constant', value_t=torch.tensor([1])))
|
||||
return g.op(
|
||||
'TopK',
|
||||
input,
|
||||
k,
|
||||
axis_i=dim,
|
||||
largest_i=largest,
|
||||
sorted_i=sorted,
|
||||
outputs=2)
|
||||
|
||||
|
||||
def _slice_helper(g,
|
||||
input,
|
||||
axes,
|
||||
starts,
|
||||
ends,
|
||||
steps=None,
|
||||
dynamic_slice=False):
|
||||
# TODO(ruobing): add support for opset<10
|
||||
from torch.onnx.symbolic_opset10 import _slice
|
||||
return _slice(g, input, axes, starts, ends, steps, dynamic_slice)
|
||||
|
||||
|
||||
def _unsqueeze_helper(g, input, dim):
|
||||
from torch.onnx.symbolic_opset9 import unsqueeze
|
||||
return unsqueeze(g, input, dim)
|
||||
|
||||
|
||||
def _interpolate_size_to_scales(g, input, output_size, dim):
|
||||
output_size = _maybe_get_const(output_size, 'is')
|
||||
if _is_value(output_size):
|
||||
offset = 2
|
||||
offsets = g.op(
|
||||
'Constant', value_t=torch.ones(offset, dtype=torch.float32))
|
||||
dividend = g.op(
|
||||
'Cast', output_size, to_i=cast_pytorch_to_onnx['Float'])
|
||||
divisor = _slice_helper(
|
||||
g, g.op('Shape', input), axes=[0], ends=[maxsize], starts=[offset])
|
||||
divisor = g.op('Cast', divisor, to_i=cast_pytorch_to_onnx['Float'])
|
||||
scale_dims = g.op('Div', dividend, divisor)
|
||||
scales = g.op('Concat', offsets, scale_dims, axis_i=0)
|
||||
else:
|
||||
scales_constant = [
|
||||
1. if i < 2 else float(output_size[-(dim - i)]) /
|
||||
float(input.type().sizes()[-(dim - i)]) for i in range(0, dim)
|
||||
]
|
||||
scales = g.op(
|
||||
'Constant',
|
||||
value_t=torch.tensor(scales_constant, dtype=torch.float32))
|
||||
return scales
|
||||
|
||||
|
||||
def _interpolate_get_scales_if_available(g, scales):
|
||||
if len(scales) == 0:
|
||||
return None
|
||||
available_scales = _maybe_get_const(scales[0], 'f') != -1 and not _is_none(
|
||||
scales[0])
|
||||
|
||||
if not available_scales:
|
||||
return None
|
||||
|
||||
scales_list = []
|
||||
for scale in scales:
|
||||
unsqueezed_scale = _unsqueeze_helper(g, scale, 0)
|
||||
# ONNX only supports float for the scales. double -> float.
|
||||
unsqueezed_scale = g.op(
|
||||
'Cast', unsqueezed_scale, to_i=cast_pytorch_to_onnx['Float'])
|
||||
scales_list.append(unsqueezed_scale)
|
||||
offsets = g.op('Constant', value_t=torch.ones(2, dtype=torch.float32))
|
||||
scales = g.op('Concat', offsets, *scales_list, axis_i=0)
|
||||
return scales
|
||||
|
||||
|
||||
def _get_interpolate_attributes(g, mode, args):
|
||||
if mode == 'nearest':
|
||||
align_corners = None
|
||||
scales = args[0:]
|
||||
else:
|
||||
align_corners = args[0]
|
||||
scales = args[1:]
|
||||
scales = _interpolate_get_scales_if_available(g, scales)
|
||||
return scales, align_corners
|
||||
|
||||
|
||||
def _interpolate_get_scales(g, scale_factor, dim):
|
||||
offsets = g.op('Constant', value_t=torch.ones(2, dtype=torch.float32))
|
||||
if isinstance(scale_factor.type(), torch._C.ListType):
|
||||
return g.op('Concat', offsets, scale_factor, axis_i=0)
|
||||
else:
|
||||
scale_factor = _unsqueeze_helper(g, scale_factor, 0)
|
||||
scale_factor = g.op(
|
||||
'Cast', scale_factor, to_i=cast_pytorch_to_onnx['Float'])
|
||||
scales = [scale_factor for i in range(dim - 2)]
|
||||
scale_factor = g.op('Concat', offsets, *scales, axis_i=0)
|
||||
return scale_factor
|
||||
|
||||
|
||||
def _size_helper(g, self, dim):
|
||||
full_shape = g.op('Shape', self)
|
||||
from torch.onnx.symbolic_opset9 import select
|
||||
return select(g, full_shape, g.op('Constant', value_t=torch.tensor([0])),
|
||||
dim)
|
||||
|
||||
|
||||
def _avgpool_helper(tuple_fn, padding, kernel_size, stride, divisor_override,
|
||||
name):
|
||||
if divisor_override and divisor_override.node().kind() != 'prim::Constant':
|
||||
return _unimplemented(name, 'divisor_override')
|
||||
if not stride:
|
||||
stride = kernel_size
|
||||
padding = tuple(tuple_fn(padding))
|
||||
return padding
|
||||
|
||||
|
||||
# Metaprogram symbolics for each ATen native specialized cast operator.
|
||||
# For e.g. we specify a function named `_cast_uint8_t` that instantiates an
|
||||
# ONNX cast node with `to` attribute 'UINT8'
|
||||
#
|
||||
# TODO: remove these once we support Type's in the JIT IR and we can once again
|
||||
# use the unified toType operator
|
||||
cast_pytorch_to_onnx = {
|
||||
'Byte': torch.onnx.TensorProtoDataType.UINT8,
|
||||
'Char': torch.onnx.TensorProtoDataType.INT8,
|
||||
'Double': torch.onnx.TensorProtoDataType.DOUBLE,
|
||||
'Float': torch.onnx.TensorProtoDataType.FLOAT,
|
||||
'Half': torch.onnx.TensorProtoDataType.FLOAT16,
|
||||
'Int': torch.onnx.TensorProtoDataType.INT32,
|
||||
'Long': torch.onnx.TensorProtoDataType.INT64,
|
||||
'Short': torch.onnx.TensorProtoDataType.INT16,
|
||||
'Bool': torch.onnx.TensorProtoDataType.BOOL,
|
||||
'ComplexFloat': torch.onnx.TensorProtoDataType.COMPLEX64,
|
||||
'ComplexDouble': torch.onnx.TensorProtoDataType.COMPLEX128,
|
||||
'Undefined': torch.onnx.TensorProtoDataType.UNDEFINED,
|
||||
}
|
||||
|
||||
# Global set to store the list of quantized operators in the network.
|
||||
# This is currently only used in the conversion of quantized ops from PT
|
||||
# -> C2 via ONNX.
|
||||
_quantized_ops = set()
|
327
mmcv/onnx/symbolic.py
Normal file
327
mmcv/onnx/symbolic.py
Normal file
@ -0,0 +1,327 @@
|
||||
"""Modified from https://github.com/pytorch/pytorch."""
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn.modules.utils import _pair, _single, _triple
|
||||
from torch.onnx.symbolic_helper import parse_args
|
||||
from torch.onnx.symbolic_registry import register_op
|
||||
|
||||
from .onnx_utils import symbolic_helper as sym_help
|
||||
|
||||
|
||||
def _interpolate(name, dim, interpolate_mode):
|
||||
|
||||
def symbolic_fn(g, input, output_size, *args):
|
||||
scales, align_corners = sym_help._get_interpolate_attributes(
|
||||
g, interpolate_mode, args)
|
||||
align_corners = sym_help._maybe_get_scalar(align_corners)
|
||||
transformation_mode = 'asymmetric' \
|
||||
if interpolate_mode == 'nearest' \
|
||||
else 'align_corners' if align_corners else 'pytorch_half_pixel'
|
||||
empty_tensor = g.op(
|
||||
'Constant', value_t=torch.tensor([], dtype=torch.float32))
|
||||
|
||||
if scales is None:
|
||||
input_size = g.op('Shape', input)
|
||||
input_size_beg = sym_help._slice_helper(
|
||||
g, input_size, axes=[0], ends=[2], starts=[0])
|
||||
output_size = g.op(
|
||||
'Cast',
|
||||
output_size,
|
||||
to_i=sym_help.cast_pytorch_to_onnx['Long'])
|
||||
output_size = g.op('Concat', input_size_beg, output_size, axis_i=0)
|
||||
scales = g.op(
|
||||
'Constant', value_t=torch.tensor([], dtype=torch.float32))
|
||||
return g.op(
|
||||
'Resize',
|
||||
input,
|
||||
empty_tensor,
|
||||
# roi only takes effect whith
|
||||
# coordinate_transformation_mode="tf_crop_and_resize"
|
||||
scales, # scales is not needed since we are sending out_size
|
||||
output_size,
|
||||
coordinate_transformation_mode_s=transformation_mode,
|
||||
cubic_coeff_a_f=-0.75, # only valid when mode="cubic"
|
||||
mode_s=interpolate_mode, # nearest, linear, or cubic
|
||||
nearest_mode_s='floor') # only valid when mode="nearest"
|
||||
else:
|
||||
return g.op(
|
||||
'Resize',
|
||||
input,
|
||||
empty_tensor,
|
||||
# roi only takes effect with
|
||||
# coordinate_transformation_mode="tf_crop_and_resize"
|
||||
scales, # scales is not needed since we are sending out_size
|
||||
coordinate_transformation_mode_s=transformation_mode,
|
||||
cubic_coeff_a_f=-0.75, # only valid when mode="cubic"
|
||||
mode_s=interpolate_mode, # nearest, linear, or cubic
|
||||
nearest_mode_s='floor') # only valid when mode="nearest"
|
||||
|
||||
return symbolic_fn
|
||||
|
||||
|
||||
upsample_nearest1d = _interpolate('upsample_nearest1d', 3, 'nearest')
|
||||
upsample_nearest2d = _interpolate('upsample_nearest2d', 4, 'nearest')
|
||||
upsample_nearest3d = _interpolate('upsample_nearest3d', 5, 'nearest')
|
||||
upsample_linear1d = _interpolate('upsample_linear1d', 3, 'linear')
|
||||
upsample_bilinear2d = _interpolate('upsample_bilinear2d', 4, 'linear')
|
||||
upsample_trilinear3d = _interpolate('upsample_trilinear3d', 5, 'linear')
|
||||
upsample_bicubic2d = _interpolate('upsample_bicubic2d', 4, 'cubic')
|
||||
|
||||
|
||||
@parse_args('v', 'v', 'i', 'i', 'i', 'none')
|
||||
def topk(g, self, k, dim, largest, sorted, out=None):
|
||||
return sym_help._topk_helper(
|
||||
g, self, k, dim, largest=largest, sorted=sorted, out=out)
|
||||
|
||||
|
||||
def masked_select(g, self, mask):
|
||||
from torch.onnx.symbolic_opset9 import nonzero, expand_as
|
||||
index = nonzero(g, expand_as(g, mask, self))
|
||||
return g.op('GatherND', self, index)
|
||||
|
||||
|
||||
def _prepare_onnx_paddings(g, dim, pad):
|
||||
pad_len = torch.onnx.symbolic_opset9.size(
|
||||
g, pad, g.op('Constant', value_t=torch.tensor([0])))
|
||||
# Set extension = [0] * (dim * 2 - len(pad))
|
||||
extension = g.op(
|
||||
'Sub',
|
||||
g.op('Mul',
|
||||
g.op('Constant', value_t=torch.tensor(dim, dtype=torch.int64)),
|
||||
g.op('Constant', value_t=torch.tensor(2, dtype=torch.int64))),
|
||||
pad_len)
|
||||
pad = g.op('Cast', pad, to_i=sym_help.cast_pytorch_to_onnx['Long'])
|
||||
paddings = g.op(
|
||||
'Concat',
|
||||
pad,
|
||||
g.op(
|
||||
'ConstantOfShape',
|
||||
extension,
|
||||
value_t=torch.tensor([0], dtype=torch.int64)),
|
||||
axis_i=0)
|
||||
paddings = g.op('Reshape', paddings,
|
||||
g.op('Constant', value_t=torch.tensor([-1, 2])))
|
||||
paddings = g.op(
|
||||
'Transpose',
|
||||
torch.onnx.symbolic_opset10.flip(g, paddings, [0]),
|
||||
perm_i=[1, 0])
|
||||
paddings = g.op('Reshape', paddings,
|
||||
g.op('Constant', value_t=torch.tensor([-1])))
|
||||
padding_c = g.op(
|
||||
'Cast', paddings, to_i=sym_help.cast_pytorch_to_onnx['Long'])
|
||||
return padding_c
|
||||
|
||||
|
||||
def constant_pad_nd(g, input, padding, value=None):
|
||||
mode = 'constant'
|
||||
value = sym_help._maybe_get_scalar(value)
|
||||
value = sym_help._if_scalar_type_as(g, value, input)
|
||||
pad = _prepare_onnx_paddings(g, input.type().dim(), padding)
|
||||
return g.op('Pad', input, pad, value, mode_s=mode)
|
||||
|
||||
|
||||
def reflection_pad(g, input, padding):
|
||||
mode = 'reflect'
|
||||
paddings = _prepare_onnx_paddings(g, input.type().dim(), padding)
|
||||
return g.op('Pad', input, paddings, mode_s=mode)
|
||||
|
||||
|
||||
reflection_pad1d = reflection_pad
|
||||
reflection_pad2d = reflection_pad
|
||||
reflection_pad3d = reflection_pad
|
||||
|
||||
|
||||
def _avg_pool(name, tuple_fn):
|
||||
|
||||
@parse_args('v', 'is', 'is', 'is', 'i', 'i', 'none')
|
||||
def symbolic_fn(g,
|
||||
input,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
ceil_mode,
|
||||
count_include_pad,
|
||||
divisor_override=None):
|
||||
padding = sym_help._avgpool_helper(tuple_fn, padding, kernel_size,
|
||||
stride, divisor_override, name)
|
||||
if not stride:
|
||||
stride = kernel_size
|
||||
if count_include_pad:
|
||||
input = g.op(
|
||||
'Pad',
|
||||
input,
|
||||
g.op(
|
||||
'Constant',
|
||||
value_t=torch.tensor(((0, ) * 2 + padding) * 2)),
|
||||
mode_s='constant')
|
||||
padding = (0, ) * len(padding)
|
||||
output = g.op(
|
||||
'AveragePool',
|
||||
input,
|
||||
kernel_shape_i=tuple_fn(kernel_size),
|
||||
strides_i=tuple_fn(stride),
|
||||
pads_i=padding * 2,
|
||||
ceil_mode_i=ceil_mode)
|
||||
return output
|
||||
|
||||
return symbolic_fn
|
||||
|
||||
|
||||
avg_pool1d = _avg_pool('avg_pool1d', _single)
|
||||
avg_pool2d = _avg_pool('avg_pool2d', _pair)
|
||||
avg_pool3d = _avg_pool('avg_pool3d', _triple)
|
||||
|
||||
|
||||
def _get_im2col_indices_along_dim(g, input_d, kernel_size_d, dilation_d,
|
||||
padding_d, stride_d):
|
||||
# Input is always 4-D (N, C, H, W)
|
||||
# Calculate indices of sliding blocks along spatial dimension
|
||||
# Slide kernel over input each dim d:
|
||||
# each dimension d ranges from 0 to
|
||||
# input[d]+2xpadding[d]-dilation[d]x(kernel_size[d]-1)
|
||||
# with steps = stride
|
||||
|
||||
blocks_d = g.op('Add', input_d,
|
||||
g.op('Constant', value_t=torch.tensor(padding_d * 2)))
|
||||
blocks_d = g.op(
|
||||
'Sub', blocks_d,
|
||||
g.op(
|
||||
'Constant',
|
||||
value_t=torch.tensor(dilation_d * (kernel_size_d - 1))))
|
||||
|
||||
# Stride kernel over input and find starting indices along dim d
|
||||
blocks_d_indices = g.op('Range', g.op('Constant', value_t=torch.tensor(0)),
|
||||
blocks_d,
|
||||
g.op('Constant', value_t=torch.tensor(stride_d)))
|
||||
|
||||
# Apply dilation on kernel and find its indices along dim d
|
||||
kernel_grid = np.arange(0, kernel_size_d * dilation_d, dilation_d)
|
||||
kernel_grid = g.op('Constant', value_t=torch.tensor([kernel_grid]))
|
||||
|
||||
# Broadcast and add kernel staring positions (indices) with
|
||||
# kernel_grid along dim d, to get block indices along dim d
|
||||
blocks_d_indices = g.op(
|
||||
'Unsqueeze', blocks_d_indices, axes_i=[0]) # Reshape to [1, -1]
|
||||
kernel_mask = g.op('Reshape', kernel_grid,
|
||||
g.op('Constant', value_t=torch.tensor([-1, 1])))
|
||||
block_mask = g.op('Add', blocks_d_indices, kernel_mask)
|
||||
|
||||
return block_mask
|
||||
|
||||
|
||||
def _get_im2col_padded_input(g, input, padding_h, padding_w):
|
||||
# Input is always 4-D tensor (N, C, H, W)
|
||||
# Padding tensor has the following format: (padding_h, padding_w)
|
||||
# Reshape the padding to follow ONNX format:
|
||||
# (dim1_begin, dim2_begin,...,dim1_end, dim2_end,...)
|
||||
pad = g.op(
|
||||
'Constant', value_t=torch.LongTensor([0, 0, padding_h, padding_w] * 2))
|
||||
return g.op('Pad', input, pad)
|
||||
|
||||
|
||||
def _get_im2col_output_shape(g, input, kernel_h, kernel_w):
|
||||
batch_dim = size(g, input, g.op('Constant', value_t=torch.tensor(0)))
|
||||
channel_dim = size(g, input, g.op('Constant', value_t=torch.tensor(1)))
|
||||
channel_unfolded = g.op(
|
||||
'Mul', channel_dim,
|
||||
g.op('Constant', value_t=torch.tensor(kernel_h * kernel_w)))
|
||||
|
||||
return g.op(
|
||||
'Concat',
|
||||
g.op('Unsqueeze', batch_dim, axes_i=[0]),
|
||||
g.op('Unsqueeze', channel_unfolded, axes_i=[0]),
|
||||
g.op('Constant', value_t=torch.tensor([-1])),
|
||||
axis_i=0)
|
||||
|
||||
|
||||
def size(g, self, dim=None):
|
||||
if dim is None:
|
||||
return g.op('Shape', self)
|
||||
return sym_help._size_helper(g, self, dim)
|
||||
|
||||
|
||||
@parse_args('v', 'is', 'is', 'is', 'is')
|
||||
def im2col(g, input, kernel_size, dilation, padding, stride):
|
||||
# Input is always 4-D tensor (N, C, H, W)
|
||||
# All other args are int[2]
|
||||
|
||||
input_h = size(g, input, g.op('Constant', value_t=torch.tensor(2)))
|
||||
input_w = size(g, input, g.op('Constant', value_t=torch.tensor(3)))
|
||||
|
||||
stride_h, stride_w = stride[0], stride[1]
|
||||
padding_h, padding_w = padding[0], padding[1]
|
||||
dilation_h, dilation_w = dilation[0], dilation[1]
|
||||
kernel_h, kernel_w = kernel_size[0], kernel_size[1]
|
||||
|
||||
blocks_row_indices = _get_im2col_indices_along_dim(g, input_h, kernel_h,
|
||||
dilation_h, padding_h,
|
||||
stride_h)
|
||||
blocks_col_indices = _get_im2col_indices_along_dim(g, input_w, kernel_w,
|
||||
dilation_w, padding_w,
|
||||
stride_w)
|
||||
|
||||
output_shape = _get_im2col_output_shape(g, input, kernel_h, kernel_w)
|
||||
padded_input = _get_im2col_padded_input(g, input, padding_h, padding_w)
|
||||
|
||||
output = g.op('Gather', padded_input, blocks_row_indices, axis_i=2)
|
||||
output = g.op('Gather', output, blocks_col_indices, axis_i=4)
|
||||
output = g.op('Transpose', output, perm_i=[0, 1, 2, 4, 3, 5])
|
||||
return g.op('Reshape', output, output_shape)
|
||||
|
||||
|
||||
@parse_args('v', 'i')
|
||||
def one_hot(g, self, num_classes):
|
||||
values = g.op('Constant', value_t=torch.LongTensor([0, 1]))
|
||||
depth = g.op('Constant', value_t=torch.LongTensor([num_classes]))
|
||||
return g.op('OneHot', self, depth, values, axis_i=-1)
|
||||
|
||||
|
||||
@parse_args('v', 'i', 'none')
|
||||
def softmax(g, input, dim, dtype=None):
|
||||
input_dim = input.type().dim()
|
||||
if input_dim:
|
||||
# TODO: remove this as onnx opset 11 spec allows negative axes
|
||||
if dim < 0:
|
||||
dim = input_dim + dim
|
||||
if input_dim == dim + 1:
|
||||
softmax = g.op('Softmax', input, axis_i=dim)
|
||||
if dtype and dtype.node().kind() != 'prim::Constant':
|
||||
parsed_dtype = sym_help._get_const(dtype, 'i', 'dtype')
|
||||
softmax = g.op(
|
||||
'Cast',
|
||||
softmax,
|
||||
to_i=sym_help.scalar_type_to_onnx[parsed_dtype])
|
||||
return softmax
|
||||
|
||||
max_value = g.op('ReduceMax', input, axes_i=[dim], keepdims_i=1)
|
||||
input = g.op('Sub', input, max_value)
|
||||
exp = g.op('Exp', input)
|
||||
sum = g.op('ReduceSum', exp, axes_i=[dim])
|
||||
softmax = g.op('Div', exp, sum)
|
||||
if dtype and dtype.node().kind() != 'prim::Constant':
|
||||
parsed_dtype = sym_help._get_const(dtype, 'i', 'dtype')
|
||||
softmax = g.op(
|
||||
'Cast', softmax, to_i=sym_help.scalar_type_to_onnx[parsed_dtype])
|
||||
return softmax
|
||||
|
||||
|
||||
def register_extra_symbolics(opset=11):
|
||||
register_op('one_hot', one_hot, '', opset)
|
||||
register_op('im2col', im2col, '', opset)
|
||||
register_op('topk', topk, '', opset)
|
||||
register_op('softmax', softmax, '', opset)
|
||||
register_op('constant_pad_nd', constant_pad_nd, '', opset)
|
||||
register_op('reflection_pad1d', reflection_pad1d, '', opset)
|
||||
register_op('reflection_pad2d', reflection_pad2d, '', opset)
|
||||
register_op('reflection_pad3d', reflection_pad3d, '', opset)
|
||||
register_op('avg_pool1d', avg_pool1d, '', opset)
|
||||
register_op('avg_pool2d', avg_pool2d, '', opset)
|
||||
register_op('avg_pool3d', avg_pool3d, '', opset)
|
||||
register_op('masked_select', masked_select, '', opset)
|
||||
register_op('upsample_nearest1d', upsample_nearest1d, '', opset)
|
||||
register_op('upsample_nearest2d', upsample_nearest2d, '', opset)
|
||||
register_op('upsample_nearest3d', upsample_nearest3d, '', opset)
|
||||
register_op('upsample_linear1d', upsample_linear1d, '', opset)
|
||||
register_op('upsample_bilinear2d', upsample_bilinear2d, '', opset)
|
||||
register_op('upsample_trilinear3d', upsample_trilinear3d, '', opset)
|
||||
register_op('upsample_bicubic2d', upsample_bicubic2d, '', opset)
|
@ -1,5 +1,8 @@
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.onnx.symbolic_opset9 import select, squeeze, unsqueeze
|
||||
|
||||
from mmcv.utils import deprecated_api_warning
|
||||
from ..utils import ext_loader
|
||||
@ -7,6 +10,34 @@ from ..utils import ext_loader
|
||||
ext_module = ext_loader.load_ext('_ext', ['nms', 'softnms', 'nms_match'])
|
||||
|
||||
|
||||
# This function is modified from: https://github.com/pytorch/vision/
|
||||
class NMSop(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, bboxes, scores, iou_threshold, offset):
|
||||
inds = ext_module.nms(
|
||||
bboxes, scores, iou_threshold=float(iou_threshold), offset=offset)
|
||||
return inds
|
||||
|
||||
@staticmethod
|
||||
def symbolic(g, bboxes, scores, iou_threshold, offset):
|
||||
boxes = unsqueeze(g, bboxes, 0)
|
||||
scores = unsqueeze(g, unsqueeze(g, scores, 0), 0)
|
||||
max_output_per_class = g.op(
|
||||
'Constant', value_t=torch.tensor([sys.maxsize], dtype=torch.long))
|
||||
iou_threshold = g.op(
|
||||
'Constant',
|
||||
value_t=torch.tensor([iou_threshold], dtype=torch.float))
|
||||
nms_out = g.op('NonMaxSuppression', boxes, scores,
|
||||
max_output_per_class, iou_threshold)
|
||||
return squeeze(
|
||||
g,
|
||||
select(
|
||||
g, nms_out, 1,
|
||||
g.op('Constant', value_t=torch.tensor([2], dtype=torch.long))),
|
||||
1)
|
||||
|
||||
|
||||
@deprecated_api_warning({'iou_thr': 'iou_threshold'})
|
||||
def nms(boxes, scores, iou_threshold, offset=0):
|
||||
"""Dispatch to either CPU or GPU NMS implementations.
|
||||
@ -75,11 +106,10 @@ def nms(boxes, scores, iou_threshold, offset=0):
|
||||
select = ext_module.nms(*indata_list, **indata_dict)
|
||||
inds = order.masked_select(select)
|
||||
else:
|
||||
inds = ext_module.nms(
|
||||
boxes,
|
||||
scores,
|
||||
iou_threshold=float(iou_threshold),
|
||||
offset=int(offset))
|
||||
if torch.onnx.is_in_onnx_export() and offset == 0:
|
||||
# ONNX only support offset == 1
|
||||
boxes[:, -2:] -= 1
|
||||
inds = NMSop.apply(boxes, scores, iou_threshold, offset)
|
||||
dets = torch.cat((boxes[inds], scores[inds].reshape(-1, 1)), dim=1)
|
||||
if is_numpy:
|
||||
dets = dets.cpu().numpy()
|
||||
|
Loading…
x
Reference in New Issue
Block a user