LLM文本生成—解码策略(Top-k Top-p Temperature)

news2024/12/23 10:06:25
{
 "top_k": 5,
 "temperature": 0.8,
 "num_beams": 1,
 "top_p": 0.75,
 "repetition_penalty": 1.5,
 "max_tokens": 30000,
 "message": [
        {
 "content": "你好",
 "role": "user"
        }
    ]
}

在大模型训练好之后,如何对训练好的模型进行解码(decode)是一个火热的研究话题。

在自然语言任务中,我们通常使用一个预训练的大模型(比如GPT)来根据给定的输入文本(比如一个开头或一个问题)生成输出文本(比如一个答案或一个结尾)。为了生成输出文本,我们需要让模型逐个预测每个 token ,直到达到一个终止条件(如一个标点符号或一个最大长度)。在每一步,模型会给出一个概率分布,表示它对下一个单词的预测。例如,如果输入的文本是“我最喜欢的”,那么模型可能会给出下面的概率分布:

那么,我们应该如何从这个概率分布中选择下一个单词呢?以下是几种常用的方法:

  • 贪心解码(Greedy Decoding):直接选择概率最高的单词。这种方法简单高效,但是可能会导致生成的文本过于单调和重复。
  • 随机采样(Random Sampling):按照概率分布随机选择一个单词。这种方法可以增加生成的多样性,但是可能会导致生成的文本不连贯和无意义。
  • Beam Search:维护一个大小为 k 的候选序列集合,每一步从每个候选序列的概率分布中选择概率最高的 k 个单词,然后保留总概率最高的 k 个候选序列。这种方法可以平衡生成的质量和多样性,但是可能会导致生成的文本过于保守和不自然。

以上方法都有各自的问题,而 top-k 采样和 top-p 采样是介于贪心解码和随机采样之间的方法,也是目前大模型解码策略中常用的方法。

top-k采样

在上面的例子中,如果使用贪心策略,那么选择的 token 必然就是“女孩”。

贪心解码是一种合理的策略,但也有一些缺点。例如,输出可能会陷入重复循环。想想智能手机自动建议中的建议。当你不断地选择建议最高的单词时,它可能会变成重复的句子。

Top-k 采样是对前面“贪心策略”的优化,它从排名前 k 的 token 中进行抽样,允许其他分数或概率较高的token 也有机会被选中。在很多情况下,这种抽样带来的随机性有助于提高生成质量。

top-k 采样的思路是,在每一步,只从概率最高的 k 个单词中进行随机采样,而不考虑其他低概率的单词。例如,如果 k=2,那么我们只从女孩、鞋子中选择一个单词,而不考虑大象、西瓜等其他单词。这样可以避免采样到一些不合适或不相关的单词,同时也可以保留一些有趣或有创意的单词。

下面是 top-k 采样的例子:

通过调整 k 的大小,即可控制采样列表的大小。“贪心策略”其实就是 k = 1的 top-k 采样。

下面是top-k 的代码实现:

import torch
from labml_nn.sampling import Sampler

# Top-k Sampler
class TopKSampler(Sampler):
    # k is the number of tokens to pick
    # sampler is the sampler to use for the top-k tokens
    # sampler can be any sampler that takes a logits tensor as input and returns a token tensor; e.g. `TemperatureSampler`.
    def __init__(self, k: int, sampler: Sampler):
        self.k = k
        self.sampler = sampler

    # Sample from logits
    def __call__(self, logits: torch.Tensor):
        # New logits filled with −∞; i.e. zero probability
        zeros = logits.new_ones(logits.shape) * float('-inf')
        # Pick the largest k logits and their indices
        values, indices = torch.topk(logits, self.k, dim=-1)
        # Set the values of the top-k selected indices to actual logits.
        # Logits of other tokens remain −∞
        zeros.scatter_(-1, indices, values)
        # Sample from the top-k logits with the specified sampler.
        return self.sampler(zeros)

