Transformer应用之构建聊天机器人(二)

news2024/11/26 15:50:07

四、模型训练解析

在PyTorch提供的“Chatbot Tutorial”中,关于训练提到了2个小技巧:

  • 使用”teacher forcing”模式,通过设置参数“teacher_forcing_ratio”来决定是否需要使用当前标签词汇来作为decoder的下一个输入,而不是把decoder当前预测出来的词汇当做decoder的下一个输入,这是因为存在这样的情况,如果当前预测出来的词汇跟输入词汇从语义上来讲没有多大关联时,如果继续使用预测出来的词汇来训练模型,有可能就会造成比较大的预测偏差,从而导致模型训练后的预测效果很差,如果改为直接使用输入词汇对应的目标词汇(标签)来作为decoder的下一个输入,相当于进行强制纠偏,使decoder训练时输出与输入之间不至于出现偏差很大的情况。
  • 第2个小技巧是使用梯度裁剪(Gradient Clipping),这是一种常用的防止梯度爆炸的技术。在深度学习训练过程中,因为网络层数较多,梯度可能会非常大,导致模型无法收敛。梯度裁剪的目的就是限制梯度的大小,使其不超过一个预设的阈值,从而避免梯度爆炸的问题。

训练过程如下:

  1. 输入语句正向传播通过encoder
  2. 使用SOS token作为decoder的初始输入,使用encoder的final hidden state来初始化decoder的hidden state
  3. Decoder端根据输入单步执行产生输出
  4. 如果执行”teacher forcing”模式,则把当前对应的目标词汇(标签)作为decoder的下一个输入,否则使用当前decoder的输出词汇作为decoder的下一个输入
  5. 计算并累加损失
  6. 执行反向传播
  7. 执行梯度裁剪
  8. 更新decoder和encoder的模型参数

以下是代码示例:

以下是Transformer模型训练代码示例,

  • 首先把输入sequence(对话输入),输出sequence(对话输出),以及各自的mask传入模型做正向传播
  • 计算预测结果与标签的损失,然后反向传播更新模型参数
  • 训练时可以使用验证集(dev dataset)对训练效果进行评估

五、模型预测(推理)过程解析

下面这个图描述了Transformer的预测推理过程:

  • 假设使用两个encoder和两个decoder来构成这个Transformer模型,首先把输入语句转为embedding词向量,并加入位置编码信息
  • 正向传播通过encoder1,它的输出再通过encoder2,期间会使用多头注意力机制对输入序列中的每个词向量并行地进行注意力Q,K,V的计算
  • Decoder1使用<START> token进行初始化,并使用带掩码多头注意力机制进行计算,并且需要根据前面encoder2的输出进行注意力的计算,然后输出预测得到的词汇
  • Decoder1输出的词汇作为decoder2的输入,同样decoder2在进行多头注意力计算时也需要使用encoder2的注意力计算输出结果
  • Decoder2的输出传入线性层,之后使用Softmax函数转为0到1之间的概率,然后可以使用greedy search(贪心解码)算法得到概率最高的词汇作为预测结果

下面是预测相关代码的示例:

再来看下PyTorch提供的聊天机器人样例的预测操作:

  • 用户输入正向传播通过encoder模型
  • 把encoder的final hidden layer作为decoder模型的first hidden input
  • 使用SOS_token作为decoder的第一个输入来初始化模型
  • decoder根据encoder的输出(上篇文章提到的“Luong attention”注意力机制计算),以及当前decoder的输入,hidden state来输出预测得到的词汇(迭代操作)
  • 使用Softmax计算概率并根据概率获取最有可能出现的词汇
  • 把当前预测得到的词汇作为下一个decoder的输入
  • 收集所有预测得到的词汇

以下是预测相关代码的示例:

六、聊天机器人对话效果解析

基于Transformer的聊天机器人和PyTorch提供的聊天机器人都使用同样的训练语料(“Cornell Movie-Dialogs Corpus.”)进行训练,基于Transformer的聊天机器人模型训练了20个epochs,输入语句最大长度设置为60,PyTorch提供的聊天机器人训练配置如下:

clip = 50.0

teacher_forcing_ratio = 1.0

learning_rate = 0.0001

decoder_learning_ratio = 5.0

n_iteration = 4000

print_every = 1

