mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix inplace arg compat for GELU and PreLU via activation factory
This commit is contained in:
parent
fd962c4b4a
commit
5f4b6076d8
@ -119,3 +119,27 @@ class HardMish(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return hard_mish(x, self.inplace)
|
return hard_mish(x, self.inplace)
|
||||||
|
|
||||||
|
|
||||||
|
class PReLU(nn.PReLU):
|
||||||
|
"""Applies PReLU (w/ dummy inplace arg)
|
||||||
|
"""
|
||||||
|
def __init__(self, num_parameters: int = 1, init: float = 0.25, inplace: bool = False) -> None:
|
||||||
|
super(PReLU, self).__init__(num_parameters=num_parameters, init=init)
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
return F.prelu(input, self.weight)
|
||||||
|
|
||||||
|
|
||||||
|
def gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor:
|
||||||
|
return F.gelu(x)
|
||||||
|
|
||||||
|
|
||||||
|
class GELU(nn.Module):
|
||||||
|
"""Applies the Gaussian Error Linear Units function (w/ dummy inplace arg)
|
||||||
|
"""
|
||||||
|
def __init__(self, inplace: bool = False):
|
||||||
|
super(GELU, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
return F.gelu(input)
|
||||||
|
@ -19,10 +19,9 @@ _ACT_FN_DEFAULT = dict(
|
|||||||
relu6=F.relu6,
|
relu6=F.relu6,
|
||||||
leaky_relu=F.leaky_relu,
|
leaky_relu=F.leaky_relu,
|
||||||
elu=F.elu,
|
elu=F.elu,
|
||||||
prelu=F.prelu,
|
|
||||||
celu=F.celu,
|
celu=F.celu,
|
||||||
selu=F.selu,
|
selu=F.selu,
|
||||||
gelu=F.gelu,
|
gelu=gelu,
|
||||||
sigmoid=sigmoid,
|
sigmoid=sigmoid,
|
||||||
tanh=tanh,
|
tanh=tanh,
|
||||||
hard_sigmoid=hard_sigmoid,
|
hard_sigmoid=hard_sigmoid,
|
||||||
@ -56,10 +55,10 @@ _ACT_LAYER_DEFAULT = dict(
|
|||||||
relu6=nn.ReLU6,
|
relu6=nn.ReLU6,
|
||||||
leaky_relu=nn.LeakyReLU,
|
leaky_relu=nn.LeakyReLU,
|
||||||
elu=nn.ELU,
|
elu=nn.ELU,
|
||||||
prelu=nn.PReLU,
|
prelu=PReLU,
|
||||||
celu=nn.CELU,
|
celu=nn.CELU,
|
||||||
selu=nn.SELU,
|
selu=nn.SELU,
|
||||||
gelu=nn.GELU,
|
gelu=GELU,
|
||||||
sigmoid=Sigmoid,
|
sigmoid=Sigmoid,
|
||||||
tanh=Tanh,
|
tanh=Tanh,
|
||||||
hard_sigmoid=HardSigmoid,
|
hard_sigmoid=HardSigmoid,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user