mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #401 from hwangdeyu/deyu/add_HardSwishJitAutoFn_operator
add HardSwishJitAutoFn operator export to onnx
This commit is contained in:
commit
ea36a78cff
@ -152,6 +152,13 @@ class HardSwishJitAutoFn(torch.autograd.Function):
|
|||||||
x = ctx.saved_tensors[0]
|
x = ctx.saved_tensors[0]
|
||||||
return hard_swish_jit_bwd(x, grad_output)
|
return hard_swish_jit_bwd(x, grad_output)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def symbolic(g, self):
|
||||||
|
input = g.op("Add", self, g.op('Constant', value_t=torch.tensor(3, dtype=torch.float)))
|
||||||
|
hardtanh_ = g.op("Clip", input, g.op('Constant', value_t=torch.tensor(0, dtype=torch.float)), g.op('Constant', value_t=torch.tensor(6, dtype=torch.float)))
|
||||||
|
hardtanh_ = g.op("Div", hardtanh_, g.op('Constant', value_t=torch.tensor(6, dtype=torch.float)))
|
||||||
|
return g.op("Mul", self, hardtanh_)
|
||||||
|
|
||||||
|
|
||||||
def hard_swish_me(x, inplace=False):
|
def hard_swish_me(x, inplace=False):
|
||||||
return HardSwishJitAutoFn.apply(x)
|
return HardSwishJitAutoFn.apply(x)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user