fix layer norm (#1015)

pull/1090/head
q.yao 2022-09-21 16:24:47 +08:00 committed by GitHub
parent fbe7586415
commit ea7706cbfd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 66 additions and 2 deletions

View File

@ -2,14 +2,64 @@
# Modified from:
# https://github.com/pytorch/pytorch/blob/9ade03959392e5a90b74261012de1d806cab2253/torch/onnx/symbolic_opset9.py
import torch
from torch.onnx.symbolic_helper import parse_args
from mmdeploy.core import SYMBOLIC_REWRITER
from mmdeploy.utils import Backend
@SYMBOLIC_REWRITER.register_symbolic(
'layer_norm',
is_pytorch=True,
arg_descriptors=['v', 'is', 'v', 'v', 'f', 'i'])
def layer_norm__default(ctx, g, input, normalized_shape, weight, bias, eps,
cudnn_enable):
"""Symbolic function for `layer_norm`
Layer norm with torch<=1.12 might lead to wrong output shapes. Add
keepdims=1 to each ReduceMean node to correct the shape.
"""
import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_opset9 import add, mul, pow, sqrt, sub
if sym_help._operator_export_type == \
torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op(
'ATen',
input,
weight,
bias,
normalized_shape_i=normalized_shape,
eps_f=eps,
cudnn_enable_i=cudnn_enable,
operator_s='layer_norm')
axes = [-i for i in range(len(normalized_shape), 0, -1)]
two_cst = sym_help._generate_wrapped_number(g, 2.)
eps_cst = sym_help._generate_wrapped_number(g, eps)
mean = g.op('ReduceMean', input, axes_i=axes, keepdims_i=1)
numerator = sub(g, input, mean)
# variance = e((x - e(x))^2), and (x - e(x)) is the numerator in the
# layer_norm formula
variance = g.op(
'ReduceMean', pow(g, numerator, two_cst), axes_i=axes, keepdims_i=1)
denominator = sqrt(g, add(g, variance, eps_cst))
layer_norm = g.op('Div', numerator, denominator)
if not (weight is None or sym_help._is_none(weight)):
layer_norm = mul(g, layer_norm, weight)
if not (bias is None or sym_help._is_none(bias)):
layer_norm = add(g, layer_norm, bias)
return layer_norm
@parse_args('v', 'is', 'v', 'v', 'f', 'i')
def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable):
def _layer_norm_ncnn(g, input, normalized_shape, weight, bias, eps,
cudnn_enable):
"""Symbolic function for `layer_norm`.
PyTorch does not support export layer_norm to ONNX by default. We add the
@ -29,4 +79,4 @@ def layer_norm__ncnn(ctx, *args):
Add support to layer_norm to ONNX.
"""
return layer_norm(*args)
return _layer_norm_ncnn(*args)

View File

@ -165,3 +165,17 @@ def test_hardsigmoid():
model = torch.nn.Hardsigmoid().eval()
nodes = get_model_onnx_nodes(model, x)
assert nodes[0].op_type == 'HardSigmoid'
@pytest.mark.usefixtures('prepare_symbolics')
def test_layer_norm():
x = torch.rand(2, 1, 4)
model = torch.nn.LayerNorm(4).eval()
torch.onnx.export(model, x, onnx_file, opset_version=11)
onnx_model = onnx.load(onnx_file)
graph = onnx_model.graph
output = graph.output[0]
dim = output.type.tensor_type.shape.dim
assert dim[0].dim_value == 2
assert dim[1].dim_value == 1
assert dim[2].dim_value == 4