9. 什么是 Beam Search?深入理解模型生成策略

news2024/9/22 7:36:13

是不是总感觉很熟悉?
在之前第5,7,8篇文章中,我们都曾经用到过与它相关的参数,而对于早就有着实操经验的同学们,想必见到的更多。这篇文章将从示例到数学原理和代码带你进行理解。

Beam Search 对应的中文翻译为“集束搜索”或“束搜索”。你可以将其当作是贪心算法的拓展,其实是很简单的概念:贪心算法每次只选择最好的,而 Beam Search 会在多个候选中进行选择。通过这篇文章,你将了解到:

  • Beam Width(束宽) 的实际作用,常对应于参数名 num_beams
  • 所有候选序列生成结束标记 的含义,常对应于参数名 early_stopping
  • Beam Search 的基本原理和工作机制

强烈建议访问:Beam Search Visualizer,这是一个非常 Amazing 的交互式项目,在即将完成这个文章攥写的时候我通过官方文档发现了它,让理论与实际搭上了桥。
计划后续补上数学和与其他一些算法的比较。

Beam Search 的基本概念

Beam Search 是一种宽度优先搜索算法,通过保留多个候选序列(即“束”)来探索可能的输出空间。不同于贪心搜索(Greedy Search)每次只选择当前最优的一个候选序列,Beam Search 可以同时保留多个(由束宽 k k k 决定),从而减少陷入局部最优解的风险。

Beam Search 的工作原理

Beam Search 的核心思想是在每一步生成过程中,保留束宽 k k k 个最有可能的候选序列,而不是仅保留一个最优序列(这种是贪心算法,也就是说束宽 k k k 为 1 的时候 Beam Search 就是 Greedy Search)。以下是 Beam Search 的基本步骤:

  1. 初始化:从一个初始序列(通常为空或特殊起始标记)开始,设定束宽 k k k,初始化候选序列集 B 0 = { start } B_0 = \{ \text{start} \} B0={start}
  2. 迭代生成:对于当前所有候选序列 B t − 1 B_{t-1} Bt1,扩展一个新的词汇或符号,生成所有可能的下一个词汇组合,并计算每个序列的概率。
  3. 选择顶束:从所有扩展的候选序列中,选择得分最高的 k k k 个序列,作为下一步的候选序列 B t B_t Bt
  4. 终止条件:当所有候选序列都生成了结束标记(如 <eos>)或达到设定的最大长度 T T T 时,停止生成。
  5. 选择最终序列:从最终的候选序列集中,选择得分最高的序列作为输出。

:以GPT为例,扩展实际对应于去获取 tokens 的概率。

