【NLP相关】attention的代码实现

news2024/9/23 3:29:50

❤️觉得内容不错的话,欢迎点赞收藏加关注😊😊😊,后续会继续输入更多优质内容❤️

👉有问题欢迎大家加关注私戳或者评论(包括但不限于NLP算法相关,linux学习相关,读研读博相关......)👈

attention机制

(封面图由ERNIE-ViLG AI 作画大模型生成)

【NLP相关】attention的代码实现

Attention模型是现今机器学习领域中非常热门的模型之一,它可以用于自然语言处理、计算机视觉、语音识别等领域。本文将介绍Attention模型的代码实现。

1. attention机制的原理

首先,我们需要了解Attention模型的基本概念。Attention是一种机制,它可以用于选择和加权输入序列的不同部分,从而使得模型更加关注那些对输出结果更加重要的部分。在自然语言处理任务中,输入序列通常是由一些词语组成的,而输出序列通常是一个标签或者一句话。Attention模型可以帮助我们更好地理解输入序列中的每一个词语对输出序列的影响。关于attention的详细介绍,可以参见我的另一篇博客:深入理解attention机制(产生、发展、原理、应用和代码实现)

2. attention机制的代码实现

(1)基于PyTorch实现

接下来,我们将介绍如何使用PyTorch实现一个基本的Attention模型。我们假设输入序列是一个由 n n n个词语组成的序列,输出序列是一个由 m m m个标签组成的序列。首先,我们需要定义一个包含两个线性变换的网络层,分别用于将输入序列和输出序列的维度映射到一个相同的维度空间。代码如下所示:

class AttentionLayer(nn.Module):
    def __init__(self, input_size, output_size):
        super(AttentionLayer, self).__init__()
        self.input_proj = nn.Linear(input_size, output_size, bias=False)
        self.output_proj = nn.Linear(output_size, output_size, bias=False)

在定义完网络层之后,我们需要实现Attention的计算过程。在本文中,我们将使用加性Attention的计算方式。具体来说,我们需要计算每一个输入词语与输出标签之间的相似度,然后将相似度进行归一化处理,最终得到一个由 n n n个归一化的权重组成的向量。代码如下所示:

    def forward(self, inputs, outputs):
        inputs = self.input_proj(inputs) # (batch_size, n, input_size) -> (batch_size, n, output_size)
        outputs = self.output_proj(outputs) # (batch_size, m, output_size) -> (batch_size, m, output_size)
        scores = torch.bmm(inputs, outputs.transpose(1, 2)) # (batch_size, n, output_size) * (batch_size, output_size, m) -> (batch_size, n, m)
        weights = F.softmax(scores, dim=1) # (batch_size, n, m)
        return weights

在代码中,我们首先将输入序列和输出序列分别进行线性变换,并计算它们之间的相似度。然后,我们使用softmax函数将相似度进行归一化处理,从而得到一个 n × m n \times m n×m的归一化权重矩阵。

最后,我们可以将Attention计算的结果与输入序列相乘,得到一个由 m m m个加权输入向量组成的向量。代码如下所示:

class AttentionLayer(nn.Module):
    def __init__(self, input_size, output_size):
        super(AttentionLayer, self).__init__()
        self.input_proj = nn.Linear(input_size, output_size, bias=False)
        self.output_proj = nn.Linear(output_size, output_size, bias=False)
    
    def forward(self, inputs, outputs):
        inputs = self.input_proj(inputs) # (batch_size, n, input_size) -> (batch_size, n, output_size)
        outputs = self.output_proj(outputs) # (batch_size, m, output_size) -> (batch_size, m, output_size)
        scores = torch.bmm(inputs, outputs.transpose(1, 2)) # (batch_size, n, output_size) * (batch_size, output_size, m) -> (batch_size, n, m)
        weights = F.softmax(scores, dim=1) # (batch_size, n, m)
        context = torch.bmm(weights.transpose(1, 2), inputs) # (batch_size, m, n) * (batch_size, n, output_size) -> (batch_size, m, output_size)
        return context

在代码中,我们将归一化权重矩阵和输入序列进行矩阵乘法运算,得到一个由 m m m个加权输入向量组成的向量。这个向量就是Attention模型的输出结果。

