exportable Hardswish() implementation
parent
fd71fe8451
commit
71209a6099
utils
|
@ -10,6 +10,13 @@ class Swish(nn.Module): #
|
|||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class Hardswish(nn.Module): # alternative to nn.Hardswish() for export
|
||||
@staticmethod
|
||||
def forward(x):
|
||||
# return x * F.hardsigmoid(x)
|
||||
return x * F.hardtanh(x + 3, 0., 6.) / 6.
|
||||
|
||||
|
||||
class MemoryEfficientSwish(nn.Module):
|
||||
class F(torch.autograd.Function):
|
||||
@staticmethod
|
||||
|
|
Loading…
Reference in New Issue