举个例子

  1. 初始化

    • 束宽 ( k k k): 2
    • 当前候选集 ( B 0 B_0 B0): { (空) } \{\text{(空)}\} {(空)}
    • 词汇表 { A , B , C , ‘<eos>‘ } \{A, B, C, \text{`<eos>`}\} {A,B,C,‘<eos>‘}
    • 扩展(生成所有可能的下一个词汇):
      扩展结果概率
      A 0.4 \textbf{0.4} 0.4
      B 0.3 \textbf{0.3} 0.3
      C 0.2 0.2 0.2
      <eos> 0.1 0.1 0.1
    • 选择顶束 ( k = 2 k=2 k=2):
      • A A A 0.4 0.4 0.4
      • B B B 0.3 0.3 0.3
    • 新的候选集 ( B 1 B_1 B1): { A ( 0.4 ) , B ( 0.3 ) } \{A (0.4), B (0.3)\} {A(0.4),B(0.3)}
  2. 扩展 A A A B B B

    • 扩展 A A A

      • 生成概率: { A : 0.3 , B : 0.1 , C : 0.4 , ‘<eos>‘ : 0.2 } \{A: 0.3, B: 0.1, C: 0.4, \text{`<eos>`}: 0.2\} {A:0.3,B:0.1,C:0.4,‘<eos>‘:0.2}
      扩展结果概率计算概率
      A A AA AA 0.4 × 0.3 0.4 \times 0.3 0.4×0.3 0.12 \textbf{0.12} 0.12
      A B AB AB 0.4 × 0.1 0.4 \times 0.1 0.4×0.1 0.04 0.04 0.04
      A C AC AC 0.4 × 0.4 0.4 \times 0.4 0.4×0.4 0.16 \textbf{0.16} 0.16
      A <eos> A\text{<eos>} A<eos> 0.4 × 0.2 0.4 \times 0.2 0.4×0.2 0.08 0.08 0.08
    • 扩展 B B B

      • 生成概率: { A : 0.1 , B : 0.1 , C : 0.3 , ‘<eos>‘ : 0.5 } \{A: 0.1, B: 0.1, C: 0.3, \text{`<eos>`}: 0.5\} {A:0.1,B:0.1,C:0.3,‘<eos>‘:0.5}
      扩展结果概率计算概率
      B A BA BA 0.3 × 0.1 0.3 \times 0.1 0.3×0.1 0.03 0.03 0.03
      B B BB BB 0.3 × 0.1 0.3 \times 0.1 0.3×0.1 0.03 0.03 0.03
      B C BC BC 0.3 × 0.3 0.3 \times 0.3 0.3×0.3 0.09 \textbf{0.09} 0.09
      B <eos> B\text{<eos>} B<eos> 0.3 × 0.5 0.3 \times 0.5 0.3×0.5 0.15 \textbf{0.15} 0.15
    • 所有扩展序列及其概率

      序列概率
      A C AC AC 0.16 \textbf{0.16} 0.16
      A A AA AA 0.12 0.12 0.12
      B <eos> B\text{<eos>} B<eos> 0.15 \textbf{0.15} 0.15
      B C BC BC 0.09 0.09 0.09
      A <eos> A\text{<eos>} A<eos> 0.08 0.08 0.08
      A B AB AB 0.04 0.04 0.04
      B A BA BA 0.03 0.03 0.03
      B B BB BB 0.03 0.03 0.03
    • 选择顶束 ( k = 2 k=2 k=2):

      • A C AC AC 0.16 0.16 0.16
      • B <eos> B\text{<eos>} B<eos> 0.15 0.15 0.15
    • 新的候选集 ( B 2 B_2 B2): { A C ( 0.16 ) , B <eos> ( 0.15 ) } \{AC (0.16), B\text{<eos>} (0.15)\} {AC(0.16),B<eos>(0.15)}

  3. 仅扩展 A C AC AC

    • 生成概率: { A : 0.1 , B : 0.2 , C : 0.5 , ‘<eos>‘ : 0.2 } \{A: 0.1, B: 0.2, C: 0.5, \text{`<eos>`}: 0.2\} {A:0.1,B:0.2,C:0.5,‘<eos>‘:0.2}
    扩展结果概率计算概率
    A C A ACA ACA 0.16 × 0.1 0.16 \times 0.1 0.16×0.1 0.016 0.016 0.016
    A C B ACB ACB 0.16 × 0.2 0.16 \times 0.2 0.16×0.2 0.032 0.032 0.032
    A C C ACC ACC 0.16 × 0.5 0.16 \times 0.5 0.16×0.5 0.080 0.080 0.080
    A C <eos> AC\text{<eos>} AC<eos> 0.16 × 0.2 0.16 \times 0.2 0.16×0.2 0.032 0.032 0.032
    • 由于 B <eos> B\text{<eos>} B<eos> 已完成,我们选择扩展结果中的顶束:
      • A C C ACC ACC 0.064 0.064 0.064
      • 以某种规则选择 A C B ACB ACB A C <eos> AC\text{<eos>} AC<eos> 0.032 0.032 0.032
    • 新的候选集 ( B 3 B_3 B3): { A C C ( 0.064 ) , A C B ( 0.032 ) } \{ACC (0.064), ACB (0.032)\} {ACC(0.064),ACB(0.032)}
  4. 后续步骤

    • 继续扩展:重复上述过程,直到所有候选序列都生成了 <eos> 或达到设定的最大长度。

过程演示

现在是你访问它的最好时机:Beam Search Visualizer

处理 <eos> 的逻辑

在每一步生成过程中,如果某个序列生成了 <eos>,则将其标记为完成,不再进行扩展。以下是处理 <eos> 的示例:

  • 假设在某一步,序列 A C B ACB ACB 扩展出 A C B <eos> ACB\text{<eos>} ACB<eos> 0.032 × 1 = 0.032 0.032 \times 1 = 0.032 0.032×1=0.032),则:
    • A C B <eos> ACB\text{<eos>} ACB<eos> 保留在最终候选集,但不再扩展。
    • Beam Search 继续扩展其他未完成的序列,直到所有序列完成或达到最大长度。

问题如果有一个序列被标记为完成(生成了 <eos>),在下一个扩展步骤中,Beam Search 应该扩展多少个候选序列?

答:束宽 k k k

示例图(k=3):

你可以在下图中看到,即便有一个序列生成了 <eos>,下一个扩展步骤中还是会扩展 k=3 个候选序列。

image-20240915235014101

实际应用中的 Beam Search

在机器翻译,文本生成,语音转识别等生成式模型领域,你都能看见Beam Search,它被广泛地应用。

代码示例

使用 Hugging Face Transformers 库的简单示例:

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# 指定模型名称
model_name = "distilgpt2"

# 加载分词器和模型
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# 移动模型到设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# 设置模型为评估模式
model.eval()

# 输入文本
input_text = "Hello GPT"

