support export hardsigmoid in torch<=1.8 (#169)
* support export hardsigmoid in torch<=1.8 * fix lintpull/187/head
parent
486d45e739
commit
e9ee21fc1d
|
@ -3,6 +3,7 @@ from .adaptive_avg_pool import (adaptive_avg_pool1d__default,
|
|||
adaptive_avg_pool2d__default,
|
||||
adaptive_avg_pool3d__default)
|
||||
from .grid_sampler import grid_sampler__default
|
||||
from .hardsigmoid import hardsigmoid__default
|
||||
from .instance_norm import instance_norm__tensorrt
|
||||
from .lstm import generic_rnn__ncnn
|
||||
from .squeeze import squeeze__default
|
||||
|
@ -10,5 +11,6 @@ from .squeeze import squeeze__default
|
|||
__all__ = [
|
||||
'adaptive_avg_pool1d__default', 'adaptive_avg_pool2d__default',
|
||||
'adaptive_avg_pool3d__default', 'grid_sampler__default',
|
||||
'instance_norm__tensorrt', 'generic_rnn__ncnn', 'squeeze__default'
|
||||
'hardsigmoid__default', 'instance_norm__tensorrt', 'generic_rnn__ncnn',
|
||||
'squeeze__default'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
# Modified from:
|
||||
# https://github.com/pytorch/pytorch/blob/9ade03959392e5a90b74261012de1d806cab2253/torch/onnx/symbolic_opset9.py
|
||||
from mmdeploy.core import SYMBOLIC_REWRITER
|
||||
|
||||
|
||||
@SYMBOLIC_REWRITER.register_symbolic(
|
||||
'hardsigmoid', is_pytorch=True, arg_descriptors=['v'])
|
||||
def hardsigmoid__default(ctx, g, self):
|
||||
"""Support export hardsigmoid This rewrite enable export hardsigmoid in
|
||||
torch<=1.8.2."""
|
||||
return g.op('HardSigmoid', self, alpha_f=1 / 6)
|
|
@ -116,3 +116,10 @@ class TestSqueeze:
|
|||
nodes = get_model_onnx_nodes(model, x)
|
||||
assert nodes[0].attribute[0].ints == [0]
|
||||
assert nodes[0].op_type == 'Squeeze'
|
||||
|
||||
|
||||
def test_hardsigmoid():
|
||||
x = torch.rand(1, 2, 3, 4)
|
||||
model = torch.nn.Hardsigmoid().eval()
|
||||
nodes = get_model_onnx_nodes(model, x)
|
||||
assert nodes[0].op_type == 'HardSigmoid'
|
||||
|
|
Loading…
Reference in New Issue