至此,我们已经完成了Attention模型的代码实现。当然,这只是一个基本的Attention模型,它还可以通过增加更多的层来提升性能,比如Multi-Head Attention等。同时,在使用Attention模型时还需要考虑到一些细节问题,比如输入序列的长度不一定相同、输出序列的长度也不一定相同等。因此,Attention模型的具体实现方式还需要根据具体的任务来进行设计和调整。

(2)TensorFlow实现

在TensorFlow中,我们可以使用tf.keras.layers.Attention层来实现Attention机制。下面,我们将使用一个示例来演示如何在TensorFlow中使用Attention机制。

首先,我们需要导入必要的库和数据集。在这个示例中,我们将使用IMDB电影评论情感分类数据集,这是一个二元分类任务,我们需要将评论分为积极或消极两种情感。

import tensorflow as tf
from tensorflow.keras.datasets import imdb
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.layers import Input, Dense, Embedding, Flatten, LSTM, Bidirectional, Attention
from tensorflow.keras.models import Model
import numpy as np

# 加载IMDB数据集
max_features = 20000
maxlen = 200
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)
x_train = pad_sequences(x_train, padding='post', maxlen=maxlen)
x_test = pad_sequences(x_test, padding='post', maxlen=maxlen)

接下来,我们将使用Keras函数式API构建一个双向LSTM模型,并在其中加入Attention层。

# 构建模型
input_layer = Input(shape=(maxlen,))
embedding_layer = Embedding(max_features, 128)(input_layer)
lstm_layer = Bidirectional(LSTM(64, return_sequences=True))(embedding_layer)
attention_layer = Attention()([lstm_layer, lstm_layer])
flatten_layer = Flatten()(attention_layer)
output_layer = Dense(1, activation='sigmoid')(flatten_layer)
model = Model(inputs=input_layer, outputs=output_layer)

# 编译模型
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

在这个模型中,我们首先使用Embedding层将输入序列转换为向量表示,然后将其输入到一个双向LSTM层中。接下来,我们使用Attention层将LSTM层的输出与自身进行注意力计算,得到每个时间步的权重。最后,我们将加权后的输出进行展平,并通过一个全连接层得到二元分类的输出。

最后,我们可以训练和评估这个模型。

# 训练模型
model.fit(x_train, y_train, batch_size=128, epochs=5, validation_data=(x_test, y_test))

# 评估模型
loss, accuracy = model.evaluate(x_test, y_test, verbose=0)
print('Test accuracy:', accuracy)

3. attention模型的应用

(1)机器翻译

机器翻译是自然语言处理领域的一个重要任务,而Attention模型在机器翻译中的应用尤为广泛。在传统的机器翻译模型中,通常使用固定长度的向量来表示输入序列,并将其输入到一个循环神经网络(RNN)中进行处理。但是,这种方法存在一个问题,就是当输入序列很长时,模型会出现信息丢失的情况,无法捕捉到关键的上下文信息。为了解决这个问题,Attention模型被引入到了机器翻译中。在Attention模型中,我们不仅考虑到输入序列中每个词的信息,还将每个词的权重也作为输入,使得模型可以更加关注到重要的词汇信息。通过这种方式,Attention模型可以更加准确地进行翻译,并且在处理长文本时也可以避免信息丢失的问题。

(2)文本分类

在文本分类中,Attention模型可以帮助我们更好地捕捉到文本中的关键信息。传统的文本分类模型通常使用固定长度的向量来表示输入文本,并将其输入到一个全连接层中进行分类。但是,这种方法也存在信息丢失的问题,无法捕捉到文本中的重要信息。为了解决这个问题,Attention模型被引入到了文本分类中。在Attention模型中,我们将每个词的向量表示作为输入,并使用注意力机制来确定每个词的重要程度。通过这种方式,Attention模型可以更加准确地分类文本,并且在处理长文本时也可以避免信息丢失的问题。

(3)图像标注

在图像标注中,Attention模型可以帮助我们更好地理解图像中的内容,并生成更加准确的图像描述。传统的图像标注模型通常使用固定长度的向量来表示图像,并将其输入到一个循环神经网络中进行处理。但是,这种方法也存在信息丢失的问题,无法捕捉到图像中的重要信息。为了解决这个问题,Attention模型被引入到了图像标注中。在Attention模型中,我们将每个图像区域的向量表示作为输入,并使用注意力机制来确定每个区域的重要程度。通过这种方式,Attention模型可以更加准确地理解图像中的内容,并生成更加准确的图像描述。此外,Attention模型还可以帮助我们在图像标注中实现多模态输入,即将图像和文本结合起来进行标注,从而提高标注的准确性。