总结一下,top-k 有以下有点:

  • 它可以根据不同的输入文本动态调整候选单词的数量,而不是固定为 k 个。这是因为不同的输入文本可能会导致不同的概率分布,有些分布可能比较平坦,有些分布可能比较尖锐。如果分布比较平坦,那么前 k 个单词可能都有相近的概率,那么我们就可以从中进行随机采样;如果分布比较尖锐,那么前 k 个单词可能会占据绝大部分概率,那么我们就可以近似地进行贪心解码。
  • 它可以通过调整 k 的大小来控制生成的多样性和质量。一般来说,k 越大,生成的多样性越高,但是生成的质量越低;k 越小,生成的质量越高,但是生成的多样性越低。因此,我们可以根据不同的任务和场景来选择合适的k 值。
  • 它可以与其他解码策略结合使用,例如温度调节(Temperature Scaling)、重复惩罚(Repetition Penalty)、长度惩罚(Length Penalty)等,来进一步优化生成的效果。

但是 top-k 也有一些缺点,比如:

  • 它可能会导致生成的文本不符合常识或逻辑。这是因为 top-k 采样只考虑了单词的概率,而没有考虑单词之间的语义和语法关系。例如,如果输入文本是“我喜欢吃”,那么即使饺子的概率最高,也不一定是最合适的选择,因为可能用户更喜欢吃其他食物。
  • 它可能会导致生成的文本过于简单或无聊。这是因为 top-k 采样只考虑了概率最高的 k 个单词,而没有考虑其他低概率但有意义或有创意的单词。例如,如果输入文本是“我喜欢吃”,那么即使苹果、饺子和火锅都是合理的选择,也不一定是最有趣或最惊喜的选择,因为可能用户更喜欢吃一些特别或新奇的食物。

因此,我们通常会考虑 top-k 和其它策略结合,比如 top-p。

top-p采样

top-k 有一个缺陷,那就是“k 值取多少是最优的?”非常难确定。于是出现了动态设置 token 候选列表大小策略——即核采样(Nucleus Sampling)。

top-p 采样的思路是,在每一步,只从累积概率超过某个阈值 p 的最小单词集合中进行随机采样,而不考虑其他低概率的单词。这种方法也被称为核采样(nucleus sampling),因为它只关注概率分布的核心部分,而忽略了尾部部分。例如,如果 p=0.9,那么我们只从累积概率达到 0.9 的最小单词集合中选择一个单词,而不考虑其他累积概率小于 0.9 的单词。这样可以避免采样到一些不合适或不相关的单词,同时也可以保留一些有趣或有创意的单词。

下图展示了 top-p 值为 0.9 的 Top-p 采样效果:

top-p 值通常设置为比较高的值(如0.75),目的是限制低概率 token 的长尾。我们可以同时使用 top-k 和 top-p。如果 k 和 p 同时启用,则 p 在 k 之后起作用。

下面是 top-p 代码实现的例子:

import torch
from torch import nn

from labml_nn.sampling import Sampler


class NucleusSampler(Sampler):
    """
    ## Nucleus Sampler
    """
    def __init__(self, p: float, sampler: Sampler):
        """
        :param p: is the sum of probabilities of tokens to pick $p$
        :param sampler: is the sampler to use for the selected tokens
        """
        self.p = p
        self.sampler = sampler
        # Softmax to compute $P(x_i | x_{1:i-1})$ from the logits
        self.softmax = nn.Softmax(dim=-1)

    def __call__(self, logits: torch.Tensor):
        """
        Sample from logits with Nucleus Sampling
        """

        # Get probabilities $P(x_i | x_{1:i-1})$
        probs = self.softmax(logits)

        # Sort probabilities in descending order
        sorted_probs, indices = torch.sort(probs, dim=-1, descending=True)

        # Get the cumulative sum of probabilities in the sorted order
        cum_sum_probs = torch.cumsum(sorted_probs, dim=-1)

        # Find the cumulative sums less than $p$.
        nucleus = cum_sum_probs < self.p

        # Prepend ones so that we add one token after the minimum number
        # of tokens with cumulative probability less that $p$.
        nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]], dim=-1)

        # Get log probabilities and mask out the non-nucleus
        sorted_log_probs = torch.log(sorted_probs)
        sorted_log_probs[~nucleus] = float('-inf')

        # Sample from the sampler
        sampled_sorted_indexes = self.sampler(sorted_log_probs)

        # Get the actual indexes
        res = indices.gather(-1, sampled_sorted_indexes.unsqueeze(-1))

        #
        return res.squeeze(-1)

