论文速览【序列模型 seq2seq】—— 【Ptr-Net】Pointer Networks

news2024/11/24 20:09:30
  • 标题:Pointer Networks
  • 文章链接:Pointer Networks
  • 参考代码(非官方):keon/pointer-networks
  • 发表:NIPS 2015
  • 领域:序列模型(RNN seq2seq)改进 / 深度学习解决组合优化问题
  • 【本文为速览笔记,仅记录核心思想,具体细节请看原文】

  • 摘要:我们引入了一种新的神经网络结构,用于学习一个输出序列的条件概率,其中输出序列的元素是对应于输入序列位置的离散标记。这类问题不能通过Seq2Seq模型和神经图灵机等现有方法轻松解决,因为(这些问题中)输出每一步的目标类别数量取决于输入的长度,而输入的长度是可变的。排序可变长度序列、各种组合优化问题等属于这类问题。我们的模型使用了最近提出的神经注意力机制,来解决可变长度输出字典的问题。与先前的注意力机制不同的是,我们的模型不是在每个解码器步骤中将编码器的隐藏单元与上下文向量混合,而是使用注意力作为指针来选择输入序列的成员作为输出。我们将这种结构称为指针网络(Ptr-Net)。我们在平面凸包(planar convex hulls)、计算德劳内三角剖分(computing Delaunay triangulations)和旅行商问题(TSP)三个有挑战性的几何问题上验证了指针网络有能力以 Data-driven 的形式学到近似解。指针网络不仅改进了具有输入注意力的序列到序列模型,还能够推广到可变长度输出字典。我们展示了训练的模型在超过其训练最大长度的情况下也能泛化

文章目录

  • 0. 本文考虑的问题
  • 1. 传统方法及其问题
    • 1.1 Sequence-to-Sequence Model
    • 1.2 Content Based Input Attention
    • 1.3 问题
  • 2. 本文方法
  • 3. 实验
  • 4. 总结

0. 本文考虑的问题

  • 本文主要考虑那些 “输出序列是离散的,并对应于输入序列中位置” 的 Seq2Seq 问题。实验的问题包括

    1. 平面凸包问题planar convex hulls:给定平面上若干个点的坐标,输出一组点的索引,使得这些点围成的多边形可以覆盖所有点
    2. 计算德劳内三角剖分computing Delaunay triangulations:给定平面上若干个点的坐标,以点索引形式输出德劳内三角剖分结果(这是一种以最近的三点形成三角形,且各线段皆不相交的三角网剖分方式)
      在这里插入图片描述

    除了以上两个例子外,很多组合优化问题也具有这种形式,作者测试了 Tsp 问题。作者开源了以上三类问题的数据集

  • 注意这类问题的特点是:输出序列的每个元素都是输入序列包含的位置索引,输入序列长度 = 输出索引范围

1. 传统方法及其问题

