Transformer中的 Add & Norm
flyfish
Add
同一个意思 Residual connections,Skip Connections
Norm
包括Post layer normalization和Pre layer normalization
Post layer normalization:Transformer 论文中使用的方式,将 Layer normalization 放在 Skip Connections 之间
class PoswiseFeedForwardNet(nn.Module):
def __init__(self, d_ff=2048):
super(PoswiseFeedForwardNet, self).__init__()
# 定义一维卷积层 1,用于将输入映射到更高维度
self.conv1 = nn.Conv1d(in_channels=d_embedding, out_channels=d_ff, kernel_size=1)
# 定义一维卷积层 2,用于将输入映射回原始维度
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_embedding, kernel_size=1)
# 定义层归一化
self.layer_norm = nn.LayerNorm(d_embedding)
def forward(self, inputs):
#------------------------- 维度信息 --------------------------------
# inputs [batch_size, len_q, embedding_dim]
#----------------------------------------------------------------
residual = inputs # 保留残差连接
# 在卷积层 1 后使用 ReLU 激活函数
output = nn.ReLU()(self.conv1(inputs.transpose(1, 2)))
#------------------------- 维度信息 --------------------------------
# output [batch_size, d_ff, len_q]
#----------------------------------------------------------------
# 使用卷积层 2 进行降维
output = self.conv2(output).transpose(1, 2)
#------------------------- 维度信息 --------------------------------
# output [batch_size, len_q, embedding_dim]
#----------------------------------------------------------------
# 与输入进行残差链接,并进行层归一化
output = self.layer_norm(output + residual)
#------------------------- 维度信息 --------------------------------
# output [batch_size, len_q, embedding_dim]
#----------------------------------------------------------------
return output # 返回加入残差连接后层归一化的结果
Pre layer normalization:将 Layer Normalization 放置于 Skip Connections 的范围内。
(常用方式)
class FeedForward(nn.Module):
def __init__(self, config):
super().__init__()
self.linear_1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.linear_2 = nn.Linear(config.intermediate_size, config.hidden_size)
self.gelu = nn.GELU()
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, x):
x = self.linear_1(x)
x = self.gelu(x)
x = self.linear_2(x)
x = self.dropout(x)
return x
class EncoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.layer_norm_1 = nn.LayerNorm(config.hidden_size)
self.layer_norm_2 = nn.LayerNorm(config.hidden_size)
self.attention = MultiHeadAttention(config)
self.feed_forward = FeedForward(config)
def forward(self, x, mask=None):
# Apply layer normalization and then copy input into query, key, value
hidden_state = self.layer_norm_1(x)
# Apply attention with a skip connection
x = x + self.attention(hidden_state, hidden_state, hidden_state, mask=mask)
# Apply feed-forward layer with a skip connection
x = x + self.feed_forward(self.layer_norm_2(x))
return x
初始化
定义两个LayerNorm层用于归一化输入,
一个MultiHeadAttention模块负责自注意力机制,以及一个FeedForward模块。
在前向传播过程中:
1 对输入先做LayerNorm得到标准化后的hidden_state。
2 使用自注意力机制对hidden_state进行处理,并与原始输入相加,实现残差连接。
3 再次对上一步的结果进行LayerNorm,并通过FeedForward层进行处理。
4将FeedForward层输出与经过自注意力机制后的结果相加,再次使用残差连接。
5 返回最终处理后的编码器层输出x。