python-pytorch 下批量seq2seq+Bahdanau Attention实现问答1.0.000

news2025/3/15 14:52:10

python-pytorch 下批量seq2seq+Bahdanau Attention实现简单问答1.0.000

    • 前言
    • 原理看图
    • 数据准备
    • 分词、index2word、word2index、vocab_size
    • 输入模型的数据构造
    • 注意力模型
    • decoder的编写
    • 关于损失函数和优化器
    • 在预测时
    • 完整代码
    • 参考

前言

前面实现了 luong的dot 、general、concat注意力实现简单问答,这里参考官方文档,实现了python-pytorch 下批量seq2seq+Bahdanau Attention实现问答

原理看图

在这里插入图片描述
这里模型选择和官方不一样,官方选择的是GRU,我更喜欢使用LSTM,解码器和编码器都是如此。
意思大致思路是:

  1. 计算encoder的encoder_outputs、encoder_hn、encoder_cn
  2. 使用encoder_outputs、encoder_hn计算新的向量和注意力
  3. 在deconder中,以SOS单字开始,循环句子最大长度,在循环中,使用新的向量和单字SOS做cat计算得到decoder的LSTM输入数据,将该LSTM存起来,最后做cat计算得到decoder的输出

数据准备

结果类似还是采用前面的结构和数据

seq_example = [“你认识我吗”, “你住在哪里”, “你知道我的名字吗”, “你是谁”, “你会唱歌吗”, “谁是张学友”]
seq_answer = [“当然认识”, “我住在成都”, “我不知道”, “我是机器人”, “我不会”, “她旁边那个就是”]

分词、index2word、word2index、vocab_size

分词然后做基础准备,包括数据:index2word、word2index、vocab_size、最长的句子长度seq_length,和一些超参数的设置

输入模型的数据构造

  1. 长度要统一
  2. 问答的句子以EOS结尾,不足补0,如

tensor([[ 3, 4, 5, 6, 2, 0, 0],
[ 3, 7, 8, 9, 2, 0, 0],
[ 3, 10, 5, 11, 12, 6, 2],
[ 3, 13, 14, 2, 0, 0, 0],
[ 3, 15, 16, 6, 2, 0, 0],
[14, 13, 17, 2, 0, 0, 0]])

注意力模型

可以复用,用官方的即可

# Bahdanau
# query=hidden [layer_num,batch_size,hidden_size] keys=encoder_outputs  [seq_len,batch_size,hidden_size]
class Attention(nn.Module):
    def __init__(self):
        super(Attention, self).__init__()
        self.Wa = nn.Linear(hidden_size, hidden_size)
        self.Ua = nn.Linear(hidden_size, hidden_size)
        self.Va = nn.Linear(hidden_size, 1)

    def forward(self, query, keys):
        scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys))) #[seq_len,batch_size,1]
        scores = scores.permute(1,0,2).squeeze(2).unsqueeze(1)#[batch_size,1,seq_len]

        weights = nn.functional.softmax(scores, dim=-1)#[batch_size,1,seq_len]
        context = torch.bmm(weights, keys.permute(1,0,2))#[batch_size,1,hidden_size]

        return context, weights

decoder的编写

思路是,获得encoder的输出和hn后,计算得到向量,然后使用向量和目标的每一个字做cat计算,输入decoder的模型中,然后得出一个字的预测,循环完了以后,就会得到最大句子长度,最后做cat和softmax计算得到输出。另外,这里要区分训练和测试,训练的时候有target,测试的没有target数据。

关于损失函数和优化器

NLLLoss+Adam的组合优于CrossEntropyLoss+SGD的组合

在预测时

获取到模型输出,size是[batch_size,seq_len,vocab_size]后,对结果做topk计算,会得到每一字在vocab_size的概率,连接起来就是一句话

完整代码

# def getAQ():
#     ask=[]
#     answer=[]
#     with open("./data/flink.txt","r",encoding="utf-8") as f:
#         lines=f.readlines()
#         for line in lines:
#             ask.append(line.split("----")[0])
#             answer.append(line.split("----")[1].replace("\n",""))
#     return answer,ask

# seq_answer,seq_example=getAQ()



import torch
import torch.nn as nn
import torch.optim as optim
import jieba
import os
from tqdm import tqdm
 
