mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix inplace arg compat for GELU and PreLU via activation factory
This commit is contained in:
parent
fd962c4b4a
commit
5f4b6076d8
@ -119,3 +119,27 @@ class HardMish(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return hard_mish(x, self.inplace)
|
||||
|
||||
|
||||
class PReLU(nn.PReLU):
|
||||
"""Applies PReLU (w/ dummy inplace arg)
|
||||
"""
|
||||
def __init__(self, num_parameters: int = 1, init: float = 0.25, inplace: bool = False) -> None:
|
||||
super(PReLU, self).__init__(num_parameters=num_parameters, init=init)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return F.prelu(input, self.weight)
|
||||
|
||||
|
||||
def gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor:
|
||||
return F.gelu(x)
|
||||
|
||||
|
||||
class GELU(nn.Module):
|
||||
"""Applies the Gaussian Error Linear Units function (w/ dummy inplace arg)
|
||||
"""
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(GELU, self).__init__()
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return F.gelu(input)
|
||||
|
@ -19,10 +19,9 @@ _ACT_FN_DEFAULT = dict(
|
||||
relu6=F.relu6,
|
||||
leaky_relu=F.leaky_relu,
|
||||
elu=F.elu,
|
||||
prelu=F.prelu,
|
||||
celu=F.celu,
|
||||
selu=F.selu,
|
||||
gelu=F.gelu,
|
||||
gelu=gelu,
|
||||
sigmoid=sigmoid,
|
||||
tanh=tanh,
|
||||
hard_sigmoid=hard_sigmoid,
|
||||
@ -56,10 +55,10 @@ _ACT_LAYER_DEFAULT = dict(
|
||||
relu6=nn.ReLU6,
|
||||
leaky_relu=nn.LeakyReLU,
|
||||
elu=nn.ELU,
|
||||
prelu=nn.PReLU,
|
||||
prelu=PReLU,
|
||||
celu=nn.CELU,
|
||||
selu=nn.SELU,
|
||||
gelu=nn.GELU,
|
||||
gelu=GELU,
|
||||
sigmoid=Sigmoid,
|
||||
tanh=Tanh,
|
||||
hard_sigmoid=HardSigmoid,
|
||||
|
Loading…
x
Reference in New Issue
Block a user