transformers之text generation解码策略

news2024/12/26 21:26:49

目录

  • 参数
    • Temperature
    • Top-p and Top-k
      • 1. 选择最上面的token:贪婪解码
      • 2. 从最上面的tokens中选择:top-k
      • 3. 从概率加起来为15%的top token中选择:top-p
    • Frequency and Presence Penalties
  • transformers库中的解码策略
    • 贪婪搜索
    • 对比搜索
    • 多项式采样
    • beam搜索解码
    • beam搜索多项式采样
    • 多样beam搜索
    • 推测解码
  • 参考

文本生成对于许多NLP任务至关重要,例如开放式文本生成、摘要、翻译等。它还在各种混合模态应用程序中发挥作用,这些应用程序将文本作为输出,如语音到文本和视觉到文本。一些可以生成文本的模型包括GPT2、XLNet、OpenAI GPT、CTRL、TransformerXL、XLM、Bart、T5、GIT、Whisper。

请注意,generate方法的输入依赖于模型的模态。它们由模型的preprocessor类返回,例如AutoTokenizer或AutoProcessor。如果模型的preprocessor创建了不止一种输入,则将所有输入传递给generate()。

选择输出token以生成文本的过程称为解码,您可以定制generate()方法将使用的解码策略。修改解码策略不会改变任何可训练参数的值。但是,它会对生成的输出的质量产生明显的影响。它可以帮助减少文本中的重复,使其更加连贯。

transformers.generation.GenerationMixin

class GenerationMixin:
    """
    A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].

    The class exposes [`~generation.GenerationMixin.generate`], which can be used for:
        - *greedy decoding* by calling [`~generation.GenerationMixin._greedy_search`] if `num_beams=1` and
          `do_sample=False`
        - *contrastive search* by calling [`~generation.GenerationMixin._contrastive_search`] if `penalty_alpha>0` and
          `top_k>1`
        - *multinomial sampling* by calling [`~generation.GenerationMixin._sample`] if `num_beams=1` and
          `do_sample=True`
        - *beam-search decoding* by calling [`~generation.GenerationMixin._beam_search`] if `num_beams>1` and
          `do_sample=False`
        - *beam-search multinomial sampling* by calling [`~generation.GenerationMixin._beam_sample`] if `num_beams>1`
          and `do_sample=True`
        - *diverse beam-search decoding* by calling [`~generation.GenerationMixin._group_beam_search`], if `num_beams>1`
          and `num_beam_groups>1`
        - *constrained beam-search decoding* by calling [`~generation.GenerationMixin._constrained_beam_search`], if
          `constraints!=None` or `force_words_ids!=None`
        - *assisted decoding* by calling [`~generation.GenerationMixin._assisted_decoding`], if
            `assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`

    You do not need to call any of the above methods directly. Pass custom parameter values to 'generate' instead. To
    learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
    """

参数

有一些模型参数会影响模型生成输出的可预测性。这些参数包括temperaturetop-ptop-kfrequency_penaltypresence_penalty

Temperature

从生成模型中抽样包含随机性,因此每次点击“generate”时,相同的提示可能会产生不同的输出。温度是一个用来调整随机性程度的数字。

更低的温度值意味着更少的随机性;温度值为0将始终产生相同的输出。较低的温度值(小于1)更适合执行具有“正确”答案的任务,如问答或摘要。如果模型开始自我重复,这是温度值可能太低的迹象。
高温度值意味着更多的随机性。这可以帮助模型提供更多创造性的输出,但如果您使用检索增强生成(RAG),也可能意味着它没有正确使用您提供的上下文。如果模型开始偏离主题,给出无意义的输出,这是温度值过高的迹象。

在这里插入图片描述
温度值可以针对不同的问题进行调整,但大多数人会发现温度值为1是一个很好的起点。

随着序列变长,模型对其预测自然会变得更有信心,因此您可以对较长的提示提高温度值而不会跑题。相比之下,在简短的提示中使用高温度值可能会导致输出非常不稳定。

Top-p and Top-k