(4)文本生成

在文本生成任务中,Attention模型可以帮助我们更好地生成连贯、准确的文本。传统的文本生成模型通常使用循环神经网络来生成文本,但是在生成过程中,模型可能会出现重复、模糊等问题。为了解决这个问题,Attention模型被引入到了文本生成中。在Attention模型中,我们不仅使用循环神经网络来生成文本,还将每个词的向量表示作为输入,并使用注意力机制来确定每个词的生成概率。通过这种方式,Attention模型可以更加准确地生成文本,并且避免出现重复、模糊等问题。


❤️觉得内容不错的话,欢迎点赞收藏加关注😊😊😊,后续会继续输入更多优质内容❤️

👉有问题欢迎大家加关注私戳或者评论(包括但不限于NLP算法相关,linux学习相关,读研读博相关......)👈

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/392660.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

英飞凌Tricore实战系列导读

本文框架 1.系列概述1.1 外设理论及应用介绍1.2 基于TC3xx的MCAL各外设配置开发1.3 基于TC3xx的Davinci工程开发1.4 项目中问题排查经验分享1.5 其他相关话题分享2. 目前已发布系列文章汇总1.系列概述 英飞凌TC3xx以其强大的性能,扩展性,存储及安全性能在汽车电子中扮演着越…

0.基于知识图谱的知识建模详细方案

0.前言 知识图谱以结构化的形式描述客观世界中概念、实体及其关系。实体是客观世界中的事物,概念是对具有相同属性的事物的概括和抽象。本体是知识图谱的知识表示基础,可以形式化表示为,O={C,H, P,A,I},C 为概念集合,如事物性概念和事件类概念,H是概念的上下位关系集…

基于高灵敏度可编程线性霍尔传感器的液压缸位移检测装置(位移传感器)

目前国内工业自动化水平不断提高,工程机械、机床设备等面向自动化的同时,对控制系统的精度要求也越来越严格。在所有使用到液压的机械和设备中,液压缸是必不可少的部件,并且承载着重要的液压转换功能。而液压缸的位移检测在精确控…

Hexo+live2d | 如何把live2d老婆放进自己的博客

参考:Hexo添加Live2D看板娘最新教程live2d-widgetlive2d-widget-models网页/博客Hexo添加live2d游戏角色看板娘,简易添加,碧蓝航线等live2d新型游戏角色模型(moc3)live2d-moc3jsdelivr方法1可以直接去看参考文章的第一部分的第一篇…

【LeetCode】剑指 Offer 22. 链表中倒数第k个节点 p136 -- Java Version

题目链接:https://leetcode.cn/problems/lian-biao-zhong-dao-shu-di-kge-jie-dian-lcof/ 1. 题目介绍(22. 链表中倒数第k个节点) 输入一个链表,输出该链表中倒数第k个节点。为了符合大多数人的习惯,本题从1开始计数&…

C语言-基础了解-21-C输入/输出

C输入/输出 一、标准文件 C 语言把所有的设备都当作文件。所以设备(比如显示器)被处理的方式与文件相同。以下三个文件会在程序执行时自动打开,以便访问键盘和屏幕 文件指针是访问文件的方式,本节将讲解如何从键盘上读取值以及…

葫芦娃自动化平台-appium实战

一、配置jar包 右键用压缩工具打开jar 修改application.yml文件 name: test666 改为自己名字 libPath、drivenPath:改成自己电脑已有路径 修改appium.properties 文件,方式和上面一样 #APP名称 注意带apk后续 例:test.apk 此文件请放置在项目根目录下 …

Property ‘sqlSessionFactory‘ or ‘sqlSessionTemplate‘ are required 问题处理

错误日志: 解决办法: 1.查看配置文件是否正确 # MyBatis-plus配置 mybatis-plus:# 配置Mapper.xml映射文件mapper-locations: classpath*:/mapper/*Mapper.xml# 配置MyBatis数据返回类型别名(默认别名是类名)type-aliases-packa…

POSTGRESQL SQL 语句案例,一场由LIMIT 1 引发的“奇怪异像”

