diff --git a/utils/activations.py b/utils/activations.py index 879f7b421..58225c6de 100644 --- a/utils/activations.py +++ b/utils/activations.py @@ -10,6 +10,13 @@ class Swish(nn.Module): # return x * torch.sigmoid(x) +class Hardswish(nn.Module): # alternative to nn.Hardswish() for export + @staticmethod + def forward(x): + # return x * F.hardsigmoid(x) + return x * F.hardtanh(x + 3, 0., 6.) / 6. + + class MemoryEfficientSwish(nn.Module): class F(torch.autograd.Function): @staticmethod