240712_昇思学习打卡-Day24-LSTM+CRF序列标注(3)
今天做LSTM+CRF序列标注第三部分,同样,仅作简单记录及注释,最近确实太忙了。
Viterbi算法
在完成前向训练部分后,需要实现解码部分。这里我们选择适合求解序列最优路径的Viterbi算法。与计算Normalizer类似,使用动态规划求解所有可能的预测序列得分。不同的是在解码时同时需要将第𝑖个Token对应的score取值最大的标签保存,供后续使用Viterbi算法求解最优预测序列使用。
取得最大概率得分ScoreScore,以及每个Token对应的标签历史HistoryHistory后,根据Viterbi算法可以得到公式:
从第0个至第𝑖个Token对应概率最大的序列,只需要考虑从第0个至第𝑖−1个Token对应概率最大的序列,以及从第𝑖个至第𝑖−1个概率最大的标签即可。因此我们逆序求解每一个概率最大的标签,构成最佳的预测序列。
由于静态图语法限制,我们将Viterbi算法求解最佳预测序列的部分作为后处理函数,不纳入后续CRF层的实现。
# 定义维特比解码算法,用于找出具有最大概率的标签序列
def viterbi_decode(emissions, mask, trans, start_trans, end_trans):
# emissions: (seq_length, batch_size, num_tags) 发射概率矩阵
# mask: (seq_length, batch_size) 序列掩码,用于标记有效序列长度
# trans: 转移概率矩阵
# start_trans: 初始状态转移概率向量
# end_trans: 终止状态转移概率向量
seq_length = mask.shape[0] # 获取序列长度
# 初始化分数矩阵,等于初始状态转移概率加上第一个发射概率
score = start_trans + emissions[0]
history = () # 初始化历史路径记录
# 遍历序列中的每个时间步
for i in range(1, seq_length):
# 扩展维度以便广播运算
broadcast_score = score.expand_dims(2)
broadcast_emission = emissions[i].expand_dims(1)
# 计算所有可能的转移分数
next_score = broadcast_score + trans + broadcast_emission
# 找出当前Token对应的最大分数标签,并保存
indices = next_score.argmax(axis=1)
history += (indices,) # 保存历史路径信息
# 取出最大分数
next_score = next_score.max(axis=1)
# 更新分数矩阵,只更新mask为True的部分
score = mnp.where(mask[i].expand_dims(1), next_score, score)
# 加上终止状态转移概率
score += end_trans
# 返回最终的分数矩阵和历史路径信息
return score, history
# 根据解码过程中的得分和历史路径信息,重构最优标签序列
def post_decode(score, history, seq_length):
# score: 最终得分矩阵
# history: 历史路径信息
# seq_length: 每个样本的实际序列长度
batch_size = seq_length.shape[0] # 获取批次大小
seq_ends = seq_length - 1 # 计算每个样本的最后一个Token位置
# 初始化最佳标签序列列表
best_tags_list = []
# 对批次中的每个样本进行解码
for idx in range(batch_size):
# 找出使最后一个Token对应的预测概率最大的标签
best_last_tag = score[idx].argmax(axis=0)
best_tags = [int(best_last_tag.asnumpy())] # 添加最佳标签到序列
# 从历史路径信息中反向追踪,找到每个Token的最佳标签
for hist in reversed(history[:seq_ends[idx]]):
best_last_tag = hist[idx][best_tags[-1]]
best_tags.append(int(best_last_tag.asnumpy()))
# 将逆序的标签序列反转,得到正序的最优标签序列
best_tags.reverse()
best_tags_list.append(best_tags) # 添加到结果列表
# 返回最优标签序列列表
return best_tags_list
CRF层
完成上述前向训练和解码部分的代码后,将其组装完整的CRF层。考虑到输入序列可能存在Padding的情况,CRF的输入需要考虑输入序列的真实长度,因此除发射矩阵和标签外,加入seq_length
参数传入序列Padding前的长度,并实现生成mask矩阵的sequence_mask
方法。
综合上述代码,使用nn.Cell
进行封装,最后实现完整的CRF层如下:
# 导入MindSpore相关模块
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.numpy as mnp
from mindspore.common.initializer import initializer, Uniform
# 定义序列掩码生成函数
def sequence_mask(seq_length, max_length, batch_first=False):
"""
根据序列的实际长度和最大长度生成mask矩阵。
参数:
seq_length: 实际序列长度张量。
max_length: 序列的最大长度。
batch_first: 是否将批次放在第一维度。
返回:
mask矩阵,形状为(batch_size, max_length),其中True表示有效位置,False表示填充位置。
"""
# 生成从0到max_length的范围向量
range_vector = mnp.arange(0, max_length, 1, seq_length.dtype)
# 创建mask矩阵,shape为(seq_length.shape + (1,))
result = range_vector < seq_length.view(seq_length.shape + (1,))
# 转换数据类型并根据batch_first参数调整维度顺序
if batch_first:
return result.astype(ms.int64)
return result.astype(ms.int64).swapaxes(0, 1)
# 定义条件随机场(CRF)模型类
class CRF(nn.Cell):
def __init__(self, num_tags: int, batch_first: bool = False, reduction: str = 'sum') -> None:
"""
初始化CRF模型。
参数:
num_tags: 标签数量。
batch_first: 是否将批次放在第一维度。
reduction: 损失函数的缩减方式。
"""
# 检查标签数量是否有效
if num_tags <= 0:
raise ValueError(f'无效的标签数量: {num_tags}')
super().__init__()
# 检查reduction参数是否有效
if reduction not in ('none', 'sum', 'mean', 'token_mean'):
raise ValueError(f'无效的缩减方式: {reduction}')
self.num_tags = num_tags # 标签数量
self.batch_first = batch_first # 批次是否在第一维度
self.reduction = reduction # 损失函数缩减方式
# 初始化起始和结束状态转移权重
self.start_transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags,)), name='start_transitions')
self.end_transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags,)), name='end_transitions')
# 初始化状态间转移权重
self.transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags, num_tags)), name='transitions')
def construct(self, emissions, tags=None, seq_length=None):
"""
CRF模型的前向传播方法。
参数:
emissions: 发射概率张量。
tags: 真实标签张量。
seq_length: 序列长度张量。
返回:
如果tags为None,则返回解码结果;否则返回损失值。
"""
if tags is None:
return self._decode(emissions, seq_length)
return self._forward(emissions, tags, seq_length)
def _forward(self, emissions, tags=None, seq_length=None):
"""
计算损失值。
参数:
emissions: 发射概率张量。
tags: 真实标签张量。
seq_length: 序列长度张量。
返回:
损失值。
"""
# 根据batch_first参数调整emissions和tags的维度顺序
if self.batch_first:
batch_size, max_length = tags.shape
emissions = emissions.swapaxes(0, 1)
tags = tags.swapaxes(0, 1)
else:
max_length, batch_size = tags.shape
# 如果seq_length未给出,则假设所有序列都是最大长度
if seq_length is None:
seq_length = mnp.full((batch_size,), max_length, ms.int64)
# 生成mask矩阵
mask = sequence_mask(seq_length, max_length)
# 计算分子部分(真实路径的得分)
numerator = compute_score(emissions, tags, seq_length-1, mask, self.transitions, self.start_transitions, self.end_transitions)
# 计算分母部分(所有可能路径的得分总和)
denominator = compute_normalizer(emissions, mask, self.transitions, self.start_transitions, self.end_transitions)
# 计算对数似然比
llh = denominator - numerator
# 根据reduction参数选择损失值的缩减方式
if self.reduction == 'none':
return llh
elif self.reduction == 'sum':
return llh.sum()
elif self.reduction == 'mean':
return llh.mean()
return llh.sum() / mask.astype(emissions.dtype).sum()
def _decode(self, emissions, seq_length=None):
"""
解码方法,用于预测最优标签序列。
参数:
emissions: 发射概率张量。
seq_length: 序列长度张量。
返回:
最优标签序列。
"""
# 根据batch_first参数调整emissions的维度顺序
if self.batch_first:
batch_size, max_length = emissions.shape[:2]
emissions = emissions.swapaxes(0, 1)
else:
batch_size, max_length = emissions.shape[:2]
# 如果seq_length未给出,则假设所有序列都是最大长度
if seq_length is None:
seq_length = mnp.full((batch_size,), max_length, ms.int64)
# 生成mask矩阵
mask = sequence_mask(seq_length, max_length)
# 使用维特比算法解码最优路径
return viterbi_decode(emissions, mask, self.transitions, self.start_transitions, self.end_transitions)
打卡图片: