240712_昇思学习打卡-Day24-LSTM+CRF序列标注(3)

news2024/11/15 8:09:00

240712_昇思学习打卡-Day24-LSTM+CRF序列标注(3)

今天做LSTM+CRF序列标注第三部分,同样,仅作简单记录及注释,最近确实太忙了。

Viterbi算法

在完成前向训练部分后,需要实现解码部分。这里我们选择适合求解序列最优路径的Viterbi算法。与计算Normalizer类似,使用动态规划求解所有可能的预测序列得分。不同的是在解码时同时需要将第𝑖个Token对应的score取值最大的标签保存,供后续使用Viterbi算法求解最优预测序列使用。

取得最大概率得分ScoreScore,以及每个Token对应的标签历史HistoryHistory后,根据Viterbi算法可以得到公式:

请添加图片描述

从第0个至第𝑖个Token对应概率最大的序列,只需要考虑从第0个至第𝑖−1个Token对应概率最大的序列,以及从第𝑖个至第𝑖−1个概率最大的标签即可。因此我们逆序求解每一个概率最大的标签,构成最佳的预测序列。

由于静态图语法限制,我们将Viterbi算法求解最佳预测序列的部分作为后处理函数,不纳入后续CRF层的实现。

# 定义维特比解码算法,用于找出具有最大概率的标签序列
def viterbi_decode(emissions, mask, trans, start_trans, end_trans):
    # emissions: (seq_length, batch_size, num_tags) 发射概率矩阵
    # mask: (seq_length, batch_size) 序列掩码,用于标记有效序列长度
    # trans: 转移概率矩阵
    # start_trans: 初始状态转移概率向量
    # end_trans: 终止状态转移概率向量

    seq_length = mask.shape[0]  # 获取序列长度

    # 初始化分数矩阵,等于初始状态转移概率加上第一个发射概率
    score = start_trans + emissions[0]
    history = ()  # 初始化历史路径记录

    # 遍历序列中的每个时间步
    for i in range(1, seq_length):
        # 扩展维度以便广播运算
        broadcast_score = score.expand_dims(2)
        broadcast_emission = emissions[i].expand_dims(1)
        
        # 计算所有可能的转移分数
        next_score = broadcast_score + trans + broadcast_emission

        # 找出当前Token对应的最大分数标签,并保存
        indices = next_score.argmax(axis=1)
        history += (indices,)  # 保存历史路径信息

        # 取出最大分数
        next_score = next_score.max(axis=1)
        
        # 更新分数矩阵,只更新mask为True的部分
        score = mnp.where(mask[i].expand_dims(1), next_score, score)

    # 加上终止状态转移概率
    score += end_trans

    # 返回最终的分数矩阵和历史路径信息
    return score, history


# 根据解码过程中的得分和历史路径信息,重构最优标签序列
def post_decode(score, history, seq_length):
    # score: 最终得分矩阵
    # history: 历史路径信息
    # seq_length: 每个样本的实际序列长度

    batch_size = seq_length.shape[0]  # 获取批次大小
    seq_ends = seq_length - 1  # 计算每个样本的最后一个Token位置
    
    # 初始化最佳标签序列列表
    best_tags_list = []

    # 对批次中的每个样本进行解码
    for idx in range(batch_size):
        # 找出使最后一个Token对应的预测概率最大的标签
        best_last_tag = score[idx].argmax(axis=0)
        best_tags = [int(best_last_tag.asnumpy())]  # 添加最佳标签到序列

        # 从历史路径信息中反向追踪,找到每个Token的最佳标签
        for hist in reversed(history[:seq_ends[idx]]):
            best_last_tag = hist[idx][best_tags[-1]]
            best_tags.append(int(best_last_tag.asnumpy()))

        # 将逆序的标签序列反转,得到正序的最优标签序列
        best_tags.reverse()
        best_tags_list.append(best_tags)  # 添加到结果列表

    # 返回最优标签序列列表
    return best_tags_list

CRF层

完成上述前向训练和解码部分的代码后,将其组装完整的CRF层。考虑到输入序列可能存在Padding的情况,CRF的输入需要考虑输入序列的真实长度,因此除发射矩阵和标签外,加入seq_length参数传入序列Padding前的长度,并实现生成mask矩阵的sequence_mask方法。

