mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add memory efficient Swish impl
This commit is contained in:
parent
187ecbafbe
commit
a9eb484835
@ -371,11 +371,30 @@ def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil'):
|
||||
return arch_args
|
||||
|
||||
|
||||
def swish(x, inplace=False):
|
||||
if inplace:
|
||||
return x.mul_(x.sigmoid())
|
||||
else:
|
||||
return x * x.sigmoid()
|
||||
_USE_SWISH_OPT = True
|
||||
if _USE_SWISH_OPT:
|
||||
class SwishAutoFn(torch.autograd.Function):
|
||||
""" Memory Efficient Swish
|
||||
From: https://blog.ceshine.net/post/pytorch-memory-swish/
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
result = x.mul(torch.sigmoid(x))
|
||||
ctx.save_for_backward(x)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x = ctx.saved_variables[0]
|
||||
sigmoid_x = torch.sigmoid(x)
|
||||
return grad_output * (sigmoid_x * (1 + x * (1 - sigmoid_x)))
|
||||
|
||||
|
||||
def swish(x, inplace=False):
|
||||
return SwishAutoFn.apply(x)
|
||||
else:
|
||||
def swish(x, inplace=False):
|
||||
return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())
|
||||
|
||||
|
||||
def sigmoid(x, inplace=False):
|
||||
|
Loading…
x
Reference in New Issue
Block a user