Mindspore框架CRF条件随机场概率图模型实现文本序列命名实体标注|(三)双向LSTM+CRF模型构建实现

news2025/1/8 22:29:19

Mindspore框架CRF条件随机场概率图模型实现文本序列命名实体标注|(一)序列标注与条件随机场的关系
Mindspore框架CRF条件随机场概率图模型实现文本序列命名实体标注|(二)CRF模型构建
Mindspore框架CRF条件随机场概率图模型实现文本序列命名实体标注|(三)双向LSTM+CRF模型构建实现


Mindspore框架CRF条件随机场概率图模型实现文本序列命名实体标注|(三)双向LSTM+CRF模型构建

一、双向LSTM+CRF

BI-LSTM-CRF模型:优势在于它结合了双向LSTM的能力来捕获长距离的双向上下文依赖性,并通过CRF层来精确地建模标签之间的约束关系(CRF层能够确保识别出的实体标签在整个序列中保持一致性),从而在复杂的序列标注任务中提供了显著的性能提升。
在实现CRF后,我们设计一个双向LSTM+CRF的模型来进行命名实体识别任务的训练。
模型结构如下:

nn.Embedding -> nn.LSTM -> nn.Dense -> CRF

在这里插入图片描述

其中LSTM提取序列特征,经过Dense层变换获得发射概率矩阵,最后送入CRF层。具体实现如下:

class BiLSTM_CRF(nn.Cell):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_tags, padding_idx=0):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, bidirectional=True, batch_first=True)
        self.hidden2tag = nn.Dense(hidden_dim, num_tags, 'he_uniform')
        self.crf = CRF(num_tags, batch_first=True)

    def construct(self, inputs, seq_length, tags=None):
        embeds = self.embedding(inputs)
        outputs, _ = self.lstm(embeds, seq_length=seq_length)
        feats = self.hidden2tag(outputs)

        crf_outs = self.crf(feats, tags, seq_length)
        return crf_outs

二、构造词表和标签表

构建一个简易训练集

embedding_dim = 16
hidden_dim = 32

training_data = [(
    "清 华 大 学 坐 落 于 首 都 北 京".split(),
    "B I I I O O O O O B I".split()
), (
    "重 庆 是 一 个 魔 幻 城 市".split(),
    "B I O O O O O O O".split()
),(
    "北 京 大 学 坐 落 于 首 都 北 京".split(),
    "B I I I O O O O O B I".split()
), (
    "南 京 大 学 坐 落 于 故 都 南 京".split(),
    "B I I I O O O O O B I".split()
)]

word_to_idx = {}
word_to_idx['<pad>'] = 0
for sentence, tags in training_data:
    for word in sentence:
        if word not in word_to_idx:
            word_to_idx[word] = len(word_to_idx)

tag_to_idx = {"B": 0, "I": 1, "O": 2}  # 定义标签-序列

预测时使用:序列转标签

idx_to_tag = {idx: tag for tag, idx in tag_to_idx.items()}

def sequence_to_tag(sequences, idx_to_tag):
    outputs = []
    for seq in sequences:
        outputs.append([idx_to_tag[i] for i in seq])
    return outputs

测试输出len(word_to_idx)结果:
在这里插入图片描述

将生成的数据打包成Batch,按照序列最大长度,对长度不足的序列进行填充,分别返回输入序列、输出标签和序列长度构成的Tensor。

def prepare_sequence(seqs, word_to_idx, tag_to_idx):
    seq_outputs, label_outputs, seq_length = [], [], []
    max_len = max([len(i[0]) for i in seqs])

    for seq, tag in seqs:
        seq_length.append(len(seq))
        idxs = [word_to_idx[w] for w in seq]
        labels = [tag_to_idx[t] for t in tag]
        idxs.extend([word_to_idx['<pad>'] for i in range(max_len - len(seq))])
        labels.extend([tag_to_idx['O'] for i in range(max_len - len(seq))])
        seq_outputs.append(idxs)
        label_outputs.append(labels)

    return ms.Tensor(seq_outputs, ms.int64), \
            ms.Tensor(label_outputs, ms.int64), \
            ms.Tensor(seq_length, ms.int64)
data, label, seq_length = prepare_sequence(training_data, word_to_idx, tag_to_idx)
print(data.shape, label.shape, seq_length.shape)

((4, 11), (4, 11), (4,))

三、训练双向LSTM+CRF模型

模型初始化:

model = BiLSTM_CRF(len(word_to_idx), embedding_dim, hidden_dim, len(tag_to_idx))
optimizer = nn.SGD(model.trainable_params(), learning_rate=0.01, weight_decay=1e-4)
grad_fn = ms.value_and_grad(model, None, optimizer.parameters)

def train_step(data, seq_length, label):
    loss, grads = grad_fn(data, seq_length, label)
    optimizer(grads)
    return loss

训练模型:

from tqdm import tqdm

steps = 500
with tqdm(total=steps) as t:
    for i in range(steps):
        loss = train_step(data, seq_length, label)
        t.set_postfix(loss=loss)
        t.update(1)

四、模型预测

score, history = model(data, seq_length)  
# 打印实体命名预测结果
res = sequence_to_tag(predict, idx_to_tag)
print(res)

预测:
在这里插入图片描述

输出:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1947377.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MySQL4.索引及视图

1.建库 create database mydb15_indexstu; use mydb15_indexstu;2.建表 2.1 student表学&#xff08;sno&#xff09;号为主键&#xff0c;姓名&#xff08;sname&#xff09;不能重名&#xff0c;性别&#xff08;ssex&#xff09;仅能输入男或女&#xff0c;默认所在系别&a…

BSV区块链在人工智能时代的数字化转型中的角色

​​发表时间&#xff1a;2024年6月13日 企业数字化转型已有约30年的历史&#xff0c;而人工智能&#xff08;以下简称AI&#xff09;将这种转型提升到了一个全新的高度。这并不难理解&#xff0c;因为AI终于使企业能够发挥其潜力&#xff0c;实现更宏大的目标。然而&#xff0…

Flowable-整合系统角色人员

一、整合角色人员 在设置审批人和候选人的时候&#xff0c;我们希望能够选取业务数据库中的用户&#xff0c;也是是集成系统中的用户和角色&#xff0c;来帮助我们快速选择审批人&#xff0c;不再需要通过flowable的用户。 1. 加载系统用户 在设置候选用户的时候&#xff0c…

手把手教你本地运行Meta最新大模型:Llama3.1,可是它说自己是ChatGPT?

就在昨晚&#xff0c;Meta发布了可以与OpenAI掰手腕的最新开源大模型&#xff1a;Llama 3.1。 该模型共有三个版本&#xff1a; 8B70B405B 对于这次发布&#xff0c;Meta已经在超过150个涵盖广泛语言范围的基准数据集上评估了性能。此外&#xff0c;Meta还进行了广泛的人工评…

ADS 使用教程(十六)Using Sliders for Data Processing

上一篇&#xff1a;ADS 使用教程&#xff08;十五&#xff09;Multi-Dimensional Data Processing in ADS 在这一节&#xff0c;我们来谈论一下如何在进行多维数据处理时使用滑块&#xff08;Sliders&#xff09;来进行数据处理和分析。通过该方法&#xff0c;我们可以通过拖动…

SkyWalking入门搭建【apache-skywalking-apm-10.0.0】

Java学习文档 视频讲解 文章目录 一、准备二、服务启动2-1、Nacos启动2-2、SkyWalking服务端启动2-3、SkyWalking控制台启动2-4、自定义服务接入 SkyWalking 三、常用监控3-1、服务请求通过率3-2、服务请求拓扑图3-3、链路 四、日志配置五、性能剖析六、数据持久化6-1、MySQL持…

相机的内参与外参

目录 一、相机的内参二、相机的外参 一、相机的内参 如下图所示是相机的针孔模型示意图&#xff1a; 光心O所处平面是相机坐标系(O&#xff0c;P)&#xff0c;像素平面所在坐标系为像素坐标系(O’&#xff0c;P’)。 焦距f&#xff1a;O到O’的距离 相机的内参表示的是相机坐标…

运算符的运算顺序

【单目算术位关系&#xff0c;逻辑三目后赋值】 ![在这里插入图片描述] (https://i-blog.csdnimg.cn/direct/e4c8f4e22b5044a48154bf7378e3b3b3.png)

【React】useState:状态更新规则详解

文章目录 一、基本用法二、直接修改状态 vs 使用 setState 更新状态三、对象状态的更新四、深层次对象的更新五、函数式更新六、优化性能的建议 在 React 中&#xff0c;useState 是一个非常重要的 Hook&#xff0c;用于在函数组件中添加状态管理功能。正确理解和使用 useState…

同步库存扣减到数据库