# 编码输入文本
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)

# 生成文本,使用 Beam Search
beam_width = 5
with torch.no_grad():
    outputs = model.generate(
        inputs,
        max_length=50,
        num_beams=beam_width,  # 你可以看到 beam_width 对应的参数名为 num_beams
        no_repeat_ngram_size=2,
        early_stopping=True  # 开启 early_stopping,当所有候选序列生成<eos>停止
    )

# 解码生成的文本
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("生成的文本:")
print(generated_text)

输出

生成的文本:
Hello GPT.

This article was originally published on The Conversation. Read the original article.

对比不同束宽的输出

# 输入文本
input_text = "Hello GPT"

# 编码输入文本
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)

# 设置束宽不同的生成策略
beam_widths = [1, 3, 5]  # 使用不同的束宽

# 生成并打印结果
for beam_width in beam_widths:
    with torch.no_grad():
        outputs = model.generate(
            inputs,
            max_length=50,
            num_beams=beam_width,  
            no_repeat_ngram_size=2,
            early_stopping=True,
        )
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"束宽 {beam_width} 的生成结果:")
    print(generated_text)
    print('-' * 50)
束宽 1 的生成结果:
Hello GPT is a free and open source software project that aims to provide a platform for developers to build and use GPGP-based GPSP based GPCs. GPP is an open-source software development platform that is designed to
--------------------------------------------------
束宽 3 的生成结果:
Hello GPT.

This article is part of a series of articles on the topic, and will be updated as more information becomes available.
--------------------------------------------------
束宽 5 的生成结果:
Hello GPT.

This article was originally published on The Conversation. Read the original article.
--------------------------------------------------

参考链接

  • Beam-search decoding
  • Beam Search Visualizer

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2139393.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

工厂模式(一):简单工厂模式

一、概念 顾名思义&#xff0c;带着工厂&#xff0c;两字肯定就是有标准、快速、统一等等一些工厂独有的特点。 那么什么是简单工厂模式呢&#xff1f; 定义&#xff1a;简单工厂模式是一种创建对象的设计模式&#xff0c;它定义了一个工厂类通过某个静态方法来生成不同类型的…

基于AutoDL部署langchain-chatchat-0.3.1实战

一、租用AutoDL云服务器&#xff0c;配置环境 1.1 配置AutoDL环境 注册好autodl账户之后&#xff0c;开始在上面租服务器&#xff0c;GPU我选择的是RTX4090*2&#xff0c;西北B区&#xff0c;基础镜像选择的是Pytorch-2.3.0-python-3.12&#xff08;ubuntu22.04&#xff09;-…

夸克网盘电脑端和手机端如何查看自己分享的文件

夸克网盘有些地方做的还是有点抽象&#xff0c;好多东西是真的找不到。 找了半天终于找到了自己分享的文件&#xff0c;给大家分享下。 电脑端 点击左侧栏的“快传”&#xff0c;然后点击“我分享的” 手机端 手机端也是类似&#xff0c;点击“快传”后再点击“我分享的”&a…

白月光git

感觉bug好多干脆直接从头到脚梳理 感冒不嘻嘻 近况是&#xff1a; 早起学习 开车去沟里 把蜜蜂拍到狗身上 把车开回来 吃席 安装git和VScode 都是从官网上装的&#xff0c;不说那么多咯&#xff0c;之前说过&#xff1a; 进程间也要唠一唠-CSDN博客https://blog.csdn.net…

Spring4-IoC3-手写IoC

Spring框架的IoC是基于Java反射机制实现的 Java反射机制是在运行状态中&#xff0c;对于任意一个类&#xff0c;都能够知道这个类的所有属性和方法&#xff1b;对于任意一个对象&#xff0c;都能够调用它的任意方法和属性&#xff0c;这种动态获取信息以及动态调用对象方法的功…

【AI学习笔记】初学机器学习西瓜书的知识点概要记录

初学机器学习西瓜书的知识点概要记录 1.1 机器学习1.2 典型的机器学习过程1.2 机器学习理论1.3 基本术语1.4 归纳偏好1.5 NFL定理2.1 泛化能力2.2 过拟合和欠拟合2.3 三大问题2.4 评估方法2.5 调参与验证集2.6 性能度量2.7 比较检验 以下内容出自周志华老师亲讲西瓜书 1.1 机器…

复习:数组

目录 数组名 一般性理解 下标引用与间接访问 例外 一维数组 声明与初始化 下标引用 内存分配 长度计算 二维数组 内存分配 长度计算 声明与初始化 数组指针 引入 数组指针 一级指针 引入 一级指针 章尾问题 数组名 一般性理解 数组名是一个指向&#x…

DockerLinux安装DockerDocker基础

