人工智能(pytorch)搭建模型7-利用pytorch搭建一个BiLSTM+CRF模型,实现简单的命名实体识别

news2025/1/13 8:10:29

大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型7-利用pytorch搭建一个BiLSTM+CRF模型,实现简单的命名实体识别,BiLSTM+CRF 模型是一种常用的序列标注算法,可用于词性标注、分词、命名实体识别等任务。本文利用pytorch搭建一个BiLSTM+CRF模型,并给出数据样例,通过一个简单的命名实体识别(NER)任务来演示模型的训练和预测过程。文章将分为以下几个部分:

1. BiLSTM+CRF模型的介绍
2. BiLSTM+CRF模型的数学原理
3. 数据准备
4. 模型搭建
5. 训练与评估
6. 预测
7. 总结

1. BiLSTM+CRF模型的介绍

BiLSTM+CRF模型结合了双向长短时记忆网络(BiLSTM)和条件随机场(CRF)两种技术。BiLSTM用于捕捉序列中的上下文信息,而CRF用于解决标签之间的依赖关系。实际上,BiLSTM用于为每个输入序列生成一个特征向量,然后将这些特征向量输入到CRF层,以便为序列中的每个元素分配一个标签。BiLSTM 和 CRF 结合在一起,使模型即可以像 CRF 一样考虑序列前后之间的关联性,又可以拥有 LSTM 的特征抽取及拟合能力。

2.BiLSTM+CRF模型的数学原理

假设我们有一个序列 x = ( x 1 , x 2 , . . . , x n ) \boldsymbol{x} = (x_1, x_2, ..., x_n) x=(x1,x2,...,xn),其中 x i x_i xi 是第 i i i 个位置的输入特征。我们要对每个位置进行标注,即为每个位置 i i i 预测一个标签 y i y_i yi。标签集合为 Y = y 1 , y 2 , . . . , y n \mathcal{Y}={y_1, y_2, ..., y_n} Y=y1,y2,...,yn,其中 y i ∈ L y_i \in \mathcal{L} yiL L \mathcal{L} L 表示标签的类别集合。

BiLSTM用于从输入序列中提取特征,它由两个方向的LSTM组成,分别从前向后和从后向前处理输入序列。在时间步 t t t,BiLSTM的输出为 h t ∈ R 2 d h_t \in \mathbb{R}^{2d} htR2d,其中 d d d 是LSTM的隐藏状态维度。具体来说,前向LSTM从左至右处理输入序列 x \boldsymbol{x} x,输出隐状态序列 h → = ( h 1 → , h 2 → , . . . , h n → ) \overrightarrow{h}=(\overrightarrow{h_1},\overrightarrow{h_2},...,\overrightarrow{h_n}) h =(h1 ,h2 ,...,hn ),其中 h t → \overrightarrow{h_t} ht 表示在时间步 t t t 时前向LSTM的隐藏状态;后向LSTM从右至左处理输入序列 x \boldsymbol{x} x,输出隐状态序列 h ← = ( h 1 ← , h 2 ← , . . . , h n ← ) \overleftarrow{h}=(\overleftarrow{h_1},\overleftarrow{h_2},...,\overleftarrow{h_n}) h =(h1 ,h2 ,...,hn ),其中 h t ← \overleftarrow{h_t} ht 表示在时间步 t t t 时后向LSTM的隐藏状态。则每个位置 i i i 的特征表示为 h i = [ h i → ; h i ← ] h_i=[\overrightarrow{h_i};\overleftarrow{h_i}] hi=[hi ;hi ],其中 [ ⋅ ; ⋅ ] [\cdot;\cdot] [;] 表示向量拼接操作。

CRF用于建模标签之间的关系,并进行全局优化。CRF模型定义了一个由 Y \mathcal{Y} Y 构成的联合分布 P ( y ∣ x ) P(\boldsymbol{y}|\boldsymbol{x}) P(yx),其中 y = ( y 1 , y 2 , . . . , y n ) \boldsymbol{y} = (y_1, y_2, ..., y_n) y=(y1,y2,...,yn) 表示标签序列。具体来说,CRF模型将标签序列的概率分解为多个位置的条件概率的乘积,即

P ( y ∣ x ) = ∏ i = 1 n ψ i ( y i ∣ x ) ∏ i = 1 n − 1 ψ i , i + 1 ( y i , y i + 1 ∣ x ) P(\boldsymbol{y}|\boldsymbol{x})=\prod_{i=1}^{n}\psi_i(y_i|\boldsymbol{x}) \prod_{i=1}^{n-1}\psi_{i,i+1}(y_i,y_{i+1}|\boldsymbol{x}) P(yx)=i=1nψi(yix)i=1n1ψi,i+1(yi,yi+1x)

