引言
今天介绍LLAMA模型引入的关于激活函数的改进——SwiGLU1,该激活函数取得了不错的效果,得到了广泛地应用。
SwiGLU是GLU的一种变体,其中包含了GLU和Swish激活函数。
GLU
GLU(Gated Linear Units,门控线性单元)2引入了两个不同的线性层,其中一个首先经过sigmoid函数,其结果将和另一个线性层的输出进行逐元素相乘作为最终的输出:
GLU
(
x
,
W
,
V
,
b
,
c
)
=
σ
(
x
W
+
b
)
⊗
(
x
V
+
c
)
(1)
\text{GLU}(x,W,V,b,c) = \sigma(xW+b) \otimes (xV+c) \tag 1
GLU(x,W,V,b,c)=σ(xW+b)⊗(xV+c)(1)
这里
W
,
V
W,V
W,V以及
b
,
c
b,c
b,c分别是这两个线性层的参数;
σ
(
x
W
+
b
)
\sigma(xW+b)
σ(xW+b)作为门控,控制
x
V
+
c
xV+c
xV+c的输出。
这里使用 σ \sigma σ作为激活函数,修改改激活函数得到的变体通常能带来更好的性能表现,比如SwiGLU修改激活函数为Swish。我们来看下Swish激活函数。
Swish
Swish3激活函数的形式为:
Swish
β
(
x
)
=
x
σ
(
β
x
)
(2)
\text{Swish}_\beta(x) = x \sigma(\beta x) \tag 2
Swishβ(x)=xσ(βx)(2)
其中
σ
(
x
)
\sigma(x)
σ(x)是Sigmoid函数;
β
\beta
β是一个可学习的参数。
可以通过下面的代码画出Swish激活函数在不同参数 β \beta β下的图像:
import numpy as np
import matplotlib.pyplot as plt
def swish(x, beta):
return x / (1 + np.exp(-beta*x))
x = np.linspace(-10, 10, 100)
betas = [0.1, 1.0, 10.0]
plt.figure(figsize=(10, 6))
for beta in betas:
y = swish(x, beta)
plt.plot(x, y, label=f'beta={beta}')
plt.legend()
plt.title('Swish Activation Function')
plt.xlabel('x')
plt.ylabel('f(x)')
plt.grid(True)
plt.show()
可以看到3,当 β \beta β趋近于 0 0 0时,Swish函数趋近于线性函数 y = x 2 y=x^2 y=x2;当 β \beta β趋近于无穷大时,Swish函数趋近于ReLU函数;当 β \beta β取值为 1 1 1时,Swish函数是光滑且非单调的,等价于参考4中介绍的SiLU。
Swish与ReLU之间最显著的区别是当 x < 0 x < 0 x<0时Swish的非单调“凸起”3。
SwiGLU
如前文所述,将公式(1)中GLU的激活函数改为Swish即变成了所谓的SwiGLU激活函数1:
SwiGLU
(
x
,
W
,
V
)
=
Swish
β
(
x
W
)
⊗
(
x
V
)
(3)
\text{SwiGLU}(x,W,V) = \text{Swish}_\beta(xW) \otimes (xV) \tag{3}
SwiGLU(x,W,V)=Swishβ(xW)⊗(xV)(3)
这里省略了偏置项。
代码实现
参考LLaMA,全连接层使用带有SwiGLU激活函数的FFN(Position-wise Feed-Forward Network)的公式如下1:
FFN
SwiGLU
(
x
,
W
,
V
,
W
2
)
=
(
Swish
1
(
x
W
)
⊗
x
V
)
W
2
(4)
\text{FFN}_{\text{SwiGLU}}(\pmb x,W,V,W_2) = (\text{Swish}_1(\pmb xW) \otimes \pmb xV)W_2 \tag 4
FFNSwiGLU(x,W,V,W2)=(Swish1(xW)⊗xV)W2(4)
这里的Swish函数可以被SiLU函数替代:
SiLU
(
x
)
=
x
σ
(
x
)
\text{SiLU}(\pmb x) = \pmb x \sigma(\pmb x)
SiLU(x)=xσ(x)
即:
FFN
SwiGLU
(
x
,
W
,
V
,
W
2
)
=
(
SiLU
(
x
W
)
⊗
x
V
)
W
2
(5)
\text{FFN}_{\text{SwiGLU}}(\pmb x,W,V,W_2) = (\text{SiLU}(\pmb xW) \otimes \pmb xV)W_2 \tag 5
FFNSwiGLU(x,W,V,W2)=(SiLU(xW)⊗xV)W2(5)
import torch
from torch import nn
import torch.nn.functional as F
class FeedForward(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int) -> None:
super().__init__()
self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False)
self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False)
self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (batch_size, seq_len, hidden_size)
# w1(x) -> (batch_size, seq_len, intermediate_size)
# w1(x) -> (batch_size, seq_len, intermediate_size)
# w2(*) -> (batch_size, seq_len, hidden_size)
return self.w2(F.silu(self.w1(x)) * self.w3(x))
这里w1,w2,w3
分别对应公式(5)中的
W
,
W
2
,
V
W,W_2,V
W,W2,V。
注意维度,其中w1,w3
将x
转换到维度intermediate_size
,然后w2
转换回hidden_size
。
参考
[论文翻译]GLU Variants Improve Transformer ↩︎ ↩︎ ↩︎
[论文笔记]Language Modeling with Gated Convolutional Networks ↩︎
[论文笔记]SEARCHING FOR ACTIVATION FUNCTIONS ↩︎ ↩︎ ↩︎
[论文笔记]GAUSSIAN ERROR LINEAR UNITS (GELUS) ↩︎