开头还是介绍一下群,如果感兴趣polardb ,mongodb ,mysql ,postgresql ,redis 等有问题,有需求都可以加群群内有各大数据库行业大咖,CTO,可以解决你的问题。加群请联系 liuaustin3 ,在新加的朋友会分到2群。最近一段工作…

JS 数组拍平(多种方法)

前言数组拍平(扁平化数组)是面试中常考查的知识点,之前发过一篇基础版的数组拍平 传送门通过之后的学习,又了解了一些其他的数组拍平的操作,详情请看下面的文章flatflat 方法可以直接用于数组拍平,不传入参…

Elasticsearch集群搭建、优化及实践

文章目录一、Elasticsearch集群1、Elasticsearch集群概念2、Elasticsearch集群安装3、安装Kibana4、测试集群状态二、Elasticsearch优化1、磁盘选择2、分片策略3、内存设置三、Elasticsearch实践1、ES自动补全2、创建索引3、安装logstash导入数据4、项目搭建5、自动补全功能6、…

nginx七大核心应用场景详解 解决生产中的实际问题 二次开发扩展

nginx七大核心应用场景详解 & 解决生产中的实际问题1、nginx的安装与简单配置1.1、环境准备1.2、nginx基本操作指令:1.3、安装成系统服务1.4、conf 配置文件说明2、虚拟主机2.1、nginx多主机配置2.2、二级域名与短网址解析3、基于反向代理的负载均衡3.1、跳转到…

【Mysql索引】五种索引类型

目录前期准备实操实战主键索引 primary唯一索引 unique执行效果普通索引 noraml执行效果全文索引 fulltext执行效果组合索引执行效果前期准备 创建一个表,如果有测试的表也可以用(把主键id设置为自增)如果给字段添加的值有中文,需…

vercel和netlify部署代码并解决接口代理转发的问题(和Nginx功能一样)

前言 部署过程就不说了,部署完成后是这样子的 然后访问链接,无法访问 解决 依次点击 Settings–>Domains,在输入框中输入你的域名并点击 Add 按钮。 以此域名为例子demo.gshopfront.dreamlove.top为例,点击添加 我们前往域名管理系统,记录下绿色的值以腾讯云的…

Android Surface 分析+源码

Surface的创建涉及三个部分: App 进程 App需要将自己的内容显示在屏幕上,所以App负责发起Surface创建请求,创建好Surface后, 就可以直接可以在canvas上画图等,最终都会保存到Surface里的buffer里,最后由SurfaceFlinge…

使用docker搭建WordPress博客

文章目录一、前置条件二、拉取镜像三、安装工具及依赖1.安装2.升级pip3.获取docker-compose方法一另一种方法(推荐)四、创建容器,启用镜像,映射端口五、访问并设置一、前置条件 主机:20.0.0.142 安装docker docker的安…

SQL-刷题技巧-删除重复记录

一. 原题呈现 牛客 SQL236. 删除emp_no重复的记录,只保留最小的id对应的记录。 描述: 删除emp_no重复的记录,只保留最小的id对应的记录。 drop table if exists titles_test; CREATE TABLE titles_test (id int(11) not null primary key…

Matlab论文插图绘制模板第80期—羽状图(Feather)

在之前的文章中,分享了很多Matlab线图的绘制模板: 进一步,再来分享一种特殊的线图:羽状图。 先来看一下成品效果: 特别提示:Matlab论文插图绘制模板系列,旨在降低大家使用Matlab进行科研绘图的…

2.2 分治法的基本思想

分治法的基本思想是将一个规模为n的问题分解为化个规模较小的子问题&#xff0c;这些子问题互相独立且与原问题相同。递归地解这些子问题&#xff0c;然后将各子问题的解合并得到原问题的解。它的一般的算法设计模式如下&#xff1a;divide-and-conquer(P) { if ( P < no) a…

SRS4.0 源码分析- RTC模块相关类

前言 本文介绍SRS4.0涉及RTC模块的C类&#xff0c;主要包括RTC Server和Session相关的。 SrsGoApiRtcPlay 处理webrtc client的播放请求&#xff0c;解析client的offer&#xff0c;并且生成server的answer&#xff0c;并且为这次请求创建一个session。SrsRtcServer 监听udp端…