综合上述代码,使用nn.Cell进行封装,最后实现完整的CRF层如下:

# 导入MindSpore相关模块
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.numpy as mnp
from mindspore.common.initializer import initializer, Uniform

# 定义序列掩码生成函数
def sequence_mask(seq_length, max_length, batch_first=False):
    """
    根据序列的实际长度和最大长度生成mask矩阵。
    
    参数:
    seq_length: 实际序列长度张量。
    max_length: 序列的最大长度。
    batch_first: 是否将批次放在第一维度。
    
    返回:
    mask矩阵,形状为(batch_size, max_length),其中True表示有效位置,False表示填充位置。
    """
    # 生成从0到max_length的范围向量
    range_vector = mnp.arange(0, max_length, 1, seq_length.dtype)
    # 创建mask矩阵,shape为(seq_length.shape + (1,))
    result = range_vector < seq_length.view(seq_length.shape + (1,))
    # 转换数据类型并根据batch_first参数调整维度顺序
    if batch_first:
        return result.astype(ms.int64)
    return result.astype(ms.int64).swapaxes(0, 1)


# 定义条件随机场(CRF)模型类
class CRF(nn.Cell):
    def __init__(self, num_tags: int, batch_first: bool = False, reduction: str = 'sum') -> None:
        """
        初始化CRF模型。
        
        参数:
        num_tags: 标签数量。
        batch_first: 是否将批次放在第一维度。
        reduction: 损失函数的缩减方式。
        """
        # 检查标签数量是否有效
        if num_tags <= 0:
            raise ValueError(f'无效的标签数量: {num_tags}')
        super().__init__()
        # 检查reduction参数是否有效
        if reduction not in ('none', 'sum', 'mean', 'token_mean'):
            raise ValueError(f'无效的缩减方式: {reduction}')
        self.num_tags = num_tags  # 标签数量
        self.batch_first = batch_first  # 批次是否在第一维度
        self.reduction = reduction  # 损失函数缩减方式
        # 初始化起始和结束状态转移权重
        self.start_transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags,)), name='start_transitions')
        self.end_transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags,)), name='end_transitions')
        # 初始化状态间转移权重
        self.transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags, num_tags)), name='transitions')

    def construct(self, emissions, tags=None, seq_length=None):
        """
        CRF模型的前向传播方法。
        
        参数:
        emissions: 发射概率张量。
        tags: 真实标签张量。
        seq_length: 序列长度张量。
        
        返回:
        如果tags为None,则返回解码结果;否则返回损失值。
        """
        if tags is None:
            return self._decode(emissions, seq_length)
        return self._forward(emissions, tags, seq_length)

    def _forward(self, emissions, tags=None, seq_length=None):
        """
        计算损失值。
        
        参数:
        emissions: 发射概率张量。
        tags: 真实标签张量。
        seq_length: 序列长度张量。
        
        返回:
        损失值。
        """
        # 根据batch_first参数调整emissions和tags的维度顺序
        if self.batch_first:
            batch_size, max_length = tags.shape
            emissions = emissions.swapaxes(0, 1)
            tags = tags.swapaxes(0, 1)
        else:
            max_length, batch_size = tags.shape
        
        # 如果seq_length未给出,则假设所有序列都是最大长度
        if seq_length is None:
            seq_length = mnp.full((batch_size,), max_length, ms.int64)
        
        # 生成mask矩阵
        mask = sequence_mask(seq_length, max_length)
        
        # 计算分子部分(真实路径的得分)
        numerator = compute_score(emissions, tags, seq_length-1, mask, self.transitions, self.start_transitions, self.end_transitions)
        # 计算分母部分(所有可能路径的得分总和)
        denominator = compute_normalizer(emissions, mask, self.transitions, self.start_transitions, self.end_transitions)
        # 计算对数似然比
        llh = denominator - numerator
        
        # 根据reduction参数选择损失值的缩减方式
        if self.reduction == 'none':
            return llh
        elif self.reduction == 'sum':
            return llh.sum()
        elif self.reduction == 'mean':
            return llh.mean()
        return llh.sum() / mask.astype(emissions.dtype).sum()

    def _decode(self, emissions, seq_length=None):
        """
        解码方法,用于预测最优标签序列。
        
        参数:
        emissions: 发射概率张量。
        seq_length: 序列长度张量。
        
        返回:
        最优标签序列。
        """
        # 根据batch_first参数调整emissions的维度顺序
        if self.batch_first:
            batch_size, max_length = emissions.shape[:2]
            emissions = emissions.swapaxes(0, 1)
        else:
            batch_size, max_length = emissions.shape[:2]
        
        # 如果seq_length未给出,则假设所有序列都是最大长度
        if seq_length is None:
            seq_length = mnp.full((batch_size,), max_length, ms.int64)
        
        # 生成mask矩阵
        mask = sequence_mask(seq_length, max_length)
        
        # 使用维特比算法解码最优路径
        return viterbi_decode(emissions, mask, self.transitions, self.start_transitions, self.end_transitions)

打卡图片:

请添加图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1922666.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Android Gantt View 安卓实现项目甘特图

需要做一个项目管理工具&#xff0c;其中使用到了甘特图。发现全网甘特图解决方案比较少&#xff0c;于是自动动手丰衣足食。 前面我用 Python和 Node.js 前端都做过&#xff0c;这次仅仅是移植到 Android上面。 其实甘特图非常简单&#xff0c;开发也不难&#xff0c;如果我…

PCL从理解到应用【04】Octree 原理分析 | 案例分析 | 代码实现

前言 Octree 作为一种高效的空间分割数据结构&#xff0c;具有重要的应用价值。 本文将深入分析 Octree 的原理&#xff0c;通过多个实际案例帮助读者全面理解其功能和应用&#xff0c;包括最近邻搜索、半径搜索、盒子搜索以及点云压缩&#xff08;体素化&#xff09;。 特性…

MongoDB - 查询操作符:比较查询、逻辑查询、元素查询、数组查询

文章目录 1. 构造数据2. MongoDB 比较查询操作符1. $eq 等于1.1 等于指定值1.2 嵌入式文档中的字段等于某个值1.3 数组元素等于某个值1.4 数组元素等于数组值 2. $ne 不等于3. $gt 大于3.1 匹配文档字段3.2 根据嵌入式文档字段执行更新 4. $gte 大于等于5. $lt 小于6. $lte 小于…

(Vue+SpringBoot+elementUi+WangEditer)仿论坛项目

项目使用到的技术与库 1.前端 Vue2 elementUi Cookie WangEditer 2.后端 SpringBoot Mybatis-Plus 3.数据库 MySql 一、效果展示 1.1主页效果&#xff1a; 1.2 文章编辑页面&#xff1a; 1.3 成功发布文章 1.4 文章关键字搜索提示 1.5 文章查询结果展示 1.6 文章内容及交互展示…

统信UOS服务器操作系统离线安装postgresql数据库

原文链接&#xff1a;统信UOS服务器离线安装postgresql数据库 Hello&#xff0c;大家好啊&#xff01;今天给大家带来一篇关于在统信UOS服务器操作系统上离线安装PostgreSQL数据库的文章。PostgreSQL是一款功能强大的开源对象-关系型数据库管理系统。由于某些环境中无法直接访问…

免费开源的工业物联网(IoT)解决方案

什么是 IoT&#xff1f; 物联网 (IoT) 是指由实体设备、车辆、电器和其他实体对象组成的网络&#xff0c;这些实体对象内嵌传感器、软件和网络连接&#xff0c;可以收集和共享数据。 IoT 设备&#xff08;也称为“智能对象”&#xff09;范围广泛&#xff0c;包括智能恒温器等…

SpringBoot+Vue(2)excel后台管理页面

一、需求 SpringBootVue写excel后台管理页面&#xff08;二级页面打开展示每一个excel表&#xff0c;数据库存储字段为“下载、删除、文件详情、是否共享、共享详情”&#xff09; 二、解答 后端(Spring Boot) 1. 项目设置 使用Spring Initializr创建一个新的Spring Boot项目…

深入理解 Elasticsearch 分页技术

原文链接&#xff1a;https://zhuanlan.zhihu.com/p/609576187 Elasticsearch 是一款分布式的搜索引擎&#xff0c;提供了灵活的分页技术。本文主要介绍 Elasticsearch&#xff08;简称 ES&#xff09; 的几种分页技术&#xff0c;并深入分析各种分页技术的优缺点及应用场景。 …

基于AT89C51单片机篮球计时计分器的设计(含文档、源码与proteus仿真,以及系统详细介绍)

