mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Support native silu activation (aka swish). An optimized ver is available in PyTorch 1.7.
This commit is contained in:
parent
da6cd2cc1f
commit
e90edce438
@ -6,9 +6,14 @@ from .activations_jit import *
|
|||||||
from .activations_me import *
|
from .activations_me import *
|
||||||
from .config import is_exportable, is_scriptable, is_no_jit
|
from .config import is_exportable, is_scriptable, is_no_jit
|
||||||
|
|
||||||
|
# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. This code
|
||||||
|
# will use native version if present. Eventually, the custom Swish layers will be removed
|
||||||
|
# and only native 'silu' will be used.
|
||||||
|
_has_silu = 'silu' in dir(torch.nn.functional)
|
||||||
|
|
||||||
_ACT_FN_DEFAULT = dict(
|
_ACT_FN_DEFAULT = dict(
|
||||||
swish=swish,
|
silu=F.silu if _has_silu else swish,
|
||||||
|
swish=F.silu if _has_silu else swish,
|
||||||
mish=mish,
|
mish=mish,
|
||||||
relu=F.relu,
|
relu=F.relu,
|
||||||
relu6=F.relu6,
|
relu6=F.relu6,
|
||||||
@ -26,7 +31,8 @@ _ACT_FN_DEFAULT = dict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
_ACT_FN_JIT = dict(
|
_ACT_FN_JIT = dict(
|
||||||
swish=swish_jit,
|
silu=F.silu if _has_silu else swish_jit,
|
||||||
|
swish=F.silu if _has_silu else swish_jit,
|
||||||
mish=mish_jit,
|
mish=mish_jit,
|
||||||
hard_sigmoid=hard_sigmoid_jit,
|
hard_sigmoid=hard_sigmoid_jit,
|
||||||
hard_swish=hard_swish_jit,
|
hard_swish=hard_swish_jit,
|
||||||
@ -34,7 +40,8 @@ _ACT_FN_JIT = dict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
_ACT_FN_ME = dict(
|
_ACT_FN_ME = dict(
|
||||||
swish=swish_me,
|
silu=F.silu if _has_silu else swish_me,
|
||||||
|
swish=F.silu if _has_silu else swish_me,
|
||||||
mish=mish_me,
|
mish=mish_me,
|
||||||
hard_sigmoid=hard_sigmoid_me,
|
hard_sigmoid=hard_sigmoid_me,
|
||||||
hard_swish=hard_swish_me,
|
hard_swish=hard_swish_me,
|
||||||
@ -42,7 +49,8 @@ _ACT_FN_ME = dict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
_ACT_LAYER_DEFAULT = dict(
|
_ACT_LAYER_DEFAULT = dict(
|
||||||
swish=Swish,
|
silu=nn.SiLU if _has_silu else Swish,
|
||||||
|
swish=nn.SiLU if _has_silu else Swish,
|
||||||
mish=Mish,
|
mish=Mish,
|
||||||
relu=nn.ReLU,
|
relu=nn.ReLU,
|
||||||
relu6=nn.ReLU6,
|
relu6=nn.ReLU6,
|
||||||
@ -60,7 +68,8 @@ _ACT_LAYER_DEFAULT = dict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
_ACT_LAYER_JIT = dict(
|
_ACT_LAYER_JIT = dict(
|
||||||
swish=SwishJit,
|
silu=nn.SiLU if _has_silu else SwishJit,
|
||||||
|
swish=nn.SiLU if _has_silu else SwishJit,
|
||||||
mish=MishJit,
|
mish=MishJit,
|
||||||
hard_sigmoid=HardSigmoidJit,
|
hard_sigmoid=HardSigmoidJit,
|
||||||
hard_swish=HardSwishJit,
|
hard_swish=HardSwishJit,
|
||||||
@ -68,7 +77,8 @@ _ACT_LAYER_JIT = dict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
_ACT_LAYER_ME = dict(
|
_ACT_LAYER_ME = dict(
|
||||||
swish=SwishMe,
|
silu=nn.SiLU if _has_silu else SwishMe,
|
||||||
|
swish=nn.SiLU if _has_silu else SwishMe,
|
||||||
mish=MishMe,
|
mish=MishMe,
|
||||||
hard_sigmoid=HardSigmoidMe,
|
hard_sigmoid=HardSigmoidMe,
|
||||||
hard_swish=HardSwishMe,
|
hard_swish=HardSwishMe,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user