Linux软件安装 yum命令安装 通过yum命令安装软件,是直接把软件安装到Linux系统中 安装和卸载都比较麻烦,因为软件和系统是强关联的 Docker docker是一种容器技术,可以解决软件和系统强关联关系,使得软件的安装和卸载更方便,它可以将我们的应用以及依赖进行打包,制作出一个镜…

算法:TopK问题

题目 有10亿个数字&#xff0c;需要找出其中的前k大个数字。 为了方便讲解&#xff0c;这里令k为5。 思路分析&#xff08;以找前k大个数字为例&#xff09; 很容易想到&#xff0c;进行排序&#xff0c;然后取前k个数字即可。 但是&#xff0c;难点在于&#xff0c;10亿个数…

人工智能GPT____豆包使用的一些初步探索步骤 体验不一样的工作

豆包工具是我使用比较频繁的一款软件&#xff0c;其集合了很多功能。对话 图像 AI搜索 伴读等等使用都非常不错。电脑端安装集合了很多功能。 官网直达&#xff1a;豆包 使用我的文案创作能力&#xff0c;您可以注意以下几个技巧&#xff1a; 明确需求&#xff1a; 尽可能具…

Linux进阶 把用户加入和移除用户组

1、Linux 单用户多任务,多用户多任务概念 Linux 是一个多用户、多任务的操作系统。 单用户多任务、多用户多任务 概念; Linux 的 单用户、多任务以 beinan 登录系统,进入系统后,我要打开gedit 来写文档,但在写文档的过程中,我感觉少点音乐,所以又打开xmms 来点音乐;当…

C语言:联合和枚举

一. 联合体 1.联合体的声明 1. 像结构体一样&#xff0c;联合体也是由一个或者多个成员构成&#xff0c;这些成员可以不同的类型。 union { 成员1&#xff1b; 成员2&#xff1b; ........ }; //联合体类型 union S {char c;int i; }; 2.联合体的特点和大小计算 像结构体一…

emWin5的图片半透明之旅

文章目录 目标过程直接使用png (失败了)通过 BmpCvt.exe 转换一下&#xff08;成功了&#xff09;通过bmp转 &#xff08;半成功吧&#xff09; 补充工程结构整理 目标 显示半透明效果&#xff0c;类似png那种&#xff0c;能透过去&#xff0c;看到背景。 过程 直接使用png …

【STM32】单级与串级PID控制的C语言实现

【STM32】单级与串级PID的C语言实现 前言PID理论什么是PIDPID计算过程PID计算公式Pout、Iout、Dout的作用单级PID与串级PID PID应用单级PID串级PID 前言 笔者最近在学习PID控制器&#xff0c;本文基于Blog做以总结。CSDN上已有大量PID理论知识的优秀文章&#xff0c;因此本文将…

基于HPLC的低压电力采集方案

1. 组网部署 2. 组网部件 3. 原理

✔2848. 与车相交的点

代码实现&#xff1a; 方法一&#xff1a;哈希表 #define fmax(a, b) ((a) > (b) ? (a) : (b))int numberOfPoints(int **nums, int numsSize, int *numsColSize) {int hash[101] {0};int max 0;for (int i 0; i < numsSize; i) {max fmax(max, nums[i][1]);for …

基于SSM+Vue+MySQL的新生报到管理系统

系统展示 用户界面 管理员界面 系统背景 在当今高等教育日益普及的背景下&#xff0c;新生报到管理成为高校日常管理中的重要环节。为了提升报到效率、优化管理流程并确保数据的准确性与安全性&#xff0c;我们设计并实现了一个基于SSM&#xff08;SpringSpring MVCMyBatis&…

JavaScript高级——作用域和作用链

1、概念理解&#xff1a; —— 就是一块“地盘”&#xff0c;一个代码所在的区域 —— 静态的&#xff08;相对于上下文对象&#xff09;&#xff0c;在编写代码时就确定了 2、分类 ① 全局作用域 ② 函数作用域 ③ 没有块作用域&#xff08;ES6有了&#xff09; 3、作用 …

app抓包 chrome://inspect/#devices

一、前言&#xff1a; 1.首先不支持flutter框架&#xff0c;可支持ionic、taro 2.初次需要翻墙 3.app为debug包&#xff0c;非release 二、具体步骤 1.谷歌浏览器地址&#xff1a;chrome://inspect/#devices qq浏览器地址&#xff1a;qqbrowser://inspect/#devi…

C#开发基础之单例模式下的集合数据,解决并发访问读写冲突的问题

1. 前言 在C#中&#xff0c;使用单例模式管理集合数据时&#xff0c;如果多线程同时访问集合&#xff0c;容易产生并发访问的读写冲突问题。单例模式下集合数据的并发访问读写冲突是如何产生的&#xff1f; 单例模式确保一个类在整个应用运行期间只有一个实例&#xff0c;这使…