【深度学习】序列生成模型(二):束搜索

news2025/1/23 9:17:18

文章目录

    • 序列生成
    • 束搜索
      • 理论基础
      • 算法步骤
      • python实现

序列生成

  在进行最大似然估计训练后的模型 p θ ( x ∣ x 1 : ( t − 1 ) ) p_\theta(x | \mathbf{x}_{1:(t-1)}) pθ(xx1:(t1)),我们可以使用该模型进行序列生成。生成的过程是按照时间顺序逐步生成序列样本。假设在第 t t t 步,我们已经生成了前 t − 1 t-1 t1 步的序列前缀 x 1 : ( t − 1 ) = x 1 , … , x t − 1 \mathbf{x}_{1:(t-1)} = x_1, \ldots, x_{t-1} x1:(t1)=x1,,xt1,我们希望在当前步生成下一个词 x t x_t xt。生成的过程可以用以下概率分布表示:

x t ∼ p θ ( x ∣ x 1 : ( t − 1 ) ) x_t \sim p_\theta(x | \mathbf{x}_{1:(t-1)}) xtpθ(xx1:(t1))

其中, x 1 : ( t − 1 ) \mathbf{x}_{1:(t-1)} x1:(t1) 是已经生成的前缀序列, x t x_t xt 是在给定前缀序列的条件下,由模型生成的当前时刻的词。

  这个过程可以迭代进行,直到生成完整的序列样本。在每一步,模型根据已经生成的前缀序列生成当前时刻的词,然后将当前时刻的词添加到前缀序列中,用于生成下一个时刻的词。

生成的序列样本可以用如下方式表示:

x ^ = x ^ 1 , x ^ 2 , … , x ^ T \mathbf{\hat{x}} = \hat{x}_1, \hat{x}_2, \ldots, \hat{x}_T x^=x^1,x^2,,x^T

其中, x ^ t \hat{x}_t x^t 是在第 t t t 步生成的词, x ^ \mathbf{\hat{x}} x^ 是完整的生成序列。这个过程是根据训练得到的模型对数据分布进行采样,从而生成新的符合训练数据分布的序列。

  自回归的方式可以生成一个无限长度的序列.为了避免这种情况,通常会设置一个特殊的符号⟨𝐸𝑂𝑆⟩(End-of-Sequence)来表示序列的结束.在训练时,每个序列样本的结尾都会加上结束符号 ⟨ EOS ⟩ \langle \text{EOS} \rangle EOS。训练模型时,这有助于模型学习何时停止生成。在测试时,一旦生成了结束符号 ⟨ EOS ⟩ \langle \text{EOS} \rangle EOS,模型就会中止生成过程。

束搜索

理论基础

  在每个时间步,自回归模型贪婪搜索选择当前条件概率分布中具有最高概率的词作为生成的词。具体而言,对于每个时间步 t t t,生成的词 x ^ t \hat{x}_t x^t是:

x ^ t = arg ⁡ max ⁡ x ∈ V p θ ( x ∣ x 1 : ( t − 1 ) ) \hat{x}_t = \arg\max_{x \in \mathcal{V}} p_\theta(x | \mathbf{x}_{1:(t-1)}) x^t=argxVmaxpθ(xx1:(t1))

其中, V \mathcal{V} V 是词表, x 1 : ( t − 1 ) = x ^ 1 , … , x ^ t − 1 \mathbf{x}_{1:(t-1)} = \hat{x}_1, \ldots, \hat{x}_{t-1} x1:(t1)=x^1,,x^t1 是前 t − 1 t-1 t1 步中已经生成的前缀序列。

  这种贪婪搜索策略是一种简单且直观的方法,但它有一个主要的缺点,即可能导致生成的序列不是全局最优的。由于在每个时间步都选择了局部最大概率的词,生成的序列并不保证是整个序列的全局最大概率。这种策略可能导致生成的序列缺乏一致性或流畅性。
  为了改善这种情况,束搜索(Beam Search)是一种常用的启发式方法,特别在序列生成任务中应用广泛。在束搜索中,每个时间步生成多个备选序列,而不仅仅是一个。这样可以在每个时间步维持一个集合,称为束(beam),其中包含多个备选序列。束的大小由超参数 K K K 决定,通常被称为束大小。
在这里插入图片描述
  在每个时间步,算法选择概率最高的 K K K 个序列作为备选,并将它们作为下一个时间步的输入。这样,算法在整个生成过程中维持了 K K K 条备选序列,允许更全面地探索可能的序列空间。
  束搜索有助于减少搜索空间,提高搜索的效率。然而,束大小 K K K 的选择是一个权衡,较小的 K K K 可能导致搜索空间不够广泛,而较大的 K K K 则会增加计算开销。因此,束大小的选择通常需要根据具体任务和性能需求进行调整。

