文本生成模型如何解码

news2025/1/6 20:38:27

文章目录

    • 解码方法
      • Greedy Search
      • Beam Search
      • sampling
      • Temperature Sampling
      • top-k sampling
      • Top-p (nucleus) sampling
      • Contrastive search
    • 总结
    • 相关资源

语言模型如何对于一个给定输入生成相应的输出呢?答案是使用解码策略(decoding strategy)。这里对现有的解码策略做一个记录。

解码方法

与huggingface的how to generate 一样,用流行的transformers包和GPT2模型来对各个解码方法测试生成效果,先加载模型:

# transformers的安装命令: pip install -q transformers
# 导入对象
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# 确定推理设备
torch_device = "cuda" if torch.cuda.is_available() else "cpu"

# 加载分词器,第一次调用会先下载
tokenizer = AutoTokenizer.from_pretrained("gpt2")

# 加载模型,第一次调用会先下载
# add the EOS token as PAD token to avoid warnings
model = AutoModelForCausalLM.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id).to(torch_device)

Greedy Search

Greedy Search贪心搜索就是在每个时间步,都选择概率最大的词汇作为下一个词。比如下面的图片,从词"The"开始,算法先贪心的选择概率最大的词"nice",接着选择概率最大的"women"。

在这里插入图片描述

如果使用transformers的generate函数来生成文本,不指定参数的话,默认就是使用贪心搜索。

# encode context the generation is conditioned on
model_inputs = tokenizer('I enjoy playing badminton', return_tensors='pt').to(torch_device)

# generate 40 new tokens
greedy_output = model.generate(**model_inputs, max_new_tokens=40)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(greedy_output[0], skip_special_tokens=True))

Output:
----------------------------------------------------------------------------------------------------
I enjoy playing badminton, but I'm not a big fan of the idea of playing badminton. I think it's a bit too much of a distraction. I think it's a distraction that's not going to
  • 贪心搜索得到的最终序列不一定是最优的句子,因为最优的句子的前面的词的概率可能会比较低,但是句子整体的概率更高。就像上面图片中的[‘the’, ‘dog’, ‘has’] 的概率比[‘the’, ‘nice’, ‘women’]要高。
  • 从上面的示例结果中发现生成的内容有重复,这是语言模型生成存在的一个问题,在贪心搜索和beam search中会更常见。
  • 使用LLM生成结果时,有一个Temperature参数,比如openai 的 api ,当Temperature=0时就是使用的贪心搜索。

Beam Search

因为贪心搜索每次选择概率最大的词可能会错过整体概率更高的句子;为了减轻这个风险,Beam Search 通过在每个时间步保留num_beams个概率最高的词,最终选择整体概率最大的句子。

下面的图片示意了num_beams=2的情形:

在这里插入图片描述

如果使用transformers的generate函数来生成文本,使num_beams>1并且do_sample=False(默认即为False),就是使用的beam search方法。

# activate beam search and early_stopping
beam_output = model.generate(
    **model_inputs,
    max_new_tokens=40,
    num_beams=5,
    early_stopping=True
)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(beam_output[0], skip_special_tokens=True))
Output:
----------------------------------------------------------------------------------------------------
I enjoy playing badminton, but I don't like to play badminton. I don't like to play badminton. I don't like to play badminton. I don't like to play badm

我们也可以尝试将beam search 生成的句子都打印出来(用参数return_num_sequences,注意要小于等于num_beams),可以发现生成的几个句子差别不太大。

# set return_num_sequences > 1
beam_outputs = model.generate(
    **model_inputs,
    max_new_tokens=40,
    num_beams=5,
    num_return_sequences=5,
    early_stopping=True
)

# now we have 5 output sequences
print("Output:\n" + 100 * '-')
for i, beam_output in enumerate(beam_outputs):
  print("{}: {}".format(i, tokenizer.decode(beam_output, skip_special_tokens=True)))

