240713_昇思学习打卡-Day25-LSTM+CRF序列标注(4)

news2024/11/16 21:23:35

240713_昇思学习打卡-Day25-LSTM+CRF序列标注(4)

最后一天咯,做第四部分。

BiLSTM+CRF模型

在实现CRF后,我们设计一个双向LSTM+CRF的模型来进行命名实体识别任务的训练。模型结构如下:

nn.Embedding -> nn.LSTM -> nn.Dense -> CRF

其中LSTM提取序列特征,经过Dense层变换获得发射概率矩阵,最后送入CRF层。具体实现如下:

# 定义双向LSTM结合CRF的序列标注模型
class BiLSTM_CRF(nn.Cell):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_tags, padding_idx=0):
        """
        初始化BiLSTM_CRF模型。

        参数:
        vocab_size: 词汇表大小。
        embedding_dim: 词嵌入维度。
        hidden_dim: LSTM隐藏层维度。
        num_tags: 标签种类数量。
        padding_idx: 填充索引,默认为0。
        """
        super().__init__()
        # 初始化词嵌入层
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
        # 初始化双向LSTM层
        self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, bidirectional=True, batch_first=True)
        # 初始化从LSTM输出到标签的全连接层
        self.hidden2tag = nn.Dense(hidden_dim, num_tags, 'he_uniform')
        # 初始化条件随机场层
        self.crf = CRF(num_tags, batch_first=True)

    def construct(self, inputs, seq_length, tags=None):
        """
        模型的前向传播方法。

        参数:
        inputs: 输入序列,形状为(batch_size, seq_length)。
        seq_length: 序列长度,形状为(batch_size,)。
        tags: 真实标签,形状为(batch_size, seq_length),可选。

        返回:
        crf_outs: CRF层的输出,如果输入了真实标签则为损失值,否则为解码后的标签序列。
        """
        # 通过词嵌入层获取词向量表示
        embeds = self.embedding(inputs)
        # 通过双向LSTM层获取序列特征
        outputs, _ = self.lstm(embeds, seq_length=seq_length)
        # 通过全连接层转换LSTM输出到标签空间
        feats = self.hidden2tag(outputs)

        # 通过CRF层计算损失或解码
        crf_outs = self.crf(feats, tags, seq_length)
        return crf_outs

完成模型设计后,我们生成两句例子和对应的标签,并构造词表和标签表。

# 设置词嵌入维度和隐藏层维度
embedding_dim = 16
hidden_dim = 32

# 定义训练数据集,每条数据包含一个分词后的句子和相应的实体标签
training_data = [
    (
        "清 华 大 学 坐 落 于 首 都 北 京".split(),  # 分词后的句子
        "B I I I O O O O O B I".split()  # 相应的实体标签
    ),
    (
        "重 庆 是 一 个 魔 幻 城 市".split(),  # 分词后的句子
        "B I O O O O O O O".split()  # 相应的实体标签
    )
]

# 初始化词典,用于映射词到索引
word_to_idx = {}
# 添加特殊填充词到词典
word_to_idx['<pad>'] = 0
# 遍历训练数据,构建词到索引的映射
for sentence, tags in training_data:
    for word in sentence:
        # 如果词不在词典中,则添加到词典
        if word not in word_to_idx:
            word_to_idx[word] = len(word_to_idx)

# 初始化标签到索引的映射
tag_to_idx = {"B": 0, "I": 1, "O": 2}

len(word_to_idx)

接下来实例化模型,选择优化器并将模型和优化器送入Wrapper。

由于CRF层已经进行了NLLLoss的计算,因此不需要再设置Loss。

# 实例化BiLSTM_CRF模型,传入词汇表大小、词嵌入维度、隐藏层维度以及标签种类数量
model = BiLSTM_CRF(len(word_to_idx), embedding_dim, hidden_dim, len(tag_to_idx))

# 初始化随机梯度下降优化器,设置学习率为0.01,权重衰减为1e-4
optimizer = nn.SGD(model.trainable_params(), learning_rate=0.01, weight_decay=1e-4)

# 使用MindSpore的value_and_grad函数创建一个函数,它会同时计算模型的损失值和梯度
# 第二个参数设置为None表示不保留反向图,第三个参数是优化器的参数列表
grad_fn = ms.value_and_grad(model, None, optimizer.parameters)