用于选择输出标记的方法是使用语言模型成功生成文本的重要组成部分。有几种方法(也称为解码策略)用于选择输出token,其中最主要的两种是top-k采样和top-p采样。
让我们看一下这个例子,其中模型的输入prompt 是The name of that country is the

在这里插入图片描述
在本例中,输出的token为United 。这是在语言模型处理了输入并为其词汇表中的每个token计算了可能性得分之后输出的。这个分数表明它是句子中下一个token的可能性(基于模型所训练的所有文本)。

在这里插入图片描述
该模型计算其词汇表中每个token的可能性。使用解码策略选择一个作为输出。

1. 选择最上面的token:贪婪解码

在这里插入图片描述
总是选择得分最高的token被称为“贪婪解码”。它很有用,但也有一些缺点。
贪婪解码是一种合理的策略,但存在一些缺陷;例如,输出可能会陷入重复的循环中。想想你的智能手机输入法的自动建议。当你不断地选择最高的建议词时,它可能会演变成重复的句子。

2. 从最上面的tokens中选择:top-k

另一个常用的策略是3个top tokens的候选列表中取样。这种方法允许其他高分toekn有机会被选中。这种抽样引入的随机性有助于在许多场景中提高生成的质量。

在这里插入图片描述
添加一些随机性有助于使输出文本更自然。在top-3解码中,我们首先列出三个token,然后通过考虑它们的似然分数(likelihood)对其中一个进行采样。

更广泛地说,选择前三个标记意味着将top-k参数设置为3。更改top-k参数将设置模型在输出每个token时从中采样的候选列表的大小。将top-k设置为1时得到的是贪婪解码。

在这里插入图片描述
注意,当k被设置为0时,模型禁用k采样并使用p。

3. 从概率加起来为15%的top token中选择:top-p

由于选择最佳top-k值的难度很大,因此另一种流行的解码策略诞生了,该策略可以动态设置token短列表的大小。这种方法称为“核心抽样(Nucleus Sampling)”,通过选择可能性总和不超过某一特定值的top tokens来创建候选名单。top-p值为0.15的简单示例如下:

在这里插入图片描述
在top-p中,候选名单的大小是根据达到某个阈值的似然得分的总和动态选择的。
Top-p通常设置为一个高值(如0.75),目的是限制可能采样的低概率token的长尾。我们可以同时使用top-k和top-p。

如果kp都启用,则pk之后起作用。

Frequency and Presence Penalties

最后一组参数是frequency_penaltypresence_penalty,它们都对token的对数(log)概率(即“logits”)起作用,以影响给定token在输出中出现的频率。

频率惩罚–惩罚之前文本中已经出现的token(包括提示),并根据该token出现的次数进行缩放。因此,已经出现10次的令牌比只出现一次的令牌得到更高的惩罚(这降低了它出现的概率)。

出现惩罚–不管出现的频率如何,只要这个token之前出现过一次,就会被惩罚。

transformers库中的解码策略

默认文本生成配置模型的解码策略在其生成配置中定义。当在pipeline()中使用预训练模型进行推理时,模型会调用PreTrainedModel.generate()方法,该方法在后台应用默认的生成配置。当模型中没有保存自定义配置时,也使用默认配置。

  • 查看生成配置:
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
print(model.generation_config)

# GenerationConfig {
   
#   "bos_token_id": 50256,
#   "eos_token_id": 50256
# }
# <BLANKLINE>

Generation_config只显示与默认生成配置不同的值,而没有列出任何默认值。默认的生成配置将输出和输入token的大小限制为最多20个token,以避免

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1814567.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

中国大模型站起来了!甚至被美国团队反向抄袭

一直以来&#xff0c;美国是公认的AI领域强者&#xff0c;我国AI技术虽然差不多&#xff0c;但始终落人一步。然而&#xff0c;近日斯坦福团队的AI模型却被指控抄袭中国AI模型&#xff0c;这下许多人都坐不住了。 被实锤抄袭的&#xff0c;是斯坦福大学AI团队&#xff0c;他们…

WWDC 2024及其AI功能的引入对中国用户和开发者的影响