Output:
----------------------------------------------------------------------------------------------------
0: I enjoy playing badminton, but I don't like to play badminton. I don't like to play badminton. I don't like to play badminton. I don't like to play badm
1: I enjoy playing badminton, but I don't like to play badminton. I don't like to play badminton. I don't like to play badminton. I like to play badminton.
2: I enjoy playing badminton, but I don't like to play badminton. I don't like to play badminton. I don't like to play badminton. I don't like to play goodm
3: I enjoy playing badminton, but I don't like to play badminton. I don't like to play badminton. I don't like to play badminton. I like to play badminton."
4: I enjoy playing badminton, but I don't like to play badminton. I don't like to play badminton. I don't like to play badminton. I like to play badminton,
  • Beam Search 可以保证比贪心搜索生成概率更高的句子,但是仍然不能保证找到最有可能的句子。
  • Beam Search的重复句子生成可以用n-grams惩罚来减轻,n-gram惩罚保证每个n-gram不会出现两次,方法是如果看到当前候选词与其上文所组成的 n-gram 已经出现过了,就将该候选词的概率设置为 0 。transformers包可以使用参数no_repeat_ngram_sizeno_repeat_ngram_size=2就是任意2-gram不会出现两次。
  • 在机器翻译或摘要等任务中,因为所需生成的文本长度或多或少都是可预测的,所以beam search效果比较好 - 参见 Murray et al. (2018) 和 Yang et al. (2018)的工作。但开放域文本生成情况有所不同,其输出文本长度可能会有很大差异,如对话和故事生成的输出文本长度就有很大不同。

sampling

采样就意味着不确定性,它根据当前条件概率分布随机选择下一个词。也就是每一个单词都有一定的几率会被选择,比如上面的图片中的例子,可视化出来就如下图,单词”car"从条件概率分布P(w|"The")中被采样到,接下来"drive"从P(w|"the", "car")被采样。

在这里插入图片描述

如果使用transformers的generate函数来生成文本,使do_sample=Truetop_k=0,就是使用采样方式解码:

# set seed to reproduce results. Feel free to change the seed though to get different results
from transformers import set_seed
set_seed(42)

# activate sampling and deactivate top_k by setting top_k sampling to 0
sample_output = model.generate(
    **model_inputs,
    max_new_tokens=40,
    do_sample=True,
    top_k=0
)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))

Output:
----------------------------------------------------------------------------------------------------
I enjoy playing badminton more than any other sport. I know more about winning than any other athlete and coach would agree, it's a lot tougher than most other AC athletes. American hockey Salaries

The Miami Dolphins
  • sampling 方法的问题模型可能会生成一些不太连贯的胡言乱语

Temperature Sampling

我们知道softmax的表达式如下式
p i = e x p ( z i ) ∑ j = 1 N e x p ( z j ) p_i = \frac {exp(z_i)} {\sum^N_{j=1} exp(z_j)} pi=j=1Nexp(zj)exp(zi)
而带Temperature的softmax的表达式如下式:
p i = e x p ( z i / τ ) ∑ j = 1 N e x p ( z j / τ ) p_i = \frac {exp(z_i/\tau)} {\sum^N_{j=1} exp(z_j/\tau)} pi=j=1Nexp(zj/τ)exp(zi/τ)

Temperature=1时就是普通的softmax,加了temperature之后可以让原本的概率分布更加两级分化(Temperature<1)或更平缓(Temperature>1)。

用如下代码生成的下图可以直观感受一下Temperature的效果:
在这里插入图片描述

import math
from matplotlib import pyplot as plt
import numpy as np
import torch

def softmax(vec, temperature):
    """
    turn vec into normalized probability
    """
    sum_exp = sum(math.exp(x/temperature) for x in vec)
    return [math.exp(x/temperature)/sum_exp for x in vec]

def main():
    vec = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    ts = [0.1, 0.3, 0.6, 1, 1.5, 10, 100, 10000]

    for t in ts:
        result = softmax(vec, t)
        print(t, result)
        plt.plot(result, label=t)
    plt.legend()
    plt.show()

if __name__ == "__main__":
    main()
-----------输出结果-----------------------
0.1 [8.193640616392913e-40, 1.8047694477191753e-35, 3.975269250769863e-31, 8.75611321772293e-27, 1.9286622828562907e-22, 4.2481613803067925e-18, 9.357198133414645e-14, 2.0610600462088695e-09, 4.5397868608862414e-05, 0.9999546000702376]
0.3 [9.023799189303686e-14, 2.5295175399808997e-12, 7.090648684486909e-11, 1.987624041824023e-09, 5.5716331571752974e-08, 1.5618193071184212e-06, 4.37803329701724e-05, 0.0012272338715773265, 0.034401359545912634, 0.964326006652751]
0.6 [2.48124849643664e-07, 1.3136945477127512e-06, 6.9553427122218854e-06, 3.682499278746801e-05, 0.00019496955792188005, 0.0010322643845619335, 0.00546531351351773, 0.028936048020019006, 0.15320161834191354, 0.811124444027169]
1 [7.801341612780742e-05, 0.00021206245143623275, 0.0005764455082375902, 0.0015669413501390804, 0.004259388198344144, 0.0115782175399118, 0.031472858344688034, 0.08555209892803112, 0.23255471590259755, 0.6321492583604866]
1.5 [0.0012076552782540224, 0.002352191295314716, 0.0045814430569569645, 0.008923432599188675, 0.017380473436496794, 0.03385253976191134, 0.06593574407043169, 0.12842529324824872, 0.25013831539204334, 0.4872029118611537]
10 [0.06120702456008912, 0.0676442235257524, 0.07475842861647011, 0.08262084118795704, 0.09131015090787675, 0.10091332330848407, 0.11152647016690201, 0.12325581142409142, 0.136218738269722, 0.150544988032655]
100 [0.09556032473672185, 0.09652072196694327, 0.09749077134979559, 0.09847056989102544, 0.09946021557130351, 0.10045980735602247, 0.10146944520519384, 0.10248923008344388, 0.10351926397011023, 0.10455964986943994]
10000 [0.09995500600033737, 0.0999650020007291, 0.09997499900077084, 0.09998499700056258, 0.09999499600020427, 0.10000499599979594, 0.10001499699943757, 0.10002499899922916, 0.10003500199927076, 0.10004500599966237]