算法步骤

  1. 初始化: 设置束大小 K K K,初始化一个束(beam)用于存储备选序列。初始时,束中包含一个空序列。

  2. 逐步生成: 对于每个时间步 t t t,执行以下步骤:

    a. 对于束中的每个备选序列,生成下一个词的备选集合。计算条件概率 p θ ( x t ∣ context ) p_\theta(x_t | \text{context}) pθ(xtcontext)

    b. 对于所有的备选序列和它们的备选词,计算在当前时间步的累积概率。

    c. 从所有的备选序列中选择累积概率最高的 K K K个序列作为新的束。

    d. 如果生成了结束符号或达到了最大生成长度,则停止生成。

  3. 输出: 选择束中最终累积概率最高的序列作为最终的生成结果。

python实现

def beam_search(model, initial_context, beam_size, max_length):
    # 初始化束,初始时包含一个空序列
    beam = [([], 1.0)]  # 初始序列和初始概率

    # 逐步生成
    for t in range(max_length):
        new_beam = []

        # 对于束中的每个备选序列
        for sequence, score in beam:
            # 生成备选词
            candidates = generate_candidates(model, sequence, initial_context)

            # 计算累积概率
            for candidate in candidates:
                new_sequence = sequence + [candidate]
                new_score = score * calculate_probability(model, new_sequence, initial_context)

                new_beam.append((new_sequence, new_score))

        # 选择累积概率最高的 K 个序列作为新的束
        beam = sorted(new_beam, key=lambda x: x[1], reverse=True)[:beam_size]

        # 判断是否生成了结束符号或达到最大生成长度
        if is_finished(beam):
            break

    # 选择最终累积概率最高的序列作为结果
    best_sequence = max(beam, key=lambda x: x[1])[0]
    return best_sequence

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1320279.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

adb: error: cannot create file/directory ‘d:/1.png‘: No such file or directory

将文件从设备读取到PC 由于权限问题&#xff0c;不能直接pull到电脑磁盘根目录&#xff0c;否则会报错&#xff1a; adb pull <remote> <local> eg: C:\Users\admin>adb pull /sdcard/server.log C:\Users\admin\Desktop /sdcard/server.log: 1 file pulled.…

LeedCode刷题---二分查找类问题

顾得泉&#xff1a;个人主页 个人专栏&#xff1a;《Linux操作系统》 《C/C》 《LeedCode刷题》 键盘敲烂&#xff0c;年薪百万&#xff01; 一、二分查找 题目链接&#xff1a;二分查找 题目描述 给定一个 n 个元素有序的&#xff08;升序&#xff09;整型数组 nums 和一…

基于STC89C51单片机实现的森林防火系统源码+仿真+原理图+设计报告,含视频讲解

森林防火 摘要 森林防火是非常必要的,火灾对森林的破坏是具有毁灭性的,有着很大的危害,在春秋季节森林火灾高发期,若发生火灾,对人民生活带来极大危害,不仅危害人们生产生活,而且对地球环境产生影响.本课题研究的内容是以单片机STC89C51为控制核心&#xff0c;以MQ-2型半导体电…

Android hilt使用

一&#xff0c;添加依赖库 添加依赖库app build.gradle.kts implementation("com.google.dagger:hilt-android:2.49")annotationProcessor("com.google.dagger:hilt-android:2.49")annotationProcessor("com.google.dagger:hilt-compiler:2.49"…

关于前端学习的思考-浮动元素嵌套块级元素12.18

1、块级元素嵌套浮动元素 先摆图片&#xff0c;当橘色的盒子高度减少的时候&#xff0c;NK AD TB PK NN并不会减少。如何解决呢&#xff1f; 加一个overflow&#xff1a;clip或者hidden 2、浮动元素嵌套块级元素 加一个overflow&#xff1a;clip或者hidden 综上所述&#xff0…

2020 年网络安全应急响应分析报告