Temperature采样

Temperature 采样受统计热力学的启发,高温意味着更可能遇到低能态。在概率模型中,logits 扮演着能量的角色,我们可以通过将 logits 除以温度来实现温度采样,然后将其输入 Softmax 并获得采样概率。

越低的温度使模型对其首选越有信心,而高于1的温度会降低信心。0温度相当于 argmax 似然,而无限温度相当于均匀采样。

Temperature 采样中的温度与玻尔兹曼分布有关,其公式如下所示:

\rho_i = \frac{1}{Q}e^{-\epsilon_i/kT}=\frac{e^{-\epsilon i/kT}}{\sum{j=1}^M e^{-\epsilon_j/kT}}\\

其中 

\rho_i

 是状态 

i

 的概率, 

\epsilon_i

 是状态 

i

 的能量, 

k

 是波兹曼常数, 

T

 是系统的温度, 

M

 是系统所能到达的所有量子态的数目。

有机器学习背景的朋友第一眼看到上面的公式会觉得似曾相识。没错,上面的公式跟 Softmax 函数 :

\text{Softmax}(z_i) = \frac{e^{z_i}}{\sum_{c=1}^Ce^{z_c}}\\

很相似,本质上就是在 Softmax 函数上添加了温度(T)这个参数。Logits 根据我们的温度值进行缩放,然后传递到 Softmax 函数以计算新的概率分布。

上面“我喜欢漂亮的___”这个例子中,初始温度 

T=1

 ,我们直观看一下 

T

 取不同值的情况下,概率会发生什么变化:

通过上图我们可以清晰地看到,随着温度的降低,模型愈来愈越倾向选择”女孩“;另一方面,随着温度的升高,分布变得越来越均匀。当 

T=50

 时,选择”西瓜“的概率已经与选择”女孩“的概率相差无几了。

通常来说,温度与模型的“创造力”有关。但事实并非如此。温度只是调整单词的概率分布。其最终的宏观效果是,在较低的温度下,我们的模型更具确定性,而在较高的温度下,则不那么确定

下面是 Temperature 采样的代码实现:

import torch
from torch.distributions import Categorical

from labml_nn.sampling import Sampler


class TemperatureSampler(Sampler):
    """
    ## Sampler with Temperature
    """
    def __init__(self, temperature: float = 1.0):
        """
        :param temperature: is the temperature to sample with
        """
        self.temperature = temperature

    def __call__(self, logits: torch.Tensor):
        """
        Sample from logits
        """

        # Create a categorical distribution with temperature adjusted logits
        dist = Categorical(logits=logits / self.temperature)

        # Sample
        return dist.sample()

联合采样(top-k & top-p & Temperature)

通常我们是将 top-k、top-p、Temperature 联合起来使用。使用的先后顺序是 top-k->top-p->Temperature。

我们还是以前面的例子为例。

首先我们设置 top-k = 3,表示保留概率最高的3个 token。这样就会保留女孩、鞋子、大象这3个 token。

  • 女孩:0.664
  • 鞋子:0.199
  • 大象:0.105

接下来,我们可以使用 top-p 的方法,保留概率的累计和达到 0.8 的单词,也就是选取女孩和鞋子这两个 token。接着我们使用 Temperature = 0.7 进行归一化,变成:

  • 女孩:0.660
  • 鞋子:0.340

接着,我们可以从上述分布中进行随机采样,选取一个单词作为最终的生成结果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1525387.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