如果使用transformers的generate函数来生成文本,使do_sample=True时,可以设置Temperature参数(默认值为1),比如使temperature=0.6:

# set seed to reproduce results. Feel free to change the seed though to get different results
from transformers import set_seed
set_seed(42)

# activate sampling and deactivate top_k by setting top_k sampling to 0
sample_output = model.generate(
    **model_inputs,
    max_new_tokens=40,
    do_sample=True,
    top_k=0,
    temperature=0.6
)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))
Output:
----------------------------------------------------------------------------------------------------
I enjoy playing badminton, and I was delighted to have the opportunity to play against the best players from the world.

"I'm looking forward to the challenge of playing against some of the best players from the country

  • 可以发现将Temperature降低(例子从1变成0.6)后,因为将分布变得更两极化(增加高概率单词的可能性,降低低概率词的可能性),所以这次的生成内容更连贯了。如果还是用前一节的可视化例子的话,示意图类似如下

    在这里插入图片描述

  • 当设置 T e m p e r a t u r e → 0 Temperature \rightarrow 0 Temperature0时temperature采样也就等同于贪心搜索,比如在LLAMA代码中temperature=0时就是用的贪心搜索

top-k sampling

论文《Hierarchical Neural Story Generation》中提出top-k sampling方法 ,它在每个时间步先选出K个最可能的下一个词,将它们的概率进行缩放调整后在这K个词中进行采样。在GPT-2的论文中生成故事的时候就是使用的top-k采样方法。

将前面的例子中的下一个词从3个扩展到10个来可视化top-k sampling,设k=6,如下图所示:

在这里插入图片描述

如果使用transformers的generate函数来生成文本,使do_sample=Truetop_k>0,就是使用top-k采样方式解码:

# set seed to reproduce results. Feel free to change the seed though to get different results
set_seed(42)

# set top_k to 50
sample_output = model.generate(
    **model_inputs,
    max_new_tokens=40,
    do_sample=True,
    top_k=50
)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))

Output:
----------------------------------------------------------------------------------------------------
I enjoy playing badminton more than any other sport. I know more about winning than any other athlete and I would much rather spend my time here. I play much more than most American hockey players and I appreciate the community.
  • top-k 采样的结果看起来更自然

  • top-k采样的问题是因为不能动态调整单词的个数,有时候会像上图右图一样包括一些不太适合的词。

Top-p (nucleus) sampling

top-p采样方法出自论文《The Curious Case of Neural Text Degeneration》, 它在每个时间步,选出累积概率和超过概率p的最小单词集,将它们的概率进行缩放调整后在这个单词集中进行采样。这样得到的单词集的大小会根据下一个词的概率分布动态增加或减少。

比如如果设p=0.92,与前面top-k采样中同样的例子,如下图所示进行采样的候选词集是不一样的

在这里插入图片描述

如果使用transformers的generate函数来生成文本,使do_sample=True0<top_p<1,就是使用top-p采样方式解码:

# set seed to reproduce results. Feel free to change the seed though to get different results
set_seed(42)

# set top_p to 0.92
sample_output = model.generate(
    **model_inputs,
    max_new_tokens=40,
    do_sample=True,
    top_p=0.92,
    top_k=0
)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))

Output:
----------------------------------------------------------------------------------------------------
I enjoy playing badminton more than any other sport. I know more about winning than any other athlete and coach would agree, it's a lot tougher than most other sports because everyone is playing badminton. So I'm

在LLAMA的生成代码中,top-p的实现如下:

def sample_top_p(probs, p):
    """
    Perform top-p (nucleus) sampling on a probability distribution.

    Args:
        probs (torch.Tensor): Probability distribution tensor.
        p (float): Probability threshold for top-p sampling.

    Returns:
        torch.Tensor: Sampled token indices.

    Note:
        Top-p sampling selects the smallest set of tokens whose cumulative probability mass
        exceeds the threshold p. The distribution is renormalized based on the selected tokens.

    """
    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    mask = probs_sum - probs_sort > p
    probs_sort[mask] = 0.0
    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
    next_token = torch.multinomial(probs_sort, num_samples=1)
    next_token = torch.gather(probs_idx, -1, next_token)
    return next_token

Contrastive search

待学习总结,可参考huggingface blog。

总结

每种解码方法各有优点,都有适应的场景,可根据实际测试情况选择最适合自己的方法。

相关资源

  1. huggingface transformers关于文本生成的文档:

    • how to generate (本文笔记中的大部分代码和图片来自此文)
    • 生成相关文档的GitHub issue讨论
    • transformers里的解码策略
    • transformers 文本生成相关的类的说明文档
  2. https://nn.labml.ai/sampling/index.html

  3. https://finisky.github.io/illustrated-decoding-strategies/

  4. https://blog.csdn.net/muyao987/article/details/125917234

  5. openai 的 api文档

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/994898.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

在vx1000中对目标属性值的函数修改方法

通过在函数中编辑GFX对象属性值&#xff0c;实现Y坐标相反的操作方法 有时需要对目标属性的x 、y坐标做负方向转换&#xff0c;就需要以下方法来实现 return input.ProY*(-1); return input.C0*(-1);

验收测试的内容和流程有哪些?

验收测试 信息化项目验收确认测试内容一般包括&#xff1a;测试(复核 ),资料评审 ,质量鉴定三部分。 (一)验收评测工作主要包括 :文档分析 ,方案制定 ,现场测试 ,问题单提交 ,测试报告 ; (二)验收测试内容主要包括 :检查 "合同 " 或"验收标准 "要求的所…

【Redis】2、Redis持久化和性能管理

Redis 高可用 在web服务器中&#xff0c;高可用是指服务器可以正常访问的时间&#xff0c;衡量的标准是在多长时间内可以提供正常服务&#xff08;99.9%、99.99%、99.999%等等&#xff09;。 但是在Redis语境中&#xff0c;高可用的含义似乎要宽泛一些&#xff0c;除了保证提供…

C++之构造函数列表使用默认值(一百九十一)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 优质专栏&#xff1a;Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 人生格言&#xff1a; 人生…

发布订阅机制和点对点机制

【Go项目】25. 在 gin 中引入 WebSocket 和 Hub_哔哩哔哩_bilibili gorilla/websocket: Package gorilla/websocket is a fast, well-tested and widely used WebSocket implementation for Go. (github.com) 1.订阅发布机制 引用上面链接的内容 发布订阅的基本工作原理 在分…

AQS源码剖析,完整流程解读

目录 1 AQS是什么2 AQS加锁流程3 结构4 AQS方法概览5 AQS源码剖析5.1 加锁方法5.2 释放锁5.3 await等待5.4 signal唤醒 1 AQS是什么 ​ AQS即AbstractQueuedSynchronizer缩写&#xff0c;翻译为抽象队列同步器&#xff0c;是一种用来构建锁和同步器的框架。 平时使用较多的Ree…

【C++】常用排序算法

0.前言 1.sort #include <iostream> using namespace std;// 常用排序算法 sort #include<vector> #include<algorithm>//利用仿函数 打印输出 class myPrint { public:void operator()(int val){cout << val << " ";} };//利用普通函…

车载网络测试 - UDS诊断篇 - CANTP常用缩写

CANTP层规范常用缩写 缩写英文全称中文注释BRSbit rate switch比特率开关BSBlockSize块大小CAN controller area network控制器局域网CAN_DL CAN frame data link layer data length in bytesCAN 帧数据链路层数据长度&#xff08;以字节为单位&#xff09;CAN FDcontroller a…

[kingbase运维之奇怪的现象]

#[kingbase运维之奇怪的现象] ##奇怪的现象 某银行数据中心应用反馈&#xff0c;业务接口日志记录了很多执行慢的SQL&#xff0c;出现的时间是随机的&#xff0c;单独在数据库客户端工具执行会很快返回结果。根据之前的经验推断是业务代码传入的参数类型与数据库表结构字段定义…

HDD-FAT32 ZIP-FAT32 HDD-FAT16 ZIP-FAT16 HDD-NTFS