本篇文章论述的是基于AT89C51单片机篮球计时计分器的设计的详情介绍&#xff0c;如果对您有帮助的话&#xff0c;还请关注一下哦&#xff0c;如果有资源方面的需要可以联系我。 目录 绪论 原理图 ​编辑 仿真图 系统总体设计图 代码实现 系统论文 资源下载 绪论 本次…

内网服务器通过squid代理访问外网

一、背景 现在要对172.16.58.158服务器进行openssh升级操作,我用之前写好的升级脚本执行后,发现没有备份旧的ssh程序文件,然后还卸载了oenssl-devel,然后我发现其他服务器ssh该服务器失败。同时脚本执行时报错“ configure: error: *** zlib.h missing - please install first …

windows查看局域网所有设备ip

windows如何查看局域网所有设备ip 操作方法 一 . 在搜索栏里输入cmd 二 .在命令行黑窗口输入arp -a 三 . 最上面显示的动态地址就是所有设备ip

day20、21、22补卡

235. 二叉搜索树的最近公共祖先 这道题的解题思路&#xff0c;我想了一会都没想出来&#xff0c;看了题解想&#xff1a;对于二叉搜索树&#xff0c;当我们从上向下去递归遍历&#xff0c;第一次遇到 cur节点是数值在[q, p]区间中&#xff0c;那么cur就是 q和p的最近公共祖先。…

Database数据库 vs Data Warehouse数据仓库 vs Data Mart数据集市 vs Data Lake数据湖

1.DATABASE 数据库 数据库是一个结构化的数据集合&#xff0c;用于存储、管理和检索数据。数据库设计用于支持事务处理&#xff08;OLTP&#xff0c;Online Transaction Processing&#xff09;和日常操作。 数据库通常由数据库管理系统&#xff08;DBMS&#xff09;控制&…

webRtc架构与目录结构

整体架构 目录结构 webrtc Modules目录

基于PCIe总线架构的2路1GSPS AD、4路1GSPS DA信号处理平台(100%国产化)

板卡概述 PCIE723-165是基于PCIE总线架构的2通道1GSPS采样率14位分辨率、4通道1GSPS采样率16位分辨率信号处理平台&#xff0c;该板卡采用国产16nm FPGA作为实时处理器&#xff0c;支持2路高速采集以及4路高速数据回放&#xff0c;板载2组DDR4 SDRAM大容量数据缓存&#xff0c;…

宝兰德参编金融智能体标准,深耕大模型场景化落地

随着数智化浪潮的不断推进&#xff0c;人工智能技术正深刻影响着金融服务的模式和流程&#xff0c;金融智能体在大模型的加持下&#xff0c;业务场景的应用能力得到强化。然而&#xff0c;作为新型技术&#xff0c;金融智能体在隐私保护、透明性、数据泄露等方面仍存在诸多风险…

图片存储问题总结

参考博客&#xff1a; https://blog.csdn.net/BUPT_Kwong/article/details/100972964 今天发现图片保存的一个神奇的问题&#xff0c;就是说原始的jpg图片打开后&#xff0c;重新保存成jpg格式&#xff0c;会发现这个结果不是很对的 example from PIL import Image import n…

房屋出租管理系统小程序需求分析及功能介绍

房屋租赁管理系统适用于写字楼、办公楼、厂区、园区、商城、公寓等商办商业不动产的租赁管理及租赁营销&#xff1b;提供资产管理&#xff0c;合同管理&#xff0c;租赁管理&#xff0c; 物业管理&#xff0c;门禁管理等一体化的运营管理平台&#xff0c;提高项目方管理运营效率…

【qt】QTcpSocket相关的信号

QTcpSocket可以在这里找到相关的信号 进行信号槽的关联 connect():这个信号在connectToHost()被调用并且连接已经成功建立之后发出 disconnected():该信号在套接字断开连接时发出 stateChanged(QAbstractSocket::SocketState socketState):每当QAbstractSocket的状态发生变化…

基于Adaboost的数据分类算法matlab仿真

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.本算法原理 5.完整程序 1.程序功能描述 基于Adaboost的数据分类算法matlab仿真,分别对比线性分类和非线性分类两种方式。 2.测试软件版本以及运行结果展示 MATLAB2022A版本运行 &#xff08;完整程序…