十八、软考-系统架构设计师笔记-真题解析-2022年真题

软考-系统架构设计师-2022年上午选择题真题 考试时间 8:30 ~ 11:00 150分钟 1.云计算服务体系结构如下图所示&#xff0c;图中①、②、③分别与SaaS、PaaS、IaaS相对应&#xff0c;图中①、②、③应为( )。 A.应用层、基础设施层、平台层 B.应用层、平台层、基础设施层 C.平…

python知识点总结(二)

这里写目录标题 1、什么是解释性语言&#xff0c;什么是编译性语言&#xff1f;2、说说中作用域是怎么划分的3、type和isinstance方法的区别4、浅拷贝和深拷贝5、python中变量在内存中存储方式6、python中的封装、继承、多态7、python中内存管理机制是怎么样的&#xff1f;8、简…

Halcon 深度图片==>点云图

文件路径 链接:https://pan.baidu.com/s/1UfFyZ6y-EFq9jy0T_DTJGA 提取码:ewdi * 1.读取深度图片 *****************

ubuntu16.04上pycharm卡住关不了

在使用pycharm的过程中&#xff0c;突然卡住&#xff0c;黑屏&#xff0c;手动界面关闭失败&#xff0c;可尝试以下方法解决。 输入以下命令&#xff0c;查看所有和pycharm有关的进程 ps -ef | grep pycharm得到以下结果 根据相应的PID&#xff0c;输入以下命令&#xff0c;强…

php双端交易所

php双端交易所&#xff0c;如需联系 完美修复版&#xff0c;带所有 PHP双端交易所完美版: PHP双端交易所完美版,带前端源码https://gitee.com/ycsw/ex.git

【REST2SQL】13 用户角色功能权限设计

【REST2SQL】01RDB关系型数据库REST初设计 【REST2SQL】02 GO连接Oracle数据库 【REST2SQL】03 GO读取JSON文件 【REST2SQL】04 REST2SQL第一版Oracle版实现 【REST2SQL】05 GO 操作 达梦 数据库 【REST2SQL】06 GO 跨包接口重构代码 【REST2SQL】07 GO 操作 Mysql 数据库 【RE…

ping和telnet的区别

ping是ICMP协议&#xff0c;只包含控制信息没有端口&#xff0c;用于测试两个网络主机之间网络是否畅通 telnet是TCP协议&#xff0c;用于查看目标主机某个端口是否开发。 总结&#xff1a;ping是物理计算机间的网络互通检查&#xff0c;telnet是应用服务间的访问连通检查&am…

AS-V1000视频监控平台如何加强系统安全,满足等保2.0规范要求

目 录 一、概述 &#xff08;一&#xff09;信息安全技术网络安全等级保护标准 &#xff08;二&#xff09;解读 1、等级保护工作的内容 2、等级保护的等级划分 3、不同等级的安全保护能力 第一级安全保护能力 第二级安全保护能力 第三级安全保护能力 第…

STM32信息安全 1.2 课程架构介绍:芯片生命周期管理与安全调试

STM32信息安全 1.2 课程架构介绍&#xff1a;STM32H5 芯片生命周期管理与安全调试 下面开始学习课程的第二节&#xff0c;简单介绍下STM32H5芯片的生命周期和安全调试&#xff0c;具体课程大家可以观看STM32官方录制的课程&#xff0c;链接&#xff1a;1.2. 课程架构介绍&…

Flex最后一行左对齐

<!DOCTYPE html> <html lang"zh-CN"><head><meta charset"UTF-8" /><meta name"viewport" content"widthdevice-width, initial-scale1.0" /><title>Flex最后一行左对齐</title><style&…

HDFS磁盘写满问题分析

HDFS磁盘写满问题分析 1. 问题说明1.1 namenode常规分配datanode策略1.2 DFS Used很大时是否能够继续写入数据 2 问题修复2.1 集群均衡操作2.2 配置系统预留参数 3. 疑问和思考3.1. 是否需要配置dfs.datanode.du.reserved&#xff1f; 4. 参考文档 探讨hdfs的datanode节点磁盘被…