其中 ψ i ( y i ∣ x ) \psi_i(y_i|\boldsymbol{x}) ψi(yix) 表示在位置 i i i 时预测标签为 y i y_i yi 的条件概率, ψ i , i + 1 ( y i , y i + 1 ∣ x ) \psi_{i,i+1}(y_i,y_{i+1}|\boldsymbol{x}) ψi,i+1(yi,yi+1x) 表示预测标签为 y i y_i yi y i + 1 y_{i+1} yi+1 的联合概率。这些条件概率和联合概率可以用神经网络来建模,其中输入为位置 i i i 的特征表示 h i h_i hi

CRF模型的全局优化问题可以通过对数似然函数最大化来实现,即

max ⁡ y log ⁡ P ( y ∣ x ) = ∑ i = 1 n log ⁡ ψ i ( y i ∣ x ) ∑ i = 1 n − 1 log ⁡ ψ i , i + 1 ( y i , y i + 1 ∣ x ) \max_{\boldsymbol{y}}\log P(\boldsymbol{y}|\boldsymbol{x}) = \sum_{i=1}^{n}\log\psi_i(y_i|\boldsymbol{x}) \sum_{i=1}^{n-1}\log\psi_{i,i+1}(y_i,y_{i+1}|\boldsymbol{x}) ymaxlogP(yx)=i=1nlogψi(yix)i=1n1logψi,i+1(yi,yi+1x)
其中 y \boldsymbol{y} y 是所有可能的标签序列。可以使用动态规划算法(如维特比算法)来求解全局最优标签序列。

综上所述,BiLSTM+CRF模型的数学原理可以表示为:

P ( y ∣ x ) = ∏ i = 1 n ψ i ( y i ∣ x ) ∏ i = 1 n − 1 ψ i , i + 1 ( y i , y i + 1 ∣ x ) P(\boldsymbol{y}|\boldsymbol{x}) = \prod_{i=1}^{n}\psi_i(y_i|\boldsymbol{x}) \prod_{i=1}^{n-1}\psi_{i,i+1}(y_i,y_{i+1}|\boldsymbol{x}) P(yx)=i=1nψi(yix)i=1n1ψi,i+1(yi,yi+1x)

其中

ψ i ( y i ∣ x ) = exp ⁡ ( W o T h i + b o T y i ) ∑ y i ′ ∈ L exp ⁡ ( W o T h i + b o T y i ′ ) \psi_i(y_i|\boldsymbol{x}) = \frac{\exp(\boldsymbol{W}_o^{T}\boldsymbol{h}_i + \boldsymbol{b}_o^{T}\boldsymbol{y}i)}{\sum{y_i'\in\mathcal{L}}\exp(\boldsymbol{W}_o^{T}\boldsymbol{h}_i + \boldsymbol{b}_o^{T}\boldsymbol{y}_i')} ψi(yix)=yiLexp(WoThi+boTyi)exp(WoThi+boTyi)

ψ i , i + 1 ( y i , y i + 1 ∣ x ) = exp ⁡ ( W t T y i , i + 1 ) ∑ y i ′ ∈ L ∑ y i + 1 ′ ∈ L exp ⁡ ( W t T y i ′ , i + 1 ′ ) \psi_{i,i+1}(y_i,y_{i+1}|\boldsymbol{x}) = \frac{\exp(\boldsymbol{W}t^{T}\boldsymbol{y}{i,i+1})}{\sum_{y_i'\in\mathcal{L}}\sum_{y_{i+1}'\in\mathcal{L}}\exp(\boldsymbol{W}t^{T}\boldsymbol{y}{i',i+1}')} ψi,i+1(yi,yi+1x)=yiLyi+1Lexp(WtTyi,i+1)exp(WtTyi,i+1)

其中 W o \boldsymbol{W}_o Wo b o \boldsymbol{b}_o bo 是输出层的参数, W t \boldsymbol{W}_t Wt 是转移矩阵, h i \boldsymbol{h}_i hi 是位置 i i i 的特征表示, y i \boldsymbol{y}i yi 是位置 i i i 的标签表示, y i , i + 1 \boldsymbol{y}{i,i+1} yi,i+1 是位置 i i i i + 1 i+1 i+1 的标签联合表示。

在这里插入图片描述

3. 数据准备

下面我将使用一个简单的命名实体识别(NER)任务来演示模型的训练和预测过程。数据集包含了一些句子,每个句子中的单词都被标记为“B-PER”(人名开始)、“I-PER”(人名中间)、“B-LOC”(地名开始)、“I-LOC”(地名中间)或“O”(其他)。

数据样例:

John B-PER
lives O
in O
New B-LOC
York I-LOC
. O

4. 模型搭建

首先,我们需要安装PyTorch库:

pip install torch

接下来,我们将使用PyTorch搭建BiLSTM+CRF模型。完整的模型代码如下:

import torch
import torch.nn as nn
import torch.optim as optim

from TorchCRF import CRF

class BiLSTM_CRF(nn.Module):
    def __init__(self, vocab_size, tag_to_ix, embedding_dim, hidden_dim):
        super(BiLSTM_CRF, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.tag_to_ix = tag_to_ix
        self.tagset_size = len(tag_to_ix)

        self.word_embeds = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2,
                            num_layers=1, bidirectional=True)

        self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size)
        self.crf = CRF(self.tagset_size)

    def forward(self, sentence):
        embeds = self.word_embeds(sentence).view(len(sentence), 1, -1)
        lstm_out, _ = self.lstm(embeds)
        lstm_out = lstm_out.view(len(sentence), self.hidden_dim)
        lstm_feats = self.hidden2tag(lstm_out)
        return lstm_feats

    def loss(self, sentence, tags):
        feats = self.forward(sentence)
        return -self.crf(torch.unsqueeze(feats, 0), tags)

    def predict(self, sentence):
        feats = self.forward(sentence)
        return self.crf.decode(torch.unsqueeze(feats, 0))