save_every = 500

使用同样的测试对话语料分别对两个模型进行测试,基于Transformer模型的对话测试结果如下:

PyTorch提供的聊天机器人对话测试结果如下:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/573673.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux:查看进程。

Linux&#xff1a;查看进程。 windows linux TTY如果是&#xff1f;说明是不是终端(控制台)启动的&#xff0c;而是系统内部自己启动的。 TIME是启动Linux后&#xff0c;这个进程一共占用了cpu多少时间00…

《Spring Guides系列学习》guide46 - guide50

要想全面快速学习Spring的内容&#xff0c;最好的方法肯定是先去Spring官网去查阅文档&#xff0c;在Spring官网中找到了适合新手了解的官网Guides&#xff0c;一共68篇&#xff0c;打算全部过一遍&#xff0c;能尽量全面的了解Spring框架的每个特性和功能。 接着上篇看过的gui…

《Python安全攻防:渗透测试实战指南》极致经典,学完即可包吃包住

前言 网络江湖&#xff0c;风起云涌&#xff0c;攻防博弈&#xff0c;从未间断&#xff0c;且愈演愈烈。从架构安全到被动纵深防御&#xff0c;再到主动防御、安全智能&#xff0c;直至进攻反制&#xff0c;皆直指安全的本质——攻防。未知攻&#xff0c;焉知防! 每一位网络安…

【Python】循环语句 ② ( while 嵌套循环 | 代码示例 - while 嵌套循环 )

文章目录 一、while 嵌套循环1、while 嵌套循环语法2、代码示例 - while 嵌套循环 一、while 嵌套循环 1、while 嵌套循环语法 while 嵌套循环 就是 在 外层循环 中 , 嵌套 内层循环 ; while 嵌套循环 语法格式 : while 外层循环条件:外层循环操作1外层循环操作2while 内存循…

VuePress + GitHub Actions 自动部署

文章目录 前言背景GitHub Actions简介基本概念引用 Actionworkflow 文件 自动部署创建 Action权限问题 小结参考文献 前言 我的第二本开源书籍《后台开发命令 365》上线啦。 为了方便阅读&#xff0c;使用 VuePress 将之前记录的后台常用 Linux 命令博文整理成一个系统的开源…

路径规划算法:基于阴阳对优化的路径规划算法- 附代码

路径规划算法&#xff1a;基于阴阳对优化的路径规划算法- 附代码 文章目录 路径规划算法&#xff1a;基于阴阳对优化的路径规划算法- 附代码1.算法原理1.1 环境设定1.2 约束条件1.3 适应度函数 2.算法结果3.MATLAB代码4.参考文献 摘要&#xff1a;本文主要介绍利用智能优化算法…

Compose 没有 inputType 怎么过滤(限制)输入内容?这题我会!

前言 闲话 在我之前的文章 《Compose For Desktop 实践&#xff1a;使用 Compose-jb 做一个时间水印助手》 中&#xff0c;我埋了一个坑&#xff0c;关于在 Compose 中如何过滤 TextField 的输入内容。时隔好几个月了&#xff0c;今天这篇文章就是来填这个坑的。 为什么需要…

Doris

Aggregate 模型 是相同key的数据进行自动聚合的表模型。表中的列按照是否设置了 AggregationType&#xff0c;分为 Key&#xff08;维度列&#xff09;和 Value&#xff08;指标列&#xff09;&#xff0c;没有设置 AggregationType 的称为 Key&#xff0c;设置了 Aggregation…

散列表(哈希表)

目录 散列表 散列函数 散列表常用函数 1. 直接定址法 2. 除留余数法 2.1. exmple 3. 数字分析法 4. 平方取中法 5. 折叠法 处理冲突的方法 1. 开放定址法---线性探测 2. 二次探测法 3. 再Hash法 4. 拉链法(链地址法) 散列表&#xff08;Hash table&#xff0c;也…

Redis缓存击穿及解决问题

缓存击穿的意思是对于设置了过期时间的key,缓存在某个时间点过期的时候&#xff0c;恰好这时间点对这个 Key有大量的并发请求过来&#xff0c;这些请求发现缓存过期- -般都会从后端DB加载数据并回设到缓存&#xff0c;这个时候大并发的请求可能会瞬间把DB压垮。 解决方案有两种…

