在 Transformers 中使用约束波束搜索引导文本生成

news2025/1/4 19:54:30

引言

本文假设读者已经熟悉文本生成领域波束搜索相关的背景知识,具体可参见博文 如何生成文本: 通过 Transformers 用不同的解码方法生成文本。

与普通的波束搜索不同,约束 波束搜索允许我们控制所生成的文本。这很有用,因为有时我们确切地知道输出中需要包含什么。例如,在机器翻译任务中,我们可能通过查字典已经知道哪些词必须包含在最终的译文中; 而在某些特定的场合中,虽然某几个词对于语言模型而言差不多,但对最终用户而言可能却相差很大。这两种情况都可以通过允许用户告诉模型最终输出中必须包含哪些词来解决。

这事儿为什么这么难

然而,这个事情操作起来并不容易,它要求我们在生成过程中的 某个时刻 在输出文本的 某个位置 强制生成某些特定子序列。

假设我们要生成一个句子 S,它必须按照先 再   的顺序包含短语 。以下定义了我们希望生成的句子 :

期望

问题是波束搜索是逐词输出文本的。我们可以大致将波束搜索视为函数 ,它根据当前生成的序列 预测下一时刻 的输出。但是这个函数在任意时刻 怎么知道,未来的某个时刻 必须生成某个指定词?或者当它在时刻 时,它如何确定当前那个指定词的最佳位置,而不是未来的某一时刻 ?

f6045aa7a18f8f547b429115f2a26dd2.png
为何约束搜索很难

如果你同时有多个不同的约束怎么办?如果你想同时指定使用短语 短语 怎么办?如果你希望模型在两个短语之间 任选一个 怎么办?如果你想同时指定使用短语 以及短语列表 中的任一短语怎么办?

上述需求在实际场景中是很合理的需求,下文介绍的新的约束波束搜索功能可以满足所有这些需求!

我们会先简要介绍一下新的 约束波束搜索 可以做些什么,然后再深入介绍其原理。

例 1: 指定包含某词

假设我们要将 "How old are you?" 翻译成德语。它对应两种德语表达,其中 "Wie alt bist du?" 是非正式场合的表达,而 "Wie alt sind Sie?" 是正式场合的表达。

不同的场合,我们可能倾向于不同的表达,但我们如何告诉模型呢?

使用传统波束搜索

我们先看下如何使用 传统波束搜索 来完成翻译。

!pip install -q git+https://github.com/huggingface/transformers.git
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

encoder_input_str = "translate English to German: How old are you?"

input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids

outputs = model.generate(
    input_ids,
    num_beams=10,
    num_return_sequences=1,
    no_repeat_ngram_size=1,
    remove_invalid_values=True,
)

print("Output:\n" + 100 *'-')
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Output:
----------------------------------------------------------------------------------------------------
Wie alt bist du?

使用约束波束搜索

但是如果我们想要一个正式的表达而不是非正式的表达呢?如果我们已经先验地知道输出中必须包含什么,我们该如何 将其 注入到输出中呢?

我们可以通过 model.generate()force_words_ids 参数来实现这一功能,代码如下:

tokenizer = AutoTokenizer.from_pretrained("t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

encoder_input_str = "translate English to German: How old are you?"

force_words = ["Sie"]

input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
force_words_ids = tokenizer(force_words, add_special_tokens=False).input_ids

outputs = model.generate(
    input_ids,
    force_words_ids=force_words_ids,
    num_beams=5,
    num_return_sequences=1,
    no_repeat_ngram_size=1,
    remove_invalid_values=True,
)

print("Output:\n" + 100 *'-')
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Output:
----------------------------------------------------------------------------------------------------
Wie alt sind Sie?

如你所见,现在我们能用我们对输出的先验知识来指导文本的生成。以前我们必须先生成一堆候选输出,然后手动从中挑选出符合我们要求的输出。现在我们可以直接在生成阶段做到这一点。

例 2: 析取式约束

在上面的例子中,我们知道需要在最终输出中包含哪些单词。这方面的一个例子可能是在神经机器翻译过程中结合使用字典。

但是,如果我们不知道要使用哪种 _词形_呢,我们可能希望使用单词 rain 但对其不同的词性没有偏好,即 ["raining", "rained", "rains", ...] 是等概的。更一般地,很多情况下,我们可能并不刻板地希望 逐字母一致 ,此时我们希望划定一个范围由模型去从中选择最合适的。

支持这种行为的约束叫 析取式约束 (Disjunctive Constraints) ,其允许用户输入一个单词列表来引导文本生成,最终输出中仅须包含该列表中的 至少一个 词即可。

下面是一个混合使用上述两类约束的例子:

from transformers import GPT2LMHeadModel, GPT2Tokenizer

model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

force_word = "scared"
force_flexible = ["scream", "screams", "screaming", "screamed"]

force_words_ids = [
    tokenizer([force_word], add_prefix_space=True, add_special_tokens=False).input_ids,
    tokenizer(force_flexible, add_prefix_space=True, add_special_tokens=False).input_ids,
]

starting_text = ["The soldiers", "The child"]

input_ids = tokenizer(starting_text, return_tensors="pt").input_ids

outputs = model.generate(
    input_ids,
    force_words_ids=force_words_ids,
    num_beams=10,
    num_return_sequences=1,
    no_repeat_ngram_size=1,
    remove_invalid_values=True,
)

print("Output:\n" + 100 *'-')
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
print(tokenizer.decode(outputs[1], skip_special_tokens=True))
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.

Output:
----------------------------------------------------------------------------------------------------
The soldiers, who were all scared and screaming at each other as they tried to get out of the
The child was taken to a local hospital where she screamed and scared for her life, police said.

如你所见,第一个输出里有 "screaming" ,第二个输出里有 "screamed" ,同时它们都原原本本地包含了 "scared" 。注意,其实 ["screaming", "screamed", ...] 列表中不必一定是同一单词的不同词形,它可以是任何单词。使用这种方式,可以满足我们只需要从候选单词列表中选择一个单词的应用场景。

传统波束搜索

以下是传统 波束搜索 的一个例子,摘自之前的 博文:

451e0c3c42de380ba801a6360c621188.png
波束搜索

与贪心搜索不同,波束搜索会保留更多的候选词。上图中,我们每一步都展示了 3 个最可能的预测词。

num_beams=3 时,我们可以将第 1 步波束搜索表示成下图:

d345d1d6cb77775a4555fd8295b823ef.jpeg
波束搜索第 1 步

波束搜索不像贪心搜索那样只选择 "The dog" ,而是允许将 "The nice""The car" 留待进一步考虑

下一步,我们会为上一步创建的三个分支分别预测可能的下一个词。

43c914a824f6b2ee8150b620d8848c40.jpeg
波束搜索第 2 步

虽然我们 考查 了明显多于 num_beams 个候选词,但在每步结束时,我们只会输出 num_beams 个最终候选词。我们不能一直分叉,那样的话, beams 的数目将在 步后变成 个,最终变成指数级的增长 (当波束数为 时,在 步之后就会变成 个分支!)。

接着,我们重复上述步骤,直到满足中止条件,如生成 <eos> 标记或达到 max_length 。整个过程可以总结为: 分叉、排序、剪枝,如此往复。

约束波束搜索

约束波束搜索试图通过在每一步生成过程中 _注入_所需词来满足约束。

假设我们试图指定输出中须包含短语 "is fast"

在传统波束搜索中,我们在每个分支中找到 k 个概率最高的候选词,以供下一步使用。在约束波束搜索中,除了执行与传统波束搜索相同的操作外,我们还会试着把约束词加进去,以 _看看我们是否能尽量满足约束_。图示如下:

81189bc7d4bd7a179f624274318c1784.jpeg
约束搜索第 1 步

上图中,我们最终候选词除了包括像 "dog""nice" 这样的高概率词之外,我们还把 "is" 塞了进去,以尽量满足生成的句子中须含 "is fast" 的约束。

第二步,每个分支的候选词选择与传统的波束搜索大部分类似。唯一的不同是,与上面第一步一样,约束波束搜索会在每个新分叉上继续强加约束,把满足约束的候选词强加进来,如下图所示:

054bf7c4f64c9ba90f5da5c74117288a.jpeg
约束搜索第 2 步

组 (Banks)

在讨论下一步之前,我们停下来思考一下上述方法的缺陷。

在输出中野蛮地强制插入约束短语 is fast 的问题在于,大多数情况下,你最终会得到像上面的 The is fast 这样的无意义输出。我们需要解决这个问题。你可以从 huggingface/transformers 代码库中的这个 问题 中了解更多有关这个问题及其复杂性的深入讨论。

组方法通过在满足约束和产生合理输出两者之间取得平衡来解决这个问题。

我们把所有候选波束按照其 满足了多少步约束分到不同的组中,其中组 里包含的是 满足了 步约束的波束列表 。然后我们按照顺序轮流选择各组的候选波束。在上图中,我们先从组 2 (Bank 2) 中选择概率最大的输出,然后从组 1 (Bank 1) 中选择概率最大的输出,最后从组 0 (Bank 0) 中选择最大的输出; 接着我们从组 2 (Bank 2) 中选择概率次大的输出,从组 1 (Bank 1) 中选择概率次大的输出,依此类推。因为我们使用的是 num_beams=3,所以我们只需执行上述过程三次,就可以得到 ["The is fast", "The dog is", "The dog and"]

这样,即使我们 强制 模型考虑我们手动添加的约束词分支,我们依然会跟踪其他可能更有意义的高概率序列。尽管 The is fast 完全满足约束,但这并不是一个有意义的短语。幸运的是,我们有 "The dog is""The dog and" 可以在未来的步骤中使用,希望在将来这会产生更有意义的输出。

图示如下 (以上例的第 3 步为例):

f505dd36ba10e1a2c28cad5b2cbb284d.jpeg
约束搜索第 3 步

请注意,上图中不需要强制添加 "The is fast",因为它已经被包含在概率排序中了。另外,请注意像 "The dog is slow""The dog is mad" 这样的波束实际上是属于组 0 (Bank 0) 的,为什么呢?因为尽管它包含词 "is" ,但它不可用于生成 "is fast" ,因为 fast 的位子已经被 slowmad 占掉了,也就杜绝了后续能生成 "is fast" 的可能性。从另一个角度讲,因为 slow 这样的词的加入,该分支 满足约束的进度 被重置成了 0。

最后请注意,我们最终生成了包含约束短语的合理输出: "The dog is fast"

起初我们很担心,因为盲目地添加约束词会导致出现诸如 "The is fast" 之类的无意义短语。然而,使用基于组的轮流选择方法,我们最终隐式地摆脱了无意义的输出,优先选择了更合理的输出。

关于 Constraint 类的更多信息及自定义约束

我们总结下要点。每一步,我们都不断地纠缠模型,强制添加约束词,同时也跟踪不满足约束的分支,直到最终生成包含所需短语的合理的高概率序列。

在实现时,我们的主要方法是将每个约束表示为一个 Constraint 对象,其目的是跟踪满足约束的进度并告诉波束搜索接下来要生成哪些词。尽管我们可以使用 model.generate() 的关键字参数 force_words_ids ,但使用该参数时后端实际发生的情况如下:

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PhrasalConstraint

tokenizer = AutoTokenizer.from_pretrained("t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

encoder_input_str = "translate English to German: How old are you?"

constraints = [
    PhrasalConstraint(
        tokenizer("Sie", add_special_tokens=False).input_ids
    )
]

input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids

outputs = model.generate(
    input_ids,
    constraints=constraints,
    num_beams=10,
    num_return_sequences=1,
    no_repeat_ngram_size=1,
    remove_invalid_values=True,
)

print("Output:\n" + 100 *'-')
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Output:
----------------------------------------------------------------------------------------------------
Wie alt sind Sie?

你甚至可以定义一个自己的约束并将其通过 constraints 参数输入给 model.generate() 。此时,你只需要创建 Constraint 抽象接口类的子类并遵循其要求即可。你可以在 此处 的 Constraint 定义中找到更多信息。

我们还可以尝试其他一些有意思的约束 (尚未实现,也许你可以试一试!) 如  OrderedConstraintsTemplateConstraints 等。目前,在最终输出中约束短语间是无序的。例如,前面的例子一个输出中的约束短语顺序为 scared -> screaming ,而另一个输出中的约束短语顺序为 screamed -> scared 。如果有了 OrderedConstraints, 我们就可以允许用户指定约束短语的顺序。TemplateConstraints 的功能更小众,其约束可以像这样:

starting_text = "The woman"
template = ["the", "", "School of", "", "in"]

possible_outputs == [
   "The woman attended the Ross School of Business in Michigan.",
   "The woman was the administrator for the Harvard School of Business in MA."
]

或是这样:

starting_text = "The woman"
template = ["the", "", "", "University", "", "in"]

possible_outputs == [
   "The woman attended the Carnegie Mellon University in Pittsburgh.",
]
impossible_outputs == [
  "The woman attended the Harvard University in MA."
]

或者,如果用户不关心两个词之间应该隔多少个词,那仅用 OrderedConstraint 就可以了。

总结

约束波束搜索为我们提供了一种将外部知识和需求注入文本生成过程的灵活方法。以前,没有一个简单的方法可用于告诉模型 1. 输出中需要包含某列表中的词或短语,其中 2. 其中有一些是可选的,有些必须包含的,这样 3. 它们可以最终生成至在合理的位置。现在,我们可以通过综合使用 Constraint 的不同子类来完全控制我们的生成!

该新特性主要基于以下论文:

  • Guided Open Vocabulary Image Captioning with Constrained Beam Search

  • Fast Lexically Constrained Decoding with Dynamic Beam Allocation for Neural Machine Translation

  • Improved Lexically Constrained Decoding for Translation and Monolingual Rewriting

  • Guided Generation of Cause and Effect

与上述这些工作一样,还有许多新的研究正在探索如何使用外部知识 (例如 KG (Knowledge Graph) 、KB (Knowledge Base) ) 来指导大型深度学习模型输出。我们希望约束波束搜索功能成为实现此目的的有效方法之一。

感谢所有为此功能提供指导的人: Patrick von Platen 参与了从 初始问题 讨论到 最终 PR 的全过程,还有 Narsil Patry,他们二位对代码进行了详细的反馈。

本文使用的图标来自于 Freepik - Flaticon。


英文原文: https://hf.co/blog/constrained-beam-search

原文作者: Chan Woo Kim

译者: Matrix Yao (姚伟峰),英特尔深度学习工程师,工作方向为 transformer-family 模型在各模态数据上的应用及大规模模型的训练推理。

审校/排版: zhongdongy (阿东)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/622267.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

学习笔记之MySQL索引

1、引言 索引是数据库用来提高性能最常用的工具&#xff0c;一般索引本身也很大&#xff0c;不可能全部存于内存中&#xff0c;因此所以往往以文件形式存于磁盘上。 左表是数据表&#xff0c;共两列七条数据。为了加快Col2的查找&#xff0c;可以维护一个右表所示的二叉查找树…

图论与算法(7)最短路径问题

1.最短路径问题 1.1 带权图的最短路径 最短路径问题是指在一个加权图中寻找两个顶点之间的最短路径&#xff0c;其中路径的长度由边的权重确定。 常见的最短路径算法包括&#xff1a; Dijkstra算法&#xff1a;适用于解决单源最短路径问题&#xff0c;即从一个固定的起点到图…

meethigher-阿里邮箱POP3/SMTP服务

最近发现一个问题&#xff0c;小伙伴给我发的邮件&#xff0c;收和回都不及时。于是我现在将所有的邮箱&#xff0c;通过POP3/SMTP协议整合到了一起。再配合小米手环&#xff0c;就能做到邮件无遗漏。 一、邮箱常用协议 邮箱中常用三类协议 POP3 Post Office Protocol versi…

chatgpt赋能python:Python就业学历要求

Python 就业学历要求 Python 是一门广泛应用于数据科学、人工智能、Web 开发和自动化等领域的编程语言&#xff0c;正在迅速成为行业内最受欢迎的语言之一。如果你想进入这些领域从事相关职业&#xff0c;那么 Python 编程技能将是你的一个优势。但是&#xff0c;Python 就业所…

基于SSM+JSP的毕业生就业信息管理系统设计与实现

博主介绍&#xff1a; 大家好&#xff0c;我是一名在Java圈混迹十余年的程序员&#xff0c;精通Java编程语言&#xff0c;同时也熟练掌握微信小程序、Python和Android等技术&#xff0c;能够为大家提供全方位的技术支持和交流。 我擅长在JavaWeb、SSH、SSM、SpringBoot等框架下…

软考A计划-系统架构师-官方考试指定教程-(3/15)

点击跳转专栏>Unity3D特效百例点击跳转专栏>案例项目实战源码点击跳转专栏>游戏脚本-辅助自动化点击跳转专栏>Android控件全解手册点击跳转专栏>Scratch编程案例 &#x1f449;关于作者 专注于Android/Unity和各种游戏开发技巧&#xff0c;以及各种资源分享&am…

记录--纯CSS实现一个简单又不失优雅的步骤条

这里给大家分享我在网上总结出来的一些知识&#xff0c;希望对大家有所帮助 步骤条是一种用于引导用户按照特定流程完成任务的导航条&#xff0c;在各种分步表单交互场景中广泛应用。先来看一下几个主流前端 UI 框架中步骤条组件的样子&#xff1a; ElementPlus AntDesign Ope…

BCM和board的引脚的区别是什么?如何查看GPIO的BCM和board之间的关系

在树莓派(Raspberry Pi)上使用 GPIO(通用输入输出)时,引脚可以使用两种不同的编号方式:BCM(Broadcom SOC Channel)和board。 BCM 编号:BCM 编号是基于 Broadcom 芯片的引脚编号方式。它使用芯片上的引脚功能编号来标识 GPIO 引脚,这种编号方式是树莓派广泛使用的默认…

Spring事务简介及相关案例

目录 一、事务简介 二、准备数据库 三、创建maven项目&#xff0c;引入依赖和完成相关配置 1. pom.xml文件 2. 创建配置文件 四、编写Java代码 1. Account实体类 2. AccountDao接口 3. AccountService业务类 五、测试 1. 测试方法 2. 测试结果​编辑 往期专栏&…

判断数组中的每个元素是否为正无穷大或负无穷大 numpy.isinf()

【小白从小学Python、C、Java】 【计算机等级考试500强双证书】 【Python-数据分析】 判断数组中的每个元素 是否为正无穷大或负无穷大 numpy.isinf() [太阳]选择题 请问关于以下代码的最后输出的是&#xff1f; import numpy as np a np.array([-np.inf,0,np.inf]) print(&q…

chatgpt赋能python:Python实现文件复制到另一个文件夹下的方法

Python实现文件复制到另一个文件夹下的方法 如果你经常需要复制文件并将它们保存到不同的文件夹下&#xff0c;那么使用Python脚本来执行此任务是一个非常好的选择。Python提供了强大的文件操作功能&#xff0c;使得编写脚本来完成文件操作变得相对简单。在本篇文章中&#xf…

【网站 seo 排名优化】typecho Handsome 主题高排名权重优化方案

前言 前一篇优化文章主要是完成了对于 typecho 各个方面的美化与简单优化&#xff0c;如下&#xff1a; 构造你独一无二的博客美化&#xff1a;typecho joe主题优化日志 而现在博主采用的是 Handsome 主题&#xff0c;相比较 joe 主题&#xff0c;编辑、定制功能更为强大、方便…

华为OD机试真题 JavaScript 实现【合法IP】【牛客练习题】

一、题目描述 IPV4地址可以用一个32位无符号整数来表示&#xff0c;一般用点分方式来显示&#xff0c;点将IP地址分成4个部分&#xff0c;每个部分为8位&#xff0c;表示成一个无符号整数&#xff08;因此正号不需要出现&#xff09;&#xff0c;如10.137.17.1&#xff0c;是我…

Python中函数的介绍

在Python中&#xff0c;函数的三个要素是&#xff1a;函数名参数返回值 函数名&#xff1a;函数名是函数的标识符&#xff0c;用于唯一标识函数。在定义函数时&#xff0c;需要给函数一个名字&#xff0c;以便后续调用和引用。函数名应遵循命名规则&#xff0c;例如以字母或下划…

HDSLB VPP 23.04 is formally released

1 摘要 近年来随着数字化技术的发展&#xff0c;数据中心以及边缘设备的网络带宽需求越来越高。作为部署在服务入口位置的4层负载均衡器&#xff0c;其性能要求也随之水涨船高。为了应对当前的市场需求&#xff0c;充分利用Intel的软硬件技术和优势&#xff0c;针对4层负载均衡…

一个奇葩的问题

大家好&#xff0c;这里是极客重生&#xff0c;最近遇到一个奇葩的网络问题&#xff0c;分享给大家&#xff0c;看完一定会觉得很奇葩。 问题现象 客户反馈有一个server端S&#xff0c; 两个client端C1, C2, S的iptables规则对C1, C2都是放通的&#xff0c;但是C2无法连接上S&a…

有奖征文 | 夙兴夜寐,铸梦网安

出品&#xff5c;MS08067实验室&#xff08;www.ms08067.com&#xff09; 本文作者&#xff1a;潜龙勿用 01 时光荏苒&#xff0c;流年岁月如白驹过隙&#xff0c;不停飞逝于眼前&#xff0c;在这车马星驰的人间&#xff0c;踏入网络安全领域已然三年有余。我也终于从一开始的不…

左移右移 2022年国赛 思维

思路&#xff1a; 简单的思维题&#xff0c;应该从后往前遍历操作。如果后面的对数i操作过&#xff0c;则前面对数i的操作都可以无视。可以通过栈这种数据结构实现后往前遍历。 AC代码&#xff1a; import java.io.*; import java.util.*; public class Main{public static …

Linux常用命令——groupdel命令

在线Linux命令查询工具 groupdel 用于删除指定的工作组 补充说明 groupdel命令用于删除指定的工作组&#xff0c;本命令要修改的系统文件包括/ect/group和/ect/gshadow。若该群组中仍包括某些用户&#xff0c;则必须先删除这些用户后&#xff0c;方能删除群组。 语法 gro…