fix layer norm (#1015)
parent
fbe7586415
commit
ea7706cbfd
|
@ -2,14 +2,64 @@
|
|||
# Modified from:
|
||||
# https://github.com/pytorch/pytorch/blob/9ade03959392e5a90b74261012de1d806cab2253/torch/onnx/symbolic_opset9.py
|
||||
|
||||
import torch
|
||||
from torch.onnx.symbolic_helper import parse_args
|
||||
|
||||
from mmdeploy.core import SYMBOLIC_REWRITER
|
||||
from mmdeploy.utils import Backend
|
||||
|
||||
|
||||
@SYMBOLIC_REWRITER.register_symbolic(
|
||||
'layer_norm',
|
||||
is_pytorch=True,
|
||||
arg_descriptors=['v', 'is', 'v', 'v', 'f', 'i'])
|
||||
def layer_norm__default(ctx, g, input, normalized_shape, weight, bias, eps,
|
||||
cudnn_enable):
|
||||
"""Symbolic function for `layer_norm`
|
||||
|
||||
Layer norm with torch<=1.12 might lead to wrong output shapes. Add
|
||||
keepdims=1 to each ReduceMean node to correct the shape.
|
||||
"""
|
||||
import torch.onnx.symbolic_helper as sym_help
|
||||
from torch.onnx.symbolic_opset9 import add, mul, pow, sqrt, sub
|
||||
if sym_help._operator_export_type == \
|
||||
torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
|
||||
return g.op(
|
||||
'ATen',
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
normalized_shape_i=normalized_shape,
|
||||
eps_f=eps,
|
||||
cudnn_enable_i=cudnn_enable,
|
||||
operator_s='layer_norm')
|
||||
|
||||
axes = [-i for i in range(len(normalized_shape), 0, -1)]
|
||||
|
||||
two_cst = sym_help._generate_wrapped_number(g, 2.)
|
||||
eps_cst = sym_help._generate_wrapped_number(g, eps)
|
||||
|
||||
mean = g.op('ReduceMean', input, axes_i=axes, keepdims_i=1)
|
||||
numerator = sub(g, input, mean)
|
||||
# variance = e((x - e(x))^2), and (x - e(x)) is the numerator in the
|
||||
# layer_norm formula
|
||||
variance = g.op(
|
||||
'ReduceMean', pow(g, numerator, two_cst), axes_i=axes, keepdims_i=1)
|
||||
denominator = sqrt(g, add(g, variance, eps_cst))
|
||||
|
||||
layer_norm = g.op('Div', numerator, denominator)
|
||||
|
||||
if not (weight is None or sym_help._is_none(weight)):
|
||||
layer_norm = mul(g, layer_norm, weight)
|
||||
if not (bias is None or sym_help._is_none(bias)):
|
||||
layer_norm = add(g, layer_norm, bias)
|
||||
|
||||
return layer_norm
|
||||
|
||||
|
||||
@parse_args('v', 'is', 'v', 'v', 'f', 'i')
|
||||
def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable):
|
||||
def _layer_norm_ncnn(g, input, normalized_shape, weight, bias, eps,
|
||||
cudnn_enable):
|
||||
"""Symbolic function for `layer_norm`.
|
||||
|
||||
PyTorch does not support export layer_norm to ONNX by default. We add the
|
||||
|
@ -29,4 +79,4 @@ def layer_norm__ncnn(ctx, *args):
|
|||
|
||||
Add support to layer_norm to ONNX.
|
||||
"""
|
||||
return layer_norm(*args)
|
||||
return _layer_norm_ncnn(*args)
|
||||
|
|
|
@ -165,3 +165,17 @@ def test_hardsigmoid():
|
|||
model = torch.nn.Hardsigmoid().eval()
|
||||
nodes = get_model_onnx_nodes(model, x)
|
||||
assert nodes[0].op_type == 'HardSigmoid'
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('prepare_symbolics')
|
||||
def test_layer_norm():
|
||||
x = torch.rand(2, 1, 4)
|
||||
model = torch.nn.LayerNorm(4).eval()
|
||||
torch.onnx.export(model, x, onnx_file, opset_version=11)
|
||||
onnx_model = onnx.load(onnx_file)
|
||||
graph = onnx_model.graph
|
||||
output = graph.output[0]
|
||||
dim = output.type.tensor_type.shape.dim
|
||||
assert dim[0].dim_value == 2
|
||||
assert dim[1].dim_value == 1
|
||||
assert dim[2].dim_value == 4
|
||||
|
|
Loading…
Reference in New Issue