5. 训练与评估

接下来,我们将使用训练数据对模型进行训练,并在每个epoch后打印损失值和准确率。

def train(model, optimizer, data):
    for epoch in range(10):
        total_loss = 0
        total_correct = 0
        total_count = 0
        for sentence, tags in data:
            model.zero_grad()
            loss = model.loss(sentence, tags)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            prediction = model.predict(sentence)
            total_correct += sum([1 for p, t in zip(prediction, tags) if p == t])
            total_count += len(tags)

        print(f"Epoch {epoch + 1}: Loss = {total_loss / len(data)}, Accuracy = {total_correct / total_count}")

6. 预测

最后,我们将使用训练好的模型对新的句子进行预测。

def predict(model, sentence):
    prediction = model.predict(sentence)
    return [p for p in prediction]

7. 总结

用训练好的模型对新的句子进行预测。

def predict(model, sentence):
    prediction = model.predict(sentence)
    return [p for p in prediction]

7. 总结

本文介绍了如何使用PyTorch搭建一个BiLSTM+CRF模型,并通过一个简单的命名实体识别(NER)任务来演示模型的训练和预测过程。希望这篇文章能帮助你理解BiLSTM+CRF模型的原理,并为你的实际项目提供参考作用哦。

更新精彩的模型搭建与应用请持续关注哦!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/601775.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

软件安全概述

软件定义是:计算机程序、规则和可能相关的文档。 软件是程序、数据和文档的集合体。 零日漏洞、零日攻击 零日漏洞是指未被公开披露的软件漏洞,没有给软件的作者或厂商以时间去为漏洞打补丁或是给出建议解决方案,从而攻击者能够利用这种漏洞破…

ROS:话题消息(Message)的定义与使用

目录 一、话题模型二、自定义话题消息2.1定义msg文件2.2在package.xml中添加功能包依赖2.3在CMakeLists.txt中添加编译选项2.4编译生成C头文件或Python库 三、创建代码并编译运行(C)3.1创建代码3.2编译 四、运行 一、话题模型 自定义一个消息类型“Pers…

python接口测试之测试报告

在本文章中,主要使用jenkins和编写的自动化测试代码,来生成漂亮的测试报告,关于什么是CI这些我就不详细的介绍了,这里我们主要是实战为主。 首先搭建java的环境,这个这里不做介绍。搭建好java的环境后,在h…

Python:Python编程:从入门到实践__超清版:Python标准库:线程

Python线程与安全 实现线程安全有多重方式,常见的包括:锁,条件变量,原子操作,线程本地存储等。 💚 1. 锁2. 条件变量3. 通过 join 阻塞当前线程4. 采用 sleep 来休眠一段时间5. 原子操作5.1 使用 threading…

【I2C】Linux I2C子系统分析

文章目录 一、I2C体系架构二、主要的结构体1. i2c_adapter2. i2c_algorithm3. i2c_driver4. i2c_client4.1 方式一:通过I2C bus number静态方式来创建4.2 方式二:通过Device Tree来创建4.3 方式三:直接通过i2c_new_device来创建4.3 方式四&am…

openEuler22.03制作openstack平台使用的镜像

系列文章目录 第一章 openEuler22.03制作openstack平台使用的镜像 文章目录 系列文章目录前言一、virt-manager上的准备工作1、网卡类型切换为virtio2、IDE驱动设置成Virtio3、Display设置成vnc3、虚拟机系统分区 二、安装普通工具包三、安装云化工具包1、安装工具包2、修改配…

