fix layer norm when torch>=1.12 (#1168)
parent
f389a68dd4
commit
c35099ef0c
|
@ -2,7 +2,6 @@
|
|||
# Modified from:
|
||||
# https://github.com/pytorch/pytorch/blob/9ade03959392e5a90b74261012de1d806cab2253/torch/onnx/symbolic_opset9.py
|
||||
|
||||
import torch
|
||||
from torch.onnx.symbolic_helper import parse_args
|
||||
|
||||
from mmdeploy.core import SYMBOLIC_REWRITER
|
||||
|
@ -22,17 +21,6 @@ def layer_norm__default(ctx, g, input, normalized_shape, weight, bias, eps,
|
|||
"""
|
||||
import torch.onnx.symbolic_helper as sym_help
|
||||
from torch.onnx.symbolic_opset9 import add, mul, pow, sqrt, sub
|
||||
if sym_help._operator_export_type == \
|
||||
torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
|
||||
return g.op(
|
||||
'ATen',
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
normalized_shape_i=normalized_shape,
|
||||
eps_f=eps,
|
||||
cudnn_enable_i=cudnn_enable,
|
||||
operator_s='layer_norm')
|
||||
|
||||
axes = [-i for i in range(len(normalized_shape), 0, -1)]
|
||||
|
||||
|
|
Loading…
Reference in New Issue