mirror of https://github.com/JDAI-CV/fast-reid.git
60 lines
1.3 KiB
Python
60 lines
1.3 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: xingyu liao
|
|
@contact: liaoxingyu5@jd.com
|
|
"""
|
|
|
|
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
__all__ = [
|
|
'Mish',
|
|
'Swish',
|
|
'MemoryEfficientSwish',
|
|
'GELU']
|
|
|
|
|
|
class Mish(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
# inlining this saves 1 second per epoch (V100 GPU) vs having a temp x and then returning x(!)
|
|
return x * (torch.tanh(F.softplus(x)))
|
|
|
|
|
|
class Swish(nn.Module):
|
|
def forward(self, x):
|
|
return x * torch.sigmoid(x)
|
|
|
|
|
|
class SwishImplementation(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, i):
|
|
result = i * torch.sigmoid(i)
|
|
ctx.save_for_backward(i)
|
|
return result
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
i = ctx.saved_variables[0]
|
|
sigmoid_i = torch.sigmoid(i)
|
|
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
|
|
|
|
|
|
class MemoryEfficientSwish(nn.Module):
|
|
def forward(self, x):
|
|
return SwishImplementation.apply(x)
|
|
|
|
|
|
class GELU(nn.Module):
|
|
"""
|
|
Paper Section 3.4, last paragraph notice that BERT used the GELU instead of RELU
|
|
"""
|
|
|
|
def forward(self, x):
|
|
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|