fix layer norm when torch>=1.12 ()

pull/1045/head
hanrui1sensetime 2022-10-11 14:29:13 +08:00 committed by GitHub
parent f389a68dd4
commit c35099ef0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 0 additions and 12 deletions
mmdeploy/pytorch/ops

View File

@ -2,7 +2,6 @@
# Modified from:
# https://github.com/pytorch/pytorch/blob/9ade03959392e5a90b74261012de1d806cab2253/torch/onnx/symbolic_opset9.py
import torch
from torch.onnx.symbolic_helper import parse_args
from mmdeploy.core import SYMBOLIC_REWRITER
@ -22,17 +21,6 @@ def layer_norm__default(ctx, g, input, normalized_shape, weight, bias, eps,
"""
import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_opset9 import add, mul, pow, sqrt, sub
if sym_help._operator_export_type == \
torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op(
'ATen',
input,
weight,
bias,
normalized_shape_i=normalized_shape,
eps_f=eps,
cudnn_enable_i=cudnn_enable,
operator_s='layer_norm')
axes = [-i for i in range(len(normalized_shape), 0, -1)]