Small SwiGLU tweak, remove default LN arg in unpacked variant, add packed alias for GluMLP

This commit is contained in:
Ross Wightman 2023-05-08 12:28:00 -07:00
parent cb3f9c23bb
commit 3fdb31de2e

View File

@ -97,6 +97,9 @@ class GluMlp(nn.Module):
return x return x
SwiGLUPacked = partial(GluMlp, act_layer=nn.SiLU, gate_last=False)
class SwiGLU(nn.Module): class SwiGLU(nn.Module):
""" SwiGLU """ SwiGLU
NOTE: GluMLP above can implement SwiGLU, but this impl has split fc1 and NOTE: GluMLP above can implement SwiGLU, but this impl has split fc1 and
@ -108,7 +111,7 @@ class SwiGLU(nn.Module):
hidden_features=None, hidden_features=None,
out_features=None, out_features=None,
act_layer=nn.SiLU, act_layer=nn.SiLU,
norm_layer=nn.LayerNorm, norm_layer=None,
bias=True, bias=True,
drop=0., drop=0.,
): ):