FAT32、FAT16指的是分区格式&#xff0c; FAT16单个文件最大2G FAT32单个文件最大4G NTFS单个文件大于4G HDD是硬盘启动 ZIP是软盘启动 U盘选HDD HDD-NTFS

buuctf crypto 【还原大师】解题记录

1.打开题目就能直接看到密文 2.感觉爆破直接能解&#xff0c;试试爆破&#xff08;参考文章&#xff1a;[buuctf] crypto全解——前84道&#xff08;不建议直接抄flag&#xff09;_buuctf crypto_咸鱼壹号的博客-CSDN博客&#xff09; import hashlib k TASC?O3RJMV?WDJKX?…

建筑模板9层板和7层板的区别

建筑模板是建筑施工过程中不可或缺的一环&#xff0c;而在建筑模板的选择中&#xff0c;常见的有9层板和7层板两种选项。它们在结构、特性和应用方面存在一些区别。下面将详细探讨9层板和7层板之间的区别。 首先&#xff0c;9层板和7层板的名称源自其板材的层数。9层板由9层木片…

Docker容器技术实战-1

1.docker容器 docker就好比传统的货运集装箱 每个虚拟机都有独立的操作系统&#xff0c;互不干扰&#xff0c;在这个虚拟机里可以跑任何东西 如应用 文件系统随便装&#xff0c;通过Guest OS 做了一个完全隔离&#xff0c;所以安全性很好&#xff0c;互不影响 容器 没有虚拟化…

Tomcat配置ssl、jar包

Tomcat配置ssl 部署tomcat服务&#xff0c;项目做到用https访问&#xff0c;使用nginx去做&#xff0c;访问任意一个子网站&#xff0c;都是https 或者 医美项目需要 上传jdk 456 tomcat war包 [nginx-stable] namenginx stable repo baseurlhttp://nginx.org/packages/…

AI绘画:StableDiffusion实操教程-斗破苍穹-云韵-常服(附高清图下载)

前段时间我分享了StableDiffusion的非常完整的教程&#xff1a;“AI绘画&#xff1a;Stable Diffusion 终极宝典&#xff1a;从入门到精通 ” 不久前&#xff0c;我与大家分享了StableDiffusion的全面教程&#xff1a;“AI绘画&#xff1a;Stable Diffusion 终极宝典&#xff…

用Navicat备份Mysql演示系统数据库的时候出:Too Many Connections

今天用Navicat进行数据备份的时候&#xff0c;发现由于数据库连接数目过多导致连接锁定&#xff0c;这种情况在多人协同开发的场景中很常见。当然我这里也因为多个应用使用了数据库连接&#xff0c;所以出现了Too Many Connections。 可能是超过最大连接数了。 1、进入Navicat…

【JAVA-Day03】JDK安装与IntelliJ IDEA安装、配置环境变量

JDK安装与IntelliJ IDEA安装、配置环境变量 一、JDK 版本介绍1.1 JDK 版本选择JDK 8JDK 11JDK 16JDK 171.2 JDK 下载1.3 JDK 安装1.4 配置环境变量1.5 验证 JDK 安装 二、开发利器——IntelliJ IDEA 的安装2.1 IntelliJ IDEA下载2.2 IntelliJ IDEA安装2.3 IntelliJ IDEA启动2.4…

编译原理:编译原理简明教程知识点梳理(应对考试版)

前言 姜老师是一个好老师&#xff0c;编译原理没有过是我的问题&#xff0c;我爱姜老师。 写这篇博客涉及到好多符号&#xff0c;可以参考这篇文章latex数学公式详细教程 因为打字过于麻烦&#xff0c;很多内容用平板的手写截图&#xff0c;还有电脑截图替代&#xff0c;不太习…

【刷题篇】贪心算法(一)

文章目录 分割平衡字符串买卖股票的最佳时机Ⅱ跳跃游戏钱币找零 分割平衡字符串 class Solution { public:int balancedStringSplit(string s) {int lens.size();int cnt0;int balance0;for(int i0;i<len;i){if(s[i]R){balance--;}else{balance;}if(balance0){cnt;}}return …

MyBatis框架中各种参数类型绑定的方式

MyBatis框架中各种参数类型绑定的方式 一、MyBatis参数绑定 MyBatis框架中&#xff0c;通过Mapper接口和Mapper映射文件的方式来操作数据库的时候&#xff0c;可能需要通过Mapper接口中的方法传递相应的参数拼接到SQL语句上面&#xff0c;那么Mybatis将传递的参数映射到对应S…