def train_step(data, seq_length, label):
    """
    训练步骤函数,执行一次前向传播和反向传播更新模型参数。

    参数:
    data: 输入数据,形状为(batch_size, seq_length)。
    seq_length: 序列长度,形状为(batch_size,)。
    label: 真实标签,形状为(batch_size, seq_length)。

    返回:
    loss: 当前批次的损失值。
    """
    # 使用grad_fn计算损失值和梯度
    loss, grads = grad_fn(data, seq_length, label)
    # 使用优化器更新模型参数
    optimizer(grads)
    # 返回损失值
    return loss

将生成的数据打包成Batch,按照序列最大长度,对长度不足的序列进行填充,分别返回输入序列、输出标签和序列长度构成的Tensor。

def prepare_sequence(seqs, word_to_idx, tag_to_idx):
    """
    准备序列数据,包括填充和转换成张量。

    参数:
    seqs: 一个包含句子和对应标签的元组列表。
    word_to_idx: 词到索引的映射字典。
    tag_to_idx: 标签到索引的映射字典。

    返回:
    seq_outputs: 填充后的序列数据张量。
    label_outputs: 填充后的标签数据张量。
    seq_length: 序列的真实长度列表。
    """
    seq_outputs, label_outputs, seq_length = [], [], []
    # 获取最长序列长度
    max_len = max([len(i[0]) for i in seqs])

    for seq, tag in seqs:
        # 记录序列的真实长度
        seq_length.append(len(seq))
        # 将序列中的词转换为索引
        idxs = [word_to_idx[w] for w in seq]
        # 将标签转换为索引
        labels = [tag_to_idx[t] for t in tag]
        # 对序列进行填充
        idxs.extend([word_to_idx['<pad>'] for i in range(max_len - len(seq))])
        # 对标签进行填充,用'O'的索引填充
        labels.extend([tag_to_idx['O'] for i in range(max_len - len(seq))])
        # 添加填充后的序列和标签到列表
        seq_outputs.append(idxs)
        label_outputs.append(labels)

    # 将列表转换为MindSpore张量
    return ms.Tensor(seq_outputs, ms.int64), \
           ms.Tensor(label_outputs, ms.int64), \
           ms.Tensor(seq_length, ms.int64)

# 调用prepare_sequence函数处理训练数据,并获取处理后的数据、标签和序列长度
data, label, seq_length = prepare_sequence(training_data, word_to_idx, tag_to_idx)

# 打印处理后的数据、标签和序列长度的形状,以确认数据转换是否正确
print(data.shape, label.shape, seq_length.shape)

对模型进行预编译后,训练500个step。

训练流程可视化依赖tqdm库,可使用pip install tqdm命令安装。

from tqdm import tqdm

# 定义训练步骤的总数,用于进度条的设置
steps = 500

# 使用tqdm创建一个进度条,总进度为steps
with tqdm(total=steps) as t:
    for i in range(steps):
        # 执行单步训练,这里假设train_step是一个已定义的训练函数
        # 参数data为训练数据,seq_length为序列长度,label为标签
        loss = train_step(data, seq_length, label)
        
        # 更新进度条的附带信息,显示当前的损失值
        t.set_postfix(loss=loss)
        
        # 更新进度条,表示完成了一步训练
        t.update(1)

最后我们来观察训练500个step后的模型效果,首先使用模型预测可能的路径得分以及候选序列。

# 调用模型进行预测或评估,返回得分和历史记录
score, history = model(data, seq_length)

# 输出得分,用于查看模型的表现或决策
score

使用后处理函数进行预测得分的后处理。

predict = post_decode(score, history, seq_length)
predict

最后将预测的index序列转换为标签序列,打印输出结果,查看效果。

# 通过索引和标签的映射关系,构建标签到索引的反向映射
idx_to_tag = {idx: tag for tag, idx in tag_to_idx.items()}

def sequence_to_tag(sequences, idx_to_tag):
    """
    将序列中的索引转换为对应的标签。
    
    参数:
    sequences: 一个包含标签索引的序列列表。
    idx_to_tag: 一个字典,用于将索引映射到对应的标签。
    
    返回:
    一个列表,其中每个元素是输入序列中索引转换为标签后的结果。
    """
    # 初始化一个空列表,用于存储转换后的标签序列
    outputs = []
    # 遍历输入的序列列表
    for seq in sequences:
        # 对每个序列,将索引转换为标签,并添加到输出列表中
        outputs.append([idx_to_tag[i] for i in seq])
    # 返回转换后的标签序列列表
    return outputs