WWDC&#xff08;Apple Worldwide Developers Conference&#xff09;是苹果公司一年一度的重要活动&#xff0c;吸引了全球开发者的关注。WWDC 2024引入了许多新技术和功能&#xff0c;尤其是AI功能的加入&#xff0c;引发了广泛讨论。本文将深入探讨中国开发者如何看待WWDC 2…

四川赤橙宏海商务信息咨询有限公司揭秘抖音电商新风口

在数字化浪潮席卷全球的今天&#xff0c;电商行业作为新时代经济的生力军&#xff0c;正以前所未有的速度发展。作为抖音电商服务的佼佼者&#xff0c;四川赤橙宏海商务信息咨询有限公司凭借其专业的服务团队和前瞻的市场洞察&#xff0c;不断刷新行业纪录&#xff0c;助力商家…

docker安装rabbitmq和延迟插件(不废话版)

1.下载镜像 docker pull rabbitmq:3.8-management 2.启动 docker run -e RABBITMQ_DEFAULT_USERlicoos -e RABBITMQ_DEFAULT_PASSlicoosrabbitmq -v mq-plugins:/plugins --name mq --hostname mq -p 15672:15672 -p 5672:5672 -d rabbitmq:3.8-management 3.下载对…

跨海交流丨台湾混凝土行业参访团与上海思伟软件共筑“智慧砼厂”梦 !

每一次跨越地域的握手 都是行业革新与智慧交融的序曲 台湾优质混凝土参访团 2024年5月29日&#xff0c;财团法人台湾营建研究院院长吕良正先生&#xff0c;率领着由61名行业精英组成的台湾商砼参访团&#xff0c;跨越海峡抵达上海&#xff0c;开展了一场连接两岸的学习交流活动…

苹果的股票都飙升7%了,谷歌仍在建议你往披萨上加胶水|TodayAI

最近&#xff0c;谷歌&#xff08;Google&#xff09;的人工智能再次引发了一场笑话。这次&#xff0c;它建议用户在披萨上添加胶水&#xff0c;引起了广泛关注和讨论。事情的起因源自一位互联网传奇人物Katie Notopoulos&#xff0c;她实际上制作并食用了一个胶水披萨&#xf…

【源码】2024最新陪诊小程序uniapp+thinkphp

20 2024最新陪诊小程序uniappthinkphp资源来源&#xff1a;52codes.cc 20最新陪诊小程序uniappthinkphp 简介&#xff1a;随着社会逐渐步进入老龄化越来越多的老年人或者不经常去医院的用户对于医院繁琐的流程很是苦劳于是陪诊这个行业开始兴起。小白陪诊开发理念&#xff0…

PS2045L-ASEMI低Low VF肖特基PS2045L

编辑&#xff1a;ll PS2045L-ASEMI低Low VF肖特基PS2045L 型号&#xff1a;PS2045L 品牌&#xff1a;ASEMI 封装&#xff1a;TO-277 最大平均正向电流&#xff08;IF&#xff09;&#xff1a;20A 最大循环峰值反向电压&#xff08;VRRM&#xff09;&#xff1a;45V 最大…

CVE-2019-20933-influxdb未授权访问-vulhub

1.原理 参考&#xff1a;https://blog.csdn.net/tqlisno1/article/details/109110644 InfluxDB 未授权访问 漏洞复现_influxdb未授权访问复现-CSDN博客 InfluxDB 是一个开源分布式时序、时间和指标数据库&#xff0c;使用 Go 语言编写&#xff0c;无需外部依赖。其设计目标是…

铝合金板件加工迎来3D视觉新时代

在制造业的浩瀚星空中&#xff0c;铝合金板件加工一直以其轻质、高强度、耐腐蚀的特性&#xff0c;扮演着举足轻重的角色。然而&#xff0c;随着市场竞争的加剧和产品需求的多样化&#xff0c;传统的加工方式已难以满足现代制造业对高效率、高精度的追求。在这个关键时刻&#…

【Java】解决Java报错:ArithmeticException during Division