第五十四天学习记录:C语言进阶:动态内存管理Ⅱ

常见的动态内存错误 1、对NULL指针的解引用操作 int* p(int*)malloc(4); //p进行相关的判断 *p10;//malloc开辟空间失败&#xff0c;有可能对NULL指针解引用 free(p); pNULL;2、对动态开辟的内存的越界访问 int* p(int*)malloc(40);//10个int if(p!NULL) {int i0;//越界for(…

微服务项目租房网

文章目录 一、租房网项目的介绍1、使用的技术介绍2、使用的组件和开发工具的版本以及作用3、项目模块结构4、项目总体架构 二、环境搭建1、启动前端服务2、CentOS7各个组件的安装2.1 安装Docker2.2 安装JDK2.3 安装Redis(6390)2.4 安装FastDFS(8888)2.5 安装MongoDB(27017)2.6 …

Niagara—— 概述

目录 一&#xff0c;核心组件 Systems Emitters Modules Parameters 二&#xff0c;创建系统或发射器向导 System向导 Emetter向导 三&#xff0c;Niagara VFX工作流程 创建系统 创建或添加发射器 创建或添加模块 Niagara是最新一代VFX系统&#xff0c;无需程序员…

Junit测试框架详解

目录 Junit框架 导入Junit到项目 Junit注解 Test Disabled BeforeAll / AfterAll BeforeEach / AfterEach 参数化 单参数 多参数 CSV获取参数 方法获取参数 断言 assertEquals / assertNotEquals assertNull / assertNotNull 用例执行顺序 测试套件Suite 指定…

使用IIS创建WEB服务

文章目录 前言一、Web服务是什么&#xff1f;1.Web服务概述2.如何获取网页资源3.常见Web服务端软件4.什么是IIS 二、安装IIS1.安装Web服务器角色2.准备网页文件3.配置Web站点4.客户端浏览例&#xff1a;配置IIS站点 三、虚拟主机概述1.虚拟Web主机2.虚拟主机的几种类型3.基于端…

软考信管高级——进度管理

进度管理内容 缩短活动工期方法 赶工&#xff0c;投入更多资源或增加工作时间&#xff0c;以缩短关键活动的工期快速跟进&#xff0c;并行施工&#xff0c;以缩短关键路径长度使用高素质的资源或经验更丰富的人员减小活动范围或降低活动要求改进方法或技术&#xff0c;以提高…

活动回顾|解锁 AIGC 密码,探寻企业发展新商机

5月24日&#xff0c;Google Cloud 与 Cloud Ace 联合主办的线下活动顺利落下帷幕。 本次活动&#xff0c;有近 40 位企业精英到场支持。三位 Google Cloud 演讲嘉宾就本次活动主题&#xff0c;为大家带来了比较深度的演讲内容&#xff0c;干货满满。 &#xff08;*以下的嘉宾演…

期末复习总结【MySQL】聚合查询 + 多表联合查询(重点)

文章目录 前言一、聚合查询1, 聚合函数2, 聚合函数使用示例3, GROUP BY 子句4, HAVING 子句 二、联合查询(重点)1, 笛卡尔积2, 内连接2.1, 示例12.2, 示例22.3, 示例3 3, 外连接4, 自连接 总结 前言 各位读者好, 我是小陈, 这是我的个人主页, 希望我的专栏能够帮助到你: &#…

存量时代下,互联网玩家如何“自我造血”?

毫无疑问&#xff0c;互联网已经进入存量时代。 在过去高增长的增量时代&#xff0c;许多互联网企业追求规模效应&#xff0c;痴迷于“先规模后盈利”的打法&#xff0c;力图用规模构建护城河。然而&#xff0c;随着行业整体增长速度放缓&#xff0c;规模扩张变得更为艰难&…

面了个字节跳动拿 38K 出来的测试,让我见识到了跳槽的天花板

最近内卷严重&#xff0c;各种跳槽裁员&#xff0c;相信很多小伙伴也在准备金九银十的面试计划。 作为一个入职5年的老人家&#xff0c;目前工资比较乐观&#xff0c;但是我还是会选择跳槽&#xff0c;因为感觉在一个舒适圈待久了&#xff0c;人过得太过安逸&#xff0c;晋升涨…