sequence_to_tag(predict, idx_to_tag)

打卡照片:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1924109.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

前端练习小项目——方向感应名片

前言&#xff1a;在学习完HTML和CSS之后&#xff0c;我们就可以开始做一些小项目了&#xff0c;本篇文章所讲的小项目为——方向感应名片 ✨✨✨这里是秋刀鱼不做梦的BLOG ✨✨✨想要了解更多内容可以访问我的主页秋刀鱼不做梦-CSDN博客 在开始学习之前&#xff0c;先让我们看一…

C++客户端Qt开发——开发环境

一、QT开发环境 1.安装三个部分 ①C编译器&#xff08;gcc&#xff0c;cl.exe……) ②QT SDK SDK-->软件开发工具包 比如&#xff0c;windows版本QT SDK里已经内置了C的编译器&#xff08;内置编译器是mingw&#xff0c;windows版本的gcc/g&#xff09; ③QT的集成开发…

KnoBo:医书学习知识,辅助图像分析,解决分布外性能下降和可解释性问题

KnoBo&#xff1a;从医书中学习知识&#xff0c;辅助图像分析&#xff0c;解决分布外性能下降问题 提出背景KnoBo 流程图KnoBo 详解问题构成结构先验瓶颈预测器参数先验 解法拆解逻辑链对比 CLIP、Med-CLIPCLIPMed-CLIPKnoBo 训练细节预训练过程OpenCLIP的微调 构建医学语料库文…

说说执行一条查询SQL语句时,期间发生了什么?

执行一条查询SQL语句时&#xff0c;期间发生了什么&#xff1f; 前言说说执行一条查询SQL语句时&#xff0c;发生了什么&#xff1f;连接器权限验证断开连接长连接 查询缓存查询缓存的问题 解析器词法分析语法分析 执行 SQL预处理器优化器执行器主键索引查询全表扫描索引下推 总…

轻薄鼠标的硬核选购攻略,很多人都在“高性价比”鼠标上栽跟头了

轻薄款设计的鼠标是目前鼠标市场的出货大头&#xff0c; 也是价格最卷的一类鼠标。 比游戏鼠标或许更卷一些。 这和当前的移动办公趋势关系很大。 这类鼠标主要跟笔记本和iPad搭配。 核心的使用场景是办公。 因此轻薄和静音是这类鼠标的核心卖点。 同时用户并不愿意付出太…

代码随想录算法训练营第三十二天|1049.最后一块石头的重量II、494.目标和、474.一和零

1049.最后一块石头的重量II 有一堆石头&#xff0c;每块石头的重量都是正整数。 每一回合&#xff0c;从中选出任意两块石头&#xff0c;然后将它们一起粉碎。假设石头的重量分别为 x 和 y&#xff0c;且 x < y。那么粉碎的可能结果如下&#xff1a; 如果 x y&#xff0c;那…

期货交易记录20240713

文章目录 期货交易系统构建步骤一、选品二、心态历练三、何时开仓3.1、开仓纪律3.2、开仓时机3.3、开仓小技巧 四、持仓纪律五、接下来的计划 2024年7月13号&#xff0c;期货交易第5篇记录。 交易记录&#xff1a;半个月多没记录了&#xff0c;这段时间分别尝试做了菜粕、棕榈油…

9.6 栅格图层符号化唯一值着色渲染

文章目录 前言多波段彩色渲染唯一值着色QGis设置为唯一值着色二次开发代码实现唯一值着色 总结 前言 介绍栅格图层数据渲染之唯一值着色渲染说明&#xff1a;文章中的示例代码均来自开源项目qgis_cpp_api_apps 多波段彩色渲染唯一值着色 以“with_color_table.tif”数据为例…

【嵌入式DIY实例-ESP8266篇】-LCD ST7789显示DS1307 RTC时间数据

LCD ST7789显示DS1307 RTC时间数据 文章目录 LCD ST7789显示DS1307 RTC时间数据1、硬件准备与接线2、代码实现本文将介绍如何使用 ESP8266 NodeMCU 板和 DS1307 RTC 集成电路构建简单的实时时钟和日历 (RTCC),其中时间和日期打印在 ST7789 TFT 显示模块上。 ST7789 TFT 模块包…

Open-TeleVision——通过VR沉浸式感受人形机器人视野:兼备远程控制和深度感知能力