1.1 Sequence-to-Sequence Model

  • 本文是对 Seq2Seq 模型的一个改进,Seq2Seq 模型用于把一个序列转换成另外一个序列,且不要求输入序列和输出序列等长,典型应用有机器翻译等
    在这里插入图片描述
    如上图所示,Seq2Seq 模型通常使用 RNN 及其变种(LSTM/GRU)以 encoder-decoder 结构构建。RNN 类模型内部有一个隐状态代表目前积累的信息,每读入一个序列样本就将其更新,隐变量值可以较好地捕获前驱序列特征。通常使用两个独立的 RNN 模型,一个作为 encoder 提取输入序列的特征,另一个作为 decoder 以 Autoregress 形式解码得到输出序列
  • 考虑第 0 节的问题,形式化地讲,给定训练样例 ( P , C P ) (\mathcal{P}, \mathcal{C^P}) (P,CP),Seq2Seq 模型使用参数模型计算条件概率
    p ( C P ∣ P ; θ ) = p ( C 2 ∣ C 1 , P ; θ ) ⋅ p ( C 3 ∣ C 2 , C 1 , P ; θ ) ⋅ ⋅ ⋅ p ( C m ( P ) ∣ C m ( P ) − 1 , C m ( P ) − 2 , . . . , C 0 , P ; θ ) = ∏ i = 1 m ( P ) p ( C i ∣ C 1 , … , C i − 1 , P ; θ ) \begin{aligned} p(\mathcal{C}^{\mathcal{P}}|\mathcal{P};\theta) &= p (C_2|C_1,\mathcal{P};\theta)·p (C_3|C_2,C_1,\mathcal{P};\theta) ···p (C_{m(\mathcal{P})}|C_{m(\mathcal{P})-1},C_{m(\mathcal{P})-2},...,C_0,\mathcal{P};\theta) \\ &=\prod_{i=1}^{m(\mathcal{P})}p (C_{i}|C_{1},\ldots,C_{i-1},\mathcal{P};\theta) \end{aligned} p(CPP;θ)=p(C2C1,P;θ)p(C3C2,C1,P;θ)⋅⋅⋅p(Cm(P)Cm(P)1,Cm(P)2,...,C0,P;θ)=i=1m(P)p(CiC1,,Ci1,P;θ) 其中 P = { P 1 , … , P n } \mathcal{P}=\{P_{1},\ldots,P_{n}\} P={P1,,Pn} 是包含 n n n 个向量的输入序列(上图中的 v v v), C P = { C 1 , … , C m ( P ) } \mathcal{C}^{\mathcal{P}}=\{C_{1},\ldots,C_{m(\mathcal{P})}\} CP={C1,,Cm(P)} 是由 m ( P ) m(\mathcal{P}) m(P) 个索引组成的序列,每个索引取值范围为 [ 1 , n ] [1,n] [1,n]。直观地看这个条件概率就是模型解码出目标序列的概率。注意目标序列的长度 m ( P ) m(\mathcal{P}) m(P) 通常取决于 P \mathcal{P} P。学习目标是最大化训练集中所有样本的上述概率之和,即
    θ ∗ = arg max ⁡ θ ∑ P , C P log ⁡ p ( C P ∣ P ; θ ) , \theta^{*}=\operatorname*{arg\,max}_{\theta}\sum_{\mathcal{P},\mathcal{C}^{ \mathcal{P}}}\log p(\mathcal{C}^{\mathcal{P}}|\mathcal{P};\theta), θ=θargmaxP,CPlogp(CPP;θ), 训练之后评估阶段,给定输入序列 P \mathcal{P} P,使用学习到的参数 θ ∗ \theta^{*} θ 选择具有最高概率的序列
    C ^ P = arg max ⁡ C P p ( C P ∣ P ; θ ∗ ) \hat{\mathcal{C}}^{\mathcal{P}}=\operatorname*{arg\,max}_{\mathcal{C}^{\mathcal{P}}}p(\mathcal{C}^{\mathcal{P}}|\mathcal{P};\theta^{*}) C^P=CPargmaxp(CPP;θ) 由于输出序列空间大小为 n m ( P ) n^{m(\mathcal{P})} nm(P),找到真正的最大概率输出序列的计算量太大,工程上通常使用贪心或者 beam search 方法进行解码

