exportable Hardswish() implementation

pull/823/head
Glenn Jocher 2020-08-22 16:59:31 -07:00
parent fd71fe8451
commit 71209a6099
1 changed files with 7 additions and 0 deletions

View File

@ -10,6 +10,13 @@ class Swish(nn.Module): #
return x * torch.sigmoid(x)
class Hardswish(nn.Module): # alternative to nn.Hardswish() for export
@staticmethod
def forward(x):
# return x * F.hardsigmoid(x)
return x * F.hardtanh(x + 3, 0., 6.) / 6.
class MemoryEfficientSwish(nn.Module):
class F(torch.autograd.Function):
@staticmethod