数字化转型,企业为什么要转型?如何转型?

数字化转型是利用数字化技术(例如云计算、大数据、人工智能、物联网、区块链等)和能力来驱动组织商业模式创新和商业生态系统重构的途径和方法即是数字化转型。其目的是实现企业业务的转型、创新、增长。 核心强调了两点,其一是数字化技术的应…

每日一练 | 华为认证真题练习Day51

1、如下图所示,IPSec传输模式中AH的头部应该插入到以下哪个位置? A. 1 B. 2 C. 3 D. 4 2、以下哪种远程登录方式最安全? A. Telnet B. Stelnet v100 C. Stelnet v2 D. Stelnet v1 3、以下业务模块的ACL默认动作为permit的是&#xff1…

玩转 ChatGPT,看这条就够了,Prompt 最全中文合集

Prompt 最全中文合集 玩转 ChatGPT,看这条就够了! 🚀 简化流程:ChatGPT Shortcut 提供了快捷指令表,可以快速筛选和搜索适用于不同场景的提示词,帮助用户简化使用流程。 💻 提高生产力&#…

CSDN打出各种数学符号和数学公式

目录 1、基本四则运算2、指数对数3、根号、省略号、向量4、大(小)于等于号5、特殊符号、希腊字母符号6、累加累乘7、矩阵8、更改公式中的颜色 我们在用CSDN打出各种数学符号和数学公式时,需要学习一些关于LaTex的语法,在此做一个记…

java数组学习

一、数组的概述 1.数组的理解:数组(Array),是多个相同类型数据按一定顺序排列的集合, 并使用一个名字命名,并通过编号的方式对这些数据进行统一管理。 2.数组相关的概念: >数组名 >元素 >角标、下标、索引 >数组的长度:元素…

联通云数据库CUDB:基于openGauss打造新一代自主创新云原生数据库

总体概述 联通云彰显央企担当,围绕国家对信息技术基础软件的政策要求,开展数据库自主研发。在openGauss开源社区版软件基础上,聚焦政企市场,坚持内核创新,完善工具生态,基于海量云存储能力、存算分离架构…

React中的懒加载以及在Ice中实践

您好,如果喜欢我的文章,可以关注我的公众号「量子前端」,将不定期关注推送前端好文~ 前言 对于页面性能优化,组件懒加载是个比较不错的方案,并且在整个项目打包后,如果未做代码分割,构建出的文…

代理ip的优势、用途及注意事项

随着互联网的高速发展,代理ip的名气和地位也随着水涨船高。那么是什么让它们被我们所知悉的呢?下面我们就代理ip的优势、用途和注意事项来分析一下它为什么能迎合着互联网的发展而壮大自己的。 一、优势 每一个脱颖而出的产品必然有它的优势,…

Axure教程—菜单(中继器)

本文将教大家如何用AXURE中的中继器制作菜单(自动折叠其他菜单) 一、效果介绍 如图: 预览地址:https://iuek50.axshare.com 下载地址:https://download.csdn.net/download/weixin_43516258/87854640?spm1001.2014.30…

知识图谱简介

什么是知识图谱? 参考:知识图谱1、知识图谱2 本质上,知识图谱主要目标是用来描述真实世界中存在的各种实体和概念,以及他们之间的关系,因此可以认为是一种语义网络。 主要作用:通过数据,建立图…

智能自动化助力业务升级:探究低代码开发和业务流程自动化

当我们开始探索业务流程自动化(BPA)时,就证明我们已经真正进入到企业数字化转型的核心领域了——企业越来越关注如何通过创新技术来提高效率、降低成本并实现业务流程的自动化。在这个背景下,低代码开发平台和业务流程自动化成为了…

vue 滚动加载

在 Vue中,如果一个组件是一个 button,那就可以直接调用 input ()方法,将组件的 button放入到v-ui中。 然而在v-ui中,一个组件可能不止一个 button,而这些 button还需要从浏览器加载到 DOM树中。…

一个投喂ChatGPT大内容的小技巧

大家好,我是五竹。心血来潮整理了一份手册:《ChatGPT学习指南》并且将为小白们持续更新和GPT相关的资源和教程,专注于打造一部最好的GPT入门指南,欢迎大家转发、收藏、点赞支持!谨防失联! 至今还有很多人都…

渗透测试适合小白学习吗会让人感觉到无聊吗?

渗透测试是一项复杂的技能,需要具备扎实的计算机知识,对网络和系统安全有深入的理解和认识。对于初学者来说,建议先学习计算机网络、操作系统、编程语言等相关基础知识,了解渗透测试的概念、流程和常用工具。同时,需要…