文章目录 引言一、ArithmeticException的定义与概述1. 什么是ArithmeticException&#xff1f;2. ArithmeticException的常见触发场景3. 示例代码 二、解决方案1. 检查除数是否为零2. 使用异常处理3. 使用浮点数除法4. 使用自定义方法进行安全除法 三、最佳实践1. 始终检查除数…

RAG实操教程,LangChain + Llama2 | 创造你的个人LLM

本文将逐步指导您创建自己的RAG&#xff08;检索增强生成&#xff09;系统&#xff0c;使您能够上传自己的PDF文件并向LLM询问有关PDF的信息。本教程侧重于图中蓝色部分&#xff0c;即暂时不涉及Gradio&#xff08;想了解已接入Gradio的&#xff0c;请参考官网&#xff09;。相…

计算机网络-BGP路由优选原则四-优选AS_Path属性值最短的路由

一、优选AS_Path属性值最短的路由 AS_Path&#xff1a;这是BGP中最重要的属性之一&#xff0c;它记录了路由信息经过的所有自治系统。AS_Path属性帮助接收路由信息的路由器了解该路由的来源和路径。AS_Path由一系列的自治系统号组成&#xff0c;这些自治系统号代表了路由信息在…

SAP PP学习笔记17 - MTS(Make-to-Stock) 按库存生产 的策略70,策略59

上几章讲了几种策略&#xff0c;策略10&#xff0c;11&#xff0c;30&#xff0c;40。 SAP PP学习笔记14 - MTS&#xff08;Make-to-Stock) 按库存生产&#xff08;策略10&#xff09;&#xff0c;以及生产计划的概要-CSDN博客 SAP PP学习笔记15 - MTS&#xff08;Make-to-St…

架构设计-跨域问题的根源及解决方式

前面文章《架构设计-web项目中跨域问题涉及到的后端和前端配置》中说明了处理跨域问题的一种方式&#xff0c;本文详细说明下产生跨域问题的原因及处理方式。 一、产生跨域问题的原因&#xff1a; 浏览器的同源策略&#xff1a;这是跨域问题的根本原因。同源策略是浏览器对Jav…

C语言 sizeof 和 strlen

目录 一、sizeof 和 strlen 的区别 a.sizeof b.strlen c.sizeof与strlen的区别 二、数组和指针笔试题解析(32位环境) a.一维数组( int a[ ] { 1 , 2 , 3 , 4 } ) b.字符数组 &#xff08;char arr[ ] {a , b , c , d , e , f }&#xff09; &#xff08; char arr[ …

吴恩达2022机器学习专项课程C2W3:实验Lab_01模型评估与选择

这里写目录标题 导入模块与实验环境配置回归1.构建并可视化数据集2.分割数据集3.重新绘制数据集3.特征缩放4.评估模型&#xff1a;计算训练集的误差5.评估模型&#xff1a;计算交叉验证集的误差 添加多项式1.构建多项式特征集2.缩放特征3.使用标准化的计训练集和交叉验证集&…

[next.js]移动端调试vconsole

一般最简单的调试方式当然是使用vconsole来输出想要的数据啦&#xff1b; next.js如果想使用的话需要在客户端环境里调用才行&#xff08;服务端直接看cmd控制台就够了&#xff09;&#xff1b; 先安装vconsole npm i -D vconsolenext.js不像react cli或者vue一样有一个main.…

Python学习从0开始——Kaggle计算机视觉001

Python学习从0开始——Kaggle计算机视觉001 一、卷积分类器1.分类器2.训练分类器3.使用 二、卷积和RELU1.特征提取2.带卷积的过滤器定义3.激活&#xff1a;4.用ReLU检测5.使用 三、最大池化1.最大池压缩2.使用3.平移不变性 四、滑动窗口1.介绍2.步长3.边界4.使用 五、自定义Con…

[linux]如何跟踪linux 内核运行的流程呢

前面已经可以把内核编译出来&#xff0c;但是作为技术狗想看到内核是怎么运行的怎么办&#xff1f; 内核很多代码都是C语言写的&#xff0c;那简单&#xff0c;添加2行代码&#xff1a; include/linux/printk.h 529和530原来的&#xff1a; #define pr_info(fmt, ...) \ …