跟接上文: 并发情况下的库存扣减 一 库存扣减 public DefaultTreeFactory.TreeActionEntity logic(String userId, Long strategyId, Integer awardId, String ruleValue, Date endDateTime) {log.info("规则过滤-库存扣减 userId:{} strategyId:{} awardId:{}", u…

结合el-upload上传组件,验证文件格式及大小

结合el-upload上传组件&#xff0c;验证文件格式及大小 效果如下&#xff1a; 代码如下&#xff1a; upgradeFirmwareInfo.vue页面 <template><div><el-dialog title"新增固件升级包" :visible.sync"dialogFormVisible"top"7vh&qu…

电商数据集成之电商商品信息采集系统架构设计||电商API接口

一、引言 本架构设计文档旨在阐述基于 Selenium 的电商商品信息采集系统的整体架构&#xff0c;包括系统视图、逻辑视图、物理视图、开发视图和进程视图&#xff0c;并提供一个简单的采集电商商品信息的 demo。该系统通过模拟浏览器行为&#xff0c;实现对电商商品信息的自…

力扣高频SQL 50题(基础版)第九题

文章目录 力扣高频SQL 50题&#xff08;基础版&#xff09;第九题197. 上升的温度题目说明思路分析实现过程准备数据实现方式结果截图总结 力扣高频SQL 50题&#xff08;基础版&#xff09;第九题 197. 上升的温度 题目说明 Weather ---------------------- | Column Name…

从防范到防御异常场景处理机制终于闭环了

前言 为什么要做异常场景自动化监控&#xff1f; 怎么做&#xff1f; 做这件事情的意义! 前言 QA在异常场景的处理上&#xff0c;之前更多地侧重于防范&#xff0c;通过风险评估、异常测试等手段降低异常发生的概率。异常场景数据的自动化监控更多地侧重于防御&#xff0c…

【北京迅为】《i.MX8MM嵌入式Linux开发指南》-第三篇 嵌入式Linux驱动开发篇-第四十九章 平台总线总结回顾

i.MX8MM处理器采用了先进的14LPCFinFET工艺&#xff0c;提供更快的速度和更高的电源效率;四核Cortex-A53&#xff0c;单核Cortex-M4&#xff0c;多达五个内核 &#xff0c;主频高达1.8GHz&#xff0c;2G DDR4内存、8G EMMC存储。千兆工业级以太网、MIPI-DSI、USB HOST、WIFI/BT…

matlab 小数取余 rem 和 mod有 bug

目录 前言Matlab取余函数1 mod 函数1.1 命令行输入1.2 命令行输出 2 rem 函数2.1 命令行输入2.2 命令行输出 分析原因注意 前言 在 Matlab 代码中mod(0.11, 0.1) < 0.01 判断为真&#xff0c;mod(1.11, 0.1) < 0.01判断为假&#xff0c;导致出现意料外的结果。 结果发现…

CCS(Code Composer Studio 10.4.0)编译软件中文乱码怎么解决

如果是所有文件都出现了中文乱码这时建议直接在窗口首选项中修改&#xff1a;选择"Window" -> "Preferences"&#xff0c;找到"General" -> "Workspace"&#xff0c;将"Text file encoding"选项设置为"Other&quo…

Android lmkd机制详解

目录 一、lmkd介绍 二、lmkd实现原理 2.1 工作原理图 2.2 初始化 2.3 oom_adj获取 2.4 监听psi事件及处理 2.5 进程选取与查杀 2.5.1 进程选取 2.5.2 进程查杀 三、关键系统属性 四、核心数据结构 五、代码时序 一、lmkd介绍 Android lmkd采用epoll方式监听linux内…

关于私域终局的几个观察 【后续加图】

01 私域流量的本质 每个人对「私域」的定义都不一样。 最初&#xff0c;大家认为私域其实就是把微商往高级了说&#xff0c;于是 2020 年我看到许多朋友一窝蜂到朋友圈卖货&#xff0c;冒出了一个段子叫「我们都活成了微商」。 最开始&#xff0c;我也以为「私域」只是微信…

微服务实现全链路灰度发布

一、实现步骤 再请求 Header 中打上标签&#xff0c;例如再 Header 中添加 "gray-tag: true" &#xff0c;其表示要进行灰度测试&#xff08;访问灰度服务&#xff09;&#xff0c;而其他则访问正式服务。在负载均衡器 Spring Cloud LoadBalancer 中&#xff0c;拿到…