【vue项目中点击下载】弹窗提示:离开此网站?系统可能不会保存您所做的更改,改为直接下载,不提示此弹窗内容,已解决

项目中用的是window.location.href实现下载 在Web浏览器中&#xff0c;当尝试通过window.location.href重定向到一个文件下载URL时&#xff0c;浏览器通常会显示一个确认对话框&#xff0c;询问用户是否要离开当前页面&#xff0c;因为下载的文件通常是在新窗口或新标签页中打…

【C语言】空心正方形图案

思路&#xff1a; 1&#xff0c;两行两列打印* &#xff1a;第一行和最后一行&#xff0c;第一列和最后一列。 2&#xff0c;其他地方打印空格。 代码如下&#xff1a; #include<stdio.h> int main() { int n 0; int i 0; int j 0; while (scanf("…

avue-crud顶部操作按钮插槽;avue-crud列数据插槽;avue-crud行操作按钮插槽

1.avue-crud顶部操作按钮插槽&#xff1b; <template slot"menuLeft" slot-scope"{ size }"><div class"left"><div class"btn"><el-button type"primary" size"small" click"onBatchR…

彻底学会系列:一、机器学习之梯度下降(1)

1 梯度下降概念 1.1 概念 梯度下降是一种优化算法&#xff0c;用于最小化一个函数的值&#xff0c;特别是用于训练机器学习模型中的参数&#xff0c;其基本思想是通过不断迭代调整参数的值&#xff0c;使得函数值沿着梯度的反方向逐渐减小&#xff0c;直至达到局部或全局最小…

【Linux】一文解决如何在终端查看 python解释器 的位置

【Linux】一文解决如何在终端查看 python解释器 的位置 &#x1f308; 个人主页&#xff1a;高斯小哥 &#x1f525; 高质量专栏&#xff1a;Matplotlib之旅&#xff1a;零基础精通数据可视化、Python基础【高质量合集】、PyTorch零基础入门教程&#x1f448; 希望得到您的订阅…

【Liunx-后端开发软件安装】Liunx安装nginx

【Liunx-后端开发软件安装】Liunx安装nginx 使用安装包安装 一、简介 nginx&#xff0c;这个家伙可不是你厨房里的那位大厨&#xff0c;它可是互联网世界的“煎饼果子摊主”。想象一下&#xff0c;在熙熙攘攘的网络大街上&#xff0c;nginx挥舞着它的锅铲——哦不&#xff0c;是…

KVM安装-kvm彻底卸载-docker安装Webvirtmgr

KVM安装和使用 一、安装 检测硬件是否支持KVM需要硬件的支持,使用命令查看硬件是否支持KVM。如果结果中有vmx(Intel)或svm(AMD)字样,就说明CPU的支持的 egrep ‘(vmx|svm)’ /proc/cpuinfo关闭selinux将 /etc/sysconfig/selinux 中的 SELinux=enforcing 修改为 SELinux=d…

python知识点总结(三)

python知识点总结三 1、有一个文件file.txt大小约为10G&#xff0c;但是内存只有4G&#xff0c;如果在只修改get_lines 函数而其他代码保持不变的情况下&#xff0c;应该如何实现? 需要考虑的问题都有那些?2、交换2个变量的值3、回调函数4、Python-遍历列表时删除元素的正确做…

编曲学习:如何编写钢琴织体 Cubase12逻辑预置 需要弄明白的问题

钢琴织体是指演奏形式、方式,同一个和弦进行可以用很多种不同的演奏方法。常用织体有分解和弦,柱式和弦,琶音织体,混合织体。 在编写钢琴织体前,先定好歌曲的调。 Cubase小技巧:把钢琴轨道向上拖动打和弦轨道,就可以显示和弦!如果你有一些参考工程,不知道用了哪些和…