前言 7.3日&#xff0c;我司七月在线(集AI大模型职教、应用开发、机器人解决方案为一体的科技公司)的「大模型机器人(具身智能)线下营」群里的一学员发了《Open-TeleVision: Teleoperation with Immersive Active Visual Feedback》这篇论文的链接&#xff0c;我当时快速看了一…

UML/SysML建模工具更新情况(2024年7月)(1)

DDD领域驱动设计批评文集 做强化自测题获得“软件方法建模师”称号 《软件方法》各章合集 工具最新版本&#xff1a;Enterprise Architect 17.0 BETA 更新时间&#xff1a;2024年7月2日 工具简介 性价比很高&#xff0c;目前最流行的UML建模工具。还包含需求管理、项目估算…

AIGC专栏13——ComfyUI 插件编写细节解析-以EasyAnimateV3为例

AIGC专栏13——ComfyUI 插件编写细节解析-以EasyAnimateV3为例 学习前言什么是ComfyUI相关地址汇总ComfyUIEasyAnimateV3 节点例子复杂例子-以EasyAnimateV3为例节点文件必要库的导入载入模型节点定义Image to Video节点定义节点名称映射 __init__.py文件插件导入comfyUI 学习前…

被动的机器人非线性MPC控制

MPC是一种基于数学模型的控制策略&#xff0c;它通过预测系统在未来一段时间内的行为&#xff0c;并求解优化问题来确定当前的控制输入&#xff0c;以实现期望的控制目标。对于非线性系统&#xff0c;MPC可以采用非线性模型进行预测和优化&#xff0c;这种方法被称为非线性模型…

JS实现:统计字符出现频率/计算文字在文本中的出现次数

要实现这个功能&#xff0c;JavaScript 一个非常强大的方法&#xff0c;那就是reduce() reduce() 它用于将数组的所有元素减少到一个单一的值。这个值可以是任何类型&#xff0c;包括但不限于数字、字符串、对象或数组。 reduce() 方法接收一个回调函数作为参数&#xff0c;这个…

【C++】设计一套基于C++与C#的视频播放软件

在开发一款集视频播放与丰富交互功能于一体的软件时&#xff0c;结合C的高性能与C#在界面开发上的便捷性&#xff0c;是一个高效且实用的选择。以下&#xff0c;我们将概述这样一个系统的架构设计、关键技术点以及各功能模块的详细实现思路。 一、系统架构设计 1. 架构概览 …

截图神器Snipaste

这是我作为测试这么些年来用过的最好用的截图工具&#xff0c;让你将截图贴回到屏幕上&#xff0c;最好用的截图工具&#xff0c;推荐给同事深受好评。 snipaste是一个简单但强大的截图工具&#xff0c;也可以让你将截图贴回到屏幕上。下载打开Snipaste,按下F2来开始截图&…

【已解决】sudo: apt: command not found 或者apt-get: command not found解决方案

一、问题 在CentOS7.5运行apt-get install supervisor遇到如下报错 二、原因 CentOS的软件安装工具不是apt-get &#xff0c;而是yum&#xff0c;应该使用如下命令&#xff1a; yum install supervisor 后面命令换为yum就可以了 三、扩展&#xff1a; 一般来说linux系统…

MVC架构

MVC架构 MVC架构在软件开发中通常指的是一种设计模式&#xff0c;它将应用程序分为三个主要组成部分&#xff1a;模型&#xff08;Model&#xff09;、视图&#xff08;View&#xff09;和控制器&#xff08;Controller&#xff09;。这种分层结构有助于组织代码&#xff0c;使…

AR0132AT 1/3 英寸 CMOS 数字图像传感器可提供百万像素 HDR 图像处理(器件编号包含:AR0132AT6R、AR0132AT6C)

AR0132AT 1/3 英寸 CMOS 数字图像传感器&#xff0c;带 1280H x 960V 有效像素阵列。它能在线性或高动态模式下捕捉图像&#xff0c;且带有卷帘快门读取。它包含了多种复杂的摄像功能&#xff0c;如自动曝光控制、开窗&#xff0c;以及视频和单帧模式。它适用于低光度和高动态范…

《GroupViT: Semantic Segmentation Emerges from Text Supervision》论文解读

会议&#xff1a;CVPR 年份&#xff1a;2022 代码&#xff1a;https://github.com/NVlabs/GroupViT 研究背景与动机&#xff1a; 传统深度学习系统中&#xff0c;图像区域的Grouping通常是隐式通过像素级识别标签的自上而下监督来实现的。作者提出将Grouping机制重新引入深…