2020 年全年奇安信集团安服团队共参与和处置了全国范围内 660起网络安全应急响应事件。2020 年全年应急响应处置事件行业 TOP3 分别为:政府部门行业(146 起)医疗卫生行业(90 起)以及事业单位(61 起&#xff0c;事件处置数分别占应急处置所有行业的 22.1%、13.6%、9.2%。2020 年…

修改npm源码解决服务端渲染环境中localstorage报错read properties of undefined (reading getItem)

现象&#xff1a; 这个问题是直接指向了我使用的第三方库good-storage&#xff0c;这是一个对localStorage/sessionStorage做了简单封装的库&#xff0c;因为项目代码有一个缓存cache.ts有用到 原因分析&#xff1a; 从表象上看是storage对象找不到getItem方法&#xff0c; 但…

Vue3使用Three.js导入gltf模型并解决模型为黑色的问题

背景 如今各类数字孪生场景对三维可视化的需求持续旺盛&#xff0c;因为它们可以用来创建数字化的双胞胎&#xff0c;即现实世界的物体或系统的数字化副本。这种技术在工业、建筑、医疗保健和物联网等领域有着广泛的应用&#xff0c;可以帮助人们更好地理解和管理现实世界的事…

Selenium框架的使用心得(一)

最近使用selenium框架实现业务前端的UI自动化&#xff0c;在使用selenium时&#xff0c;有一些心得想要和大家分享一下~ Selenium是一款用于web应用程序测试的工具&#xff0c;常用来实现稳定业务的UI自动化。这里&#xff0c;不想对其发展历史做介绍&#xff0c;也不想用官方…

EXCEL SUM类函数

参考资料 万能函数SUMPRODUCT超实用的10种经典用法 目录 一. SUM二. SUMIF2.1 统计贾1的销售额2.2 > 900 的销售总额2.3 计算贾1和贾22的销售总额2.4 多区域计算 三. SUMIFS3.1 统计苹果&#xff0c;在第一季度的总数量3.2 统计苹果&#xff0c;在第一季度&#xff0c;>…

智能家居和智能家居控制设备有什么区别?

智能家居和智能家居控制设备在功能和用途伤的区别&#xff1a; 智能家居是一种整体的概念&#xff0c;它涵盖了整个家庭环境的智能化&#xff0c;包括智能家电、智能照明、智能安防等设备的互联互通和协同工作。智能家居的目标是通过中央控制器或智能音箱等设备&#xff0c;实现…

Python内置函数一览表

为了提高程序员的开发效率&#xff0c;Python 提供了很多可以直接拿来用的函数&#xff08;初学者可以先理解为方法&#xff09;&#xff0c;每个函数都可以帮助程序员实现某些具体的功能。 举个例子&#xff0c;在 Python 2.x 中 print 只是一个关键字&#xff0c;但在 Pytho…

cefsharp120.1.8(cef120.1.8,Chromium120.0.6099.109)版本升级测试,其他版本H264版本

此版本最新版cef120.1.8,Chromium120.0.6099.109 此更新包括一个高优先级安全更新 This update includes a high priority security update. 说明&#xff1a;本版本暂时不支持264&#xff0c;其他H264版本参考119,116&#xff0c;114&#xff0c;110&#xff0c;109等版本 c…

Spring 原理(一)

Spring 原理 它是一个全面的、企业应用开发一站式的解决方案&#xff0c;贯穿表现层、业务层、持久层。但是 Spring仍然可以和其他的框架无缝整合。 Spring 特点 轻量级控制反转面向切面容器框架集合 Spring 核心组件 Spring 常用模块 Spring 主要包 Spring 常用注解 bean …

CUDA C:线程、线程块与线程格

相关阅读 CUDA Chttps://blog.csdn.net/weixin_45791458/category_12530616.html?spm1001.2014.3001.5482 第一百篇博客&#xff0c;写点不一样的。 当核函数在主机端被调用时&#xff0c;它会被转移到设备端执行&#xff0c;此时设备会根据核函数的调用格式产生对应的线程(…

如何应用基础故障编排?

基础故障编排是保障系统稳定性和可用性的关键环节。通过有效应用基础故障编排&#xff0c;组织能够更快速、更智能地应对系统故障&#xff0c;从而提升业务的可靠性和竞争力。本文将介绍如何应用基础故障编排! 1、选择合适的工具&#xff1a; 选择适合组织需求的基础故障编排工…

9. DashBoard

9. DashBoard 文章目录 9. DashBoard9.1 部署Dashboard9.2 使用DashBoard 在kubernetes中完成的所有操作都是通过命令行工具kubectl完成的。 为了提供更丰富的用户体验&#xff0c;kubernetes还开发了一个基于web的用户界面&#xff08;Dashboard&#xff09;。 用户可以使用…

Mysql之Specified key was too long; max key length is xx bytes异常

问题原因&#xff1a;mysq索引的字段都太长了 767字节是 MySQL 版本5.6(以及以前版本)中 InnoDB 表的最大索引前缀长度限制&#xff0c;MyISAM 表的长度为1,000字节。在 MySQL 版本5.7及以上版本中&#xff0c;这个限制增加到了3072字节。 如果对 utf8mb4编码的 varchar 字段设…

python+torch线性回归模型机器学习

程序示例精选 pythontorch线性回归模型机器学习 如需安装运行环境或远程调试&#xff0c;见文章底部个人QQ名片&#xff0c;由专业技术人员远程协助&#xff01; 前言 这篇博客针对《pythontorch线性回归模型机器学习》编写代码&#xff0c;代码整洁&#xff0c;规则&#xf…

【操作系统】实验四 进程调度

实验名称&#xff1a; 实验四 进程调度 实验目的&#xff1a; 1. 加深理解有关进程控制块、进程队列的概念 2. 体会和了解优先级和时间片轮转调度算法的具体实施办法 实验内容&#xff1a; 1. 设计进程控制块 PCB 表结构&#xff08;与实验一的结构相同&#xff09;&#xff…