support export hardsigmoid in torch<=1.8 (#169)

* support export hardsigmoid in torch<=1.8

* fix lint
pull/187/head
q.yao 2022-02-24 16:10:42 +08:00 committed by GitHub
parent 486d45e739
commit e9ee21fc1d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 1 deletions

View File

@ -3,6 +3,7 @@ from .adaptive_avg_pool import (adaptive_avg_pool1d__default,
adaptive_avg_pool2d__default,
adaptive_avg_pool3d__default)
from .grid_sampler import grid_sampler__default
from .hardsigmoid import hardsigmoid__default
from .instance_norm import instance_norm__tensorrt
from .lstm import generic_rnn__ncnn
from .squeeze import squeeze__default
@ -10,5 +11,6 @@ from .squeeze import squeeze__default
__all__ = [
'adaptive_avg_pool1d__default', 'adaptive_avg_pool2d__default',
'adaptive_avg_pool3d__default', 'grid_sampler__default',
'instance_norm__tensorrt', 'generic_rnn__ncnn', 'squeeze__default'
'hardsigmoid__default', 'instance_norm__tensorrt', 'generic_rnn__ncnn',
'squeeze__default'
]

View File

@ -0,0 +1,12 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Modified from:
# https://github.com/pytorch/pytorch/blob/9ade03959392e5a90b74261012de1d806cab2253/torch/onnx/symbolic_opset9.py
from mmdeploy.core import SYMBOLIC_REWRITER
@SYMBOLIC_REWRITER.register_symbolic(
'hardsigmoid', is_pytorch=True, arg_descriptors=['v'])
def hardsigmoid__default(ctx, g, self):
"""Support export hardsigmoid This rewrite enable export hardsigmoid in
torch<=1.8.2."""
return g.op('HardSigmoid', self, alpha_f=1 / 6)

View File

@ -116,3 +116,10 @@ class TestSqueeze:
nodes = get_model_onnx_nodes(model, x)
assert nodes[0].attribute[0].ints == [0]
assert nodes[0].op_type == 'Squeeze'
def test_hardsigmoid():
x = torch.rand(1, 2, 3, 4)
model = torch.nn.Hardsigmoid().eval()
nodes = get_model_onnx_nodes(model, x)
assert nodes[0].op_type == 'HardSigmoid'