1.2 Content Based Input Attention

  • RNN 类模型只能利用隐状态间接地获取之前序列的信息,由于隐藏状态维度一定远远小于之前的变长序列所有样本的连接维度,这种做法无可避免地会损失一些信息。一种补偿方式是引入额外的 Attention 模块,它对整个输入序列的所有 hidden state e 1 , . . . , e n e_1,...,e_n e1,...,en 构造 key 向量,之后在任意第 i i i 个解码位置,用其 hidden state d i d_i di 构造 query 并和整个输入序列计算 attention,根据 attention 结果汇聚(加权平均)整个输入序列的 hidden state,最后用得到的结果和 d i d_i di 做 concatenate 来增强解码时信息输入,缓解信息损耗问题。下图是一个示意
    在这里插入图片描述

  • 形式化地,设 encoder 和 decoder 的隐藏状态为 ( e 1 , … , e n ) (e_{1},\ldots,e_{n}) (e1,,en) ( d 1 , … , d m ( P ) ) (d_{1},\ldots,d_{m(\mathcal{P})}) (d1,,dm(P)),如下计算每个输出时刻 i i i 的附加信息
    u j i = v T tanh ⁡ ( W 1 e j + W 2 d i ) j ∈ ( 1 , … , n ) a j i = softmax ⁡ ( u j i ) j ∈ ( 1 , … , n ) d i ′ = ∑ j = 1 n a j i e j \begin{aligned} u_{j}^{i} & =v^{T} \tanh \left(W_{1} e_{j}+W_{2} d_{i}\right) & j \in(1, \ldots, n) \\ a_{j}^{i} & =\operatorname{softmax}\left(u_{j}^{i}\right) & j \in(1, \ldots, n) \\ d_{i}^{\prime} & =\sum_{j=1}^{n} a_{j}^{i} e_{j} & \end{aligned} ujiajidi=vTtanh(W1ej+W2di)=softmax(uji)=j=1najiejj(1,,n)j(1,,n) 注意这里使用了比较早期的加性注意力,向量 v v v 和矩阵 W 1 , W 2 W_1,W_2 W1,W2 是三组要学习的参数。最后用增强后的 [ d i , d i ′ ] [d_i, d_i'] [di,di] 作为隐状态进行解码。这种方式相对 1.1 节的朴素 Seq2Seq 方法有显著性能提高

1.3 问题

  • 以上两个方法虽然也能部分解决第 0 节的问题,但它们都有一个显著缺陷,即处理问题的尺度无法随着输入泛化:对于每个不同的输入长度 n n n 都要单独训练一个模型。这个问题的本质在于模型无法动态地从输入序列中构造词表,训练时都是事先根据问题规模设置好词表大小的

2. 本文方法

  • 作者解决 1.3 节问题的思路很直接,他注意随着输入序列长度的变化,attention 范围可以自适应地变化,所以解码过程中只要想办法自回归地让 attention 像指针一样指出输入序列中的目标位置即可。这其实是对 1.2 节 decoder 的一种简化,我们不再需要根据 attention 汇聚特征再做分类任务,而是直接用经过 softmax 的 attention 向量做分类,这样输出空间可以根据输入序列长度自动调整。示意图如下
    在这里插入图片描述
    图(a) 是 1.1 节的朴素 Seq2Seq 模型,图(b) 是作者提出的指针网络模型
  • 形式化地,如下用 attention 机制改写 1.1 节中对 p ( C i ∣ C 1 , … , C i − 1 , P ) p(C_{i}|C_{1},\ldots,C_{i-1},\mathcal{P}) p(CiC1,,Ci1,P) 建模的方式
    u j i = v T tanh ⁡ ( W 1 e j + W 2 d i ) j ∈ ( 1 , … , n ) p ( C i ∣ C 1 , … , C i − 1 , P ) = softmax ⁡ ( u i ) \begin{aligned} u_{j}^{i} & =v^{T} \tanh \left(W_{1} e_{j}+W_{2} d_{i}\right) \quad j \in(1, \ldots, n) \\ p\left(C_{i} \mid C_{1}, \ldots, C_{i-1}, \mathcal{P}\right) & =\operatorname{softmax}\left(u^{i}\right) \end{aligned} ujip(CiC1,,Ci1,P)=vTtanh(W1ej+W2di)j(1,,n)=softmax(ui) 这里 u j i u_j^i uji 是长度为 n n n 的注意力得分向量, softmax ⁡ \operatorname{softmax} softmax 操作将其转换为输入序列上的分布,直接把这个 attention 分布看作在尺寸为 n n n 的词表上做分类时的 softmax 分布,使用交叉熵损失进行优化即可

3. 实验

  • 这里仅介绍 TSP 上的结果,另外两个问题详见论文

    TSP 问题是说给定一个城市列表,希望找到一个最短的路线,要求把每个城市访问一次并能返回到起点。作者假设两个城市之间的距离是对称的,即 A->B 的距离 = B-> A 的距离。

    1. 数据生成: 任意训练样本 ( P , C P ) (\mathcal{P}, \mathcal{C^P}) (P,CP) 中, P = { P 1 , … , P n } \mathcal{P}=\{P_{1},\ldots,P_{n}\} P={P1,,Pn} 是在 [0,1]×[0,1] 区间中随机采样的 n n n 个笛卡尔坐标, C P = { C 1 , . . . , C n } \mathcal{C^P} = \{C_1,...,C_n\} CP={C1,...,Cn} 是一个从 1 到 n 的排列,代表最优路线。为了一致性,在训练数据集中,数据集总是从第一个城市开始。为了生成精确的数据,城市数量不一样,所构建的数据集输出结果方式也不一样,具体地说: 在城市数量 n ≤ 20 n\leq 20 n20 的情况下,采用 Held-Karp 算法;对于 n > 20 n>20 n>20 的情况,作者考虑了 A1 A2 A3 三种启发式搜索算法,其中 A3 算法保证在离最优长度1.5倍的范围内找到一个解
    2. 模型设置:所有模型都使用了具有 256 或 512 个隐藏单元的单层 LSTM;使用随机梯度下降(SGD)训练;学习率为1.0;batch_size为128;随机均匀权重初始化从 -0.08 到 0.08;L2正则化梯度裁剪为 2.0
  • 实验结果如下
    在这里插入图片描述
    1. 由于 TSP 问题是有约束的(不能重复访问城市,也不能忽略城市),作者在解码时的波束搜索(beam Search)过程中过滤有效的结果。这种过滤过程在 n > 20 n>20 n>20 时是必须的。当 n = 30 n=30 n=30 时(超过训练时城市数量),失败率到达 30%;当 n = 40 n=40 n=40 时失败率上升到 98%
    2. 表中 OPTIMAL 列是真实最优结果,缺少 n = 50 n=50 n=50 的结果是因为计算复杂度太高了
    3. 表中第一组行显示了在 n n n 相同时使用最优数据训练的结果。注意到使用最差的算法 (A1) 数据来训练 Ptr-Net 时,模型优于其试图模仿的A1算法(6.42 < 6.46)
    4. 表中第二组行显示了在5~20个城市的最佳数据上训练的 Ptr-Nets 如何能够推广到更多的城市。 结果对于n=25来说几乎是完美的,对于n=30来说是好的,但在40或更长的时间里似乎会崩溃(尽管如此,结果还是比随机策略要好得多)

4. 总结

  • Pointer Networks 天生具备从输入序列中提取元素的能力,因此它非常适合用来实现 “复制” 这个功能。NLP 领域很多研究者也确实把它用于复制源文本中的一些词汇。比如摘要任务,由于所需的词汇较多,非常适合使用复制的方法来复制一些词,目前Pointer Networks 已经称为了文本摘要方法中的利器。
  • 此外,在组合优化领域,Ptr-Nets 也得到了广泛的应用,并已成为组合优化问题的端到端方法的入门模型,后来基于此模型,研究者也进行了很多改进,比如与强化学习结合,将 Attention 换成 Transformer 中采用的Self- Attention等。总之,Ptr-Nets为组合优化的端到端解决办法起了一个好头,并促使广大研究者进行更加深入的研究

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1040773.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

StableAudio-大模型创作音乐的工具

音乐产业即将发生革命。 今天Stability AI&#xff0c;开源人工智能工具和模型之王&#xff0c;例如Stable DIffusion和StableLM&#xff0c;推出Stable Audio&#xff0c;其首款用于音乐和声音生成的人工智能产品。 音乐行业是出了名的难以进入。即使您有才华和动力&#x…

RFID技术在质量控制和生产追溯中的关键应用

在现代制造业中&#xff0c;质量控制和生产追溯是确保产品质量和合规性的关键环节。RFID技术已经成为实现这一目标的强大工具。本文将探讨RFID技术在质量控制和生产追溯中的关键应用&#xff0c;以及如何利用它来提高生产效率、确保产品质量和满足合规性要求。 生产过程追溯 …

Android11 适配

一、修改targetSdkVersion为30 将build.gradle的目标版本targetSdkVersion修改为30&#xff08;Android 11&#xff09; targetSdkVersion 30Android11的改变改变主要影响以Adnroid11 为目标版本的应用&#xff08;targetSdkVersion>30才有影响&#xff09;&#xff0c;和所…

OpenCV实现模板匹配和霍夫线检测,霍夫圆检测

一&#xff0c;模板匹配 1.1代码实现 import cv2 as cv import numpy as np import matplotlib.pyplot as plt from pylab import mplmpl.rcParams[font.sans-serif] [SimHei]#图像和模板的读取 img cv.imread("cat.png") template cv.imread(r"E:\All_in\o…

18672-2014 枸杞 学习记录

声明 本文是学习GB-T 18672-2014 枸杞. 而整理的学习笔记,分享出来希望更多人受益,如果存在侵权请及时联系我们 1 范围 本标准规定了枸杞的质量要求、试验方法、检验规则、标志、包装、运输和贮存。 本标准适用于经干燥加工制成的各品种的枸杞成熟果实。 2 规范性引用文件…

无代码解决信息孤岛,云表实现软件开发"书同文,车同轨"

什么是信息孤岛&#xff1f;信息孤岛就是一个组织或系统内部的信息资源无法与其他部分或外部系统共享、互操作&#xff0c;从而使得这些信息无法在整个组织或系统中发挥最大作用的现象。这种现象通常发生在不同部门、不同业务领域或不同系统之间&#xff0c;导致信息重复、浪费…

精彩回顾 | 迪捷软件亮相2023世界智能网联汽车大会

2023年9月24日&#xff0c;2023世界智能网联汽车大会&#xff08;以下简称大会&#xff09;在北京市圆满落幕。迪捷软件北京参展之行圆满收官。 本次大会由工业和信息化部、公安部、交通运输部、中国科学技术协会、北京市人民政府联合主办&#xff0c;是我国首个经国务院批准的…

【编码魔法师系列_构建型1.2 】工厂方法模式(Factory Method)

学会设计模式&#xff0c;你就可以像拥有魔法一样&#xff0c;在开发过程中解决一些复杂的问题。设计模式是由经验丰富的开发者们&#xff08;GoF&#xff09;凝聚出来的最佳实践&#xff0c;可以提高代码的可读性、可维护性和可重用性&#xff0c;从而让我们的开发效率更高。通…

基于微信小程序的竞赛管理平台设计与实现(开题报告+任务书+源码+lw+ppt +部署文档+讲解)

文章目录 前言运行环境说明学生微信端的主要功能有&#xff1a;竞赛负责人的主要功能&#xff1a;管理员的主要功能有&#xff1a;具体实现截图详细视频演示为什么选择我自己的网站自己的小程序&#xff08;小蔡coding&#xff09;有保障的售后福利 代码参考论文参考源码获取 前…

33 排序链表

排序链表 题解1 STL - multiset题解2 归并【自顶向下】题解3 归并【自底向上】自底向上&#xff1a;子串长度 l 从1开始&#xff0c;合并后的串长度*2&#xff0c;11 -> 22 -> 44 ->... 给你链表的头结点 head &#xff0c;请将其按 升序 排列并返回 排序后的链表 …

如何选择一款高性价比的便携式明渠流量计

如何选择一款精度高、测量准确、易操作的便携式明渠流量计 如何选择一款精度高、测量准确、易操作的便携式明渠流量计 便携式明渠流量计&#xff1a;是一款对现有在线水监测系统中流量监测的对比装置。该便携式明渠流量计实现了比对在线系统的液位误差及流量误差。引导式的操作…

基于微信小程序的背单词学习激励系统设计与实现(源码+lw+部署文档+讲解等)

文章目录 前言用户微信端的主要功能有&#xff1a;管理员的主要功能有&#xff1a;具体实现截图论文参考详细视频演示为什么选择我自己的网站自己的小程序&#xff08;小蔡coding&#xff09;有保障的售后福利 代码参考源码获取 前言 &#x1f497;博主介绍&#xff1a;✌全网粉…

2023年思维100秋季赛报名中,比赛安排、阶段、形式和5年真题资源

家有小学生的魔都的爸爸妈妈们&#xff0c;上海市含金量比较高的奥数比赛——思维100秋季比赛正在报名中。如果你想让你的孩子多个证书、多个经历、以赛促学&#xff0c;来了解一下吧。 一、思维100比赛是什么&#xff1f; 思维100是原来的“中环杯”&#xff0c;全称"中…

ChatGPT实战-Embeddings打造定制化AI智能客服

本文介绍Embeddings的基本概念&#xff0c;并使用最少但完整的代码讲解Embeddings是如何使用的&#xff0c;帮你打造专属AI聊天机器人&#xff08;智能客服&#xff09;&#xff0c;你可以拿到该代码进行修改以满足实际需求。 ChatGPT的Embeddings解决了什么问题&#xff1f; …

蓝桥杯 题库 简单 每日十题 day10

01 最少砝码 最少砝码 问题描述 你有一架天平。现在你要设计一套砝码&#xff0c;使得利用这些砝码 可以出任意小于等于N的正整数重量。那么这套砝码最少需要包含多少个砝码&#xff1f; 注意砝码可以放在天平两边。 输入格式 输入包含一个正整数N。 输出格式 输出一个整数代表…

面部情绪识别Facial Emotion Recognition:从表情到情绪的全面解析与代码实现

面部情绪识别&#xff08;FER&#xff09;是指根据面部表情对人类情绪进行识别和分类的过程。通过分析面部特征和模式&#xff0c;机器可以有依据地推测一个人的情绪状态。这一面部识别子领域是一个高度跨学科的领域&#xff0c;它借鉴了计算机视觉、机器学习和心理学的见解。 …

蓝桥杯每日一题2023.9.25

4406. 积木画 - AcWing题库 题目描述 分析 在完成此问题前可以先引入一个新的问题 291. 蒙德里安的梦想 - AcWing题库 我们发现16的二进制是 10000 15的二进制是1111 故刚好我们可以从0枚举到1 << n(相当于二的n次方的二进制表示&#xff09; 注&#xff1a;奇数个0…

学生用什么光的灯最好?2023最适合学生用的台灯推荐

学生当然用全光谱的台灯最好。全光谱台灯主要还是以护眼台灯为主&#xff0c;因为不仅色谱丰富&#xff0c;贴近自然色的全光谱色彩&#xff0c;通常显色指数都能达到Ra95以上&#xff0c;显色能力特别强&#xff0c;而且还具有其他防辐射危害、提高光线舒适度的特性&#xff0…

Unity之Hololens如何实现传送功能

一.前言 什么是Hololens? Hololens是由微软开发的一款混合现实头戴式设备,它将虚拟内容与现实世界相结合,为用户提供了沉浸式的AR体验。Hololens通过内置的传感器和摄像头,能够感知用户的环境,并在用户的视野中显示虚拟对象。这使得用户可以与虚拟内容进行互动,将数字信…

实施主品牌进化战略(四):升级顾客认知驱动力

很多企业常常会陷入困境&#xff0c;有足够优秀的产品&#xff0c;但没有有效地提升顾客认知驱动力。产品优秀&#xff0c;但如果顾客对品牌的认知度不高、信任度不足&#xff0c;主品牌定位与目标顾客的认知产生距离&#xff0c;那么增长将无处可寻。所以&#xff0c;提升顾客…