seq_example = ["你认识我吗", "你住在哪里", "你知道我的名字吗", "你是谁", "你会唱歌吗", "谁是张学友"]
seq_answer = ["当然认识", "我住在成都", "我不知道", "我是机器人", "我不会", "她旁

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1695187.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

某神,云手机启动?

某神自从上线之后,热度不减,以其丰富的内容和独特的魅力吸引着众多玩家; 但是随着剧情无法跳过,长草期过长等原因,近年脱坑的玩家多之又多,之前米家推出了一款云某神的app,目标是为了减少用户手…

Android9.0 MTK平台如何增加一个系统应用

在安卓定制化开发过程中,难免遇到要把自己的app预置到系统中,作为系统应用使用,其实方法有很多,过程很简单,今天分享一下我是怎么做的,共总分两步: 第一步:要找到当前系统应用apk存…

PostgreSQL基本使用Schema

参考文章:PostgreSQL基本使用(3)Schema_pg数据库查询schema-CSDN博客 PostgreSQL 模式(Schema)可以理解为是一个表的集合(或者所属者)。 例如:在 MySQL 中,Scheam 是库&…

储能服务系统架构:实现能源可持续利用的科技之路

随着可再生能源的快速发展和能源系统的智能化需求增加,储能技术作为能源转型和可持续发展的关键支撑之一,备受各界关注。储能服务系统架构的设计和实现将对能源行业产生深远影响。本文将探讨储能服务系统架构的重要性和关键组成部分,旨在为相…

安卓开发--安卓使用Echatrs绘制折线图

安卓开发--安卓使用Echatrs绘制折线图 前期资料安卓使用Echarts绘制折线图1.1 下载 Echarts 安卓资源1.2 新建assets文件1.3 新建布局文件1.4 在布局文件中布局WebView1.5 在活动文件中调用 最终效果 前期资料 Echarts 官网样式预览: https://echarts.apache.org/examples/zh/…

使用Webcam实现摄像头的开启和关闭,并保存和复制图片

实现思路 0,将webcam的jar文件传入项目中 1,显示摄像头的地方:创建一个画板,在画板上添加开启和关闭按钮 2,设置开启和关闭功能:创建一个类实现动作监听器,进而实现监听动作按钮 3&#xff…

《我的阿勒泰》读后感

暂没时间写,记录在此,防止忘记,后面补上!!! 【经典语录】 01、如果天气好的话,阳光广阔地照耀着世界,暖洋洋又懒洋洋。这样的阳光下,似乎脚下的每一株草都和我一样,也把身子完全舒展开了。 02、…

Jmeter预习第1天

Jmeter参数化(重点) 本质:使用参数的方式来替代脚本中的固定为测试数据 实现方式: 定义变量(最基础) 文件定义的方式(所有测试数据都是固定的情况下[死数据],eg:注册登录&#xff0…

为了“降本增效”,我用AI 5天将SpringBoot迁移到了Nodejs

背景 大环境不好,各行各业都在流行“降本增效”,IT行业大肆执行“开猿节流”,一顿操作效果如何?普通搬砖人谁会在乎呢。 为了收紧我的口袋,决定从头学习NodejsTypeScript,来重写我的Java后端服务。 其实这…

【ECharts】数据可视化

目录 ECharts介绍ECharts 特点Vue2使用EChats步骤安装 ECharts引入 ECharts创建图表容器初始化图表更新图表 示例基本柱状图后台代码vue2代码配置 组件代码运行效果 基本折线图示例代码组件 基础饼图示例代码后台前端配置组件运行效果 其他 ECharts介绍 ECharts 是一个由百度开…

找不到msvcr110.dll无法继续执行代码的原因分析及解决方法

在计算机使用过程中,我们经常会遇到一些错误提示,其中之一就是找不到msvcr110.dll文件。这个错误通常发生在运行某些程序或游戏时,系统无法找到所需的动态链接库文件。为了解决这个问题,下面我将介绍5种常见的解决方法。 一&#…

重学java 44.多线程 Lock锁的使用

昨日之深渊,今日之浅谈 —— 24.5.25 一、Lock对象的介绍和基本使用 1.概述 Lock是一个接口 2.实现类 ReentrantLock 3.方法 lock()获取锁 unlock()释放锁 4.Lock锁的使用 package S78Lock;import java.util.concurrent.locks.Lock; import java.util.concurrent.lo…

类与对象:接口

一.概念 接口(英文:Interface),在JAVA编程语言中是一个抽象类型,是抽象方法的集合,接口通常以interface来声明。 二.语法规则 与定义类相似,使用interface关键词。 Idea可以在开始时直接创建…

【控制实践——二轮平衡车】【三】基于PID的直立控制

传送门 系列博客前言直立运动分析基于PID控制器的直立控制角度环控制角速度控制总结 电机转速的控制前言电机转速控制 结语 系列博客 【控制实践——二轮平衡车】【一】运动分析及动力学建模 【控制实践——二轮平衡车】【二】实物设计和开源结构&代码 【控制实践——二轮…

常见 JVM 面试题补充

原文地址 : 26 福利:常见 JVM 面试题补充 (lianglianglee.com) CMS 是老年代垃圾回收器? 初步印象是,但实际上不是。根据 CMS 的各个收集过程,它其实是一个涉及年轻代和老年代的综合性垃圾回收器。在很多文章和书籍的划分中&…

Scrapy顺序执行多个爬虫

Scrapy顺序执行多个爬虫 有两种方式: 第一种:bat方式运行 新建bat文件 cd C:\python_web\spiders\tiktokSelenium & C: & scrapy crawl spider1 & scrapy crawl spider2 & scrapy crawl spider3 & scrapy crawl spider4 第二种&a…

kubeadm部署k8s v1.28

一、主机准备 主机硬件配置说明 作用IP地址操作系统配置k8s-master01192.168.136.55openEuler-22.03-LTS-SP12颗CPU 4G内存 50G硬盘k8s-node01192.168.136.56openEuler-22.03-LTS-SP12颗CPU 4G内存 50G硬盘k8s-node02192.168.136.57openEuler-22.03-LTS-SP12颗CPU 4G内存 50G…

windows 11上自带时间管理-番茄工作法

在 Windows 11 中,你可以使用 专注 功能来最大程度地减少干扰,帮助你保持专注。 专注的工作原理 专注时段打开后,将会出现以下情况: 专注计时器将显示在屏幕上 请勿打扰将打开 任务栏中的应用不会闪烁发出提醒 任务栏中应用的…

MySQL 8.4.0 LTS 变更解析:I_S 表、权限、关键字和客户端

↑ 关注“少安事务所”公众号,欢迎⭐收藏,不错过精彩内容~ MySQL 8.4.0 LTS 已经发布 ,作为发版模型变更后的第一个长期支持版本,注定要承担未来生产环境的重任,那么这个版本都有哪些新特性、变更,接下来少…

UIKit之猜图器Demo

需求 实现猜图器Demo 功能分解: 1>下一题切换功能 2>点击图片后能放大并背景变暗(本质是透明度的变化)。再次点击则缩小,删除暗色背景。 3> 答案区按钮点击则文字消失,选